fisheye-experimental / inference.py
oskarastrom's picture
First Commit
7a4b92f
raw
history blame
6.07 kB
import project_path
import torch
from tqdm import tqdm
from functools import partial
import numpy as np
import json
from unittest.mock import patch
# assumes yolov5 on sys.path
from lib.yolov5.models.experimental import attempt_load
from lib.yolov5.utils.torch_utils import select_device
from lib.yolov5.utils.general import non_max_suppression
from lib.yolov5.utils.general import clip_boxes, scale_boxes
from lib.fish_eye.tracker import Tracker
### Configuration options
WEIGHTS = 'models/v5m_896_300best.pt'
# will need to configure these based on GPU hardware
BATCH_SIZE = 32
conf_thres = 0.3 # detection
iou_thres = 0.3 # NMS IOU
min_length = 0.3 # minimum fish length, in meters
###
def norm(bbox, w, h):
"""
Normalize a bounding box.
Args:
bbox: list of length 4. Can be [x,y,w,h] or [x0,y0,x1,y1]
w: image width
h: image height
"""
bb = bbox.copy()
bb[0] /= w
bb[1] /= h
bb[2] /= w
bb[3] /= h
return bb
def do_full_inference(dataloader, image_meter_width, image_meter_height, gp=None, weights=WEIGHTS):
model, device = setup_model(weights)
all_preds = do_detection(dataloader, model, device, gp=gp)
results = do_tracking(all_preds, image_meter_width, image_meter_height, gp=gp)
return results
def setup_model(weights_fp=WEIGHTS, imgsz=896, batch_size=32):
if torch.cuda.is_available():
device = select_device('0', batch_size=batch_size)
else:
print("CUDA not available. Using CPU inference.")
device = select_device('cpu', batch_size=batch_size)
# Setup model for inference
model = attempt_load(weights_fp, device=device)
half = device.type != 'cpu' # half precision only supported on CUDA
if half:
model.half()
model.eval();
# Create dataloader for batched inference
img = torch.zeros((1, 3, imgsz, imgsz), device=device)
_ = model(img.half() if half else img) if device.type != 'cpu' else None # run once
return model, device
def do_detection(dataloader, model, device, gp=None, batch_size=BATCH_SIZE):
"""
Args:
frames_dir: a directory containing frames to be evaluated
image_meter_width: the width of each image, in meters (used for fish length calculation)
gp: a callback function which takes as input 1 parameter, (int) percent complete
prep_for_marking: re-index fish for manual marking output
"""
if (gp): gp(0, "Detection...")
# keep predictions to feed them ordered into the Tracker
# TODO: how to deal with large files?
all_preds = {}
# Run detection
with tqdm(total=len(dataloader)*batch_size, desc="Running detection", ncols=0) as pbar:
for batch_i, (img, _, shapes) in enumerate(dataloader):
if gp: gp(batch_i / len(dataloader), pbar.__str__())
img = img.to(device, non_blocking=True)
img = img.half() if device.type != 'cpu' else img.float() # uint8 to fp16/32
img /= 255.0 # 0 - 255 to 0.0 - 1.0
nb, _, height, width = img.shape # batch size, channels, height, width
# Run model & NMS
with torch.no_grad():
inf_out, _ = model(img, augment=False)
output = non_max_suppression(inf_out, conf_thres=conf_thres, iou_thres=iou_thres)
# Format results
for si, pred in enumerate(output):
# Clip boxes to image bounds and resize to input shape
clip_boxes(pred, (height, width))
box = pred[:, :4].clone() # xyxy
confs = pred[:, 4].clone().tolist()
scale_boxes(img[si].shape[1:], box, shapes[si][0], shapes[si][1]) # to original shape
# get boxes into tracker input format - normalized xyxy with confidence score
# confidence score currently not used by tracker; set to 1.0
boxes = None
if box.shape[0]:
do_norm = partial(norm, w=shapes[si][0][1], h=shapes[si][0][0])
normed = list((map(do_norm, box[:, :4].tolist())))
boxes = np.stack([ [*bb, conf] for bb, conf in zip(normed, confs) ])
frame_num = (batch_i, si)
all_preds[frame_num] = boxes
pbar.update(1*batch_size)
return all_preds
def do_tracking(all_preds, image_meter_width, image_meter_height, gp=None):
if (gp): gp(0, "Tracking...")
# Initialize tracker
clip_info = {
'start_frame': 0,
'end_frame': len(all_preds),
'image_meter_width': image_meter_width,
'image_meter_height': image_meter_height
}
tracker = Tracker(clip_info, args={ 'max_age': 9, 'min_hits': 0, 'iou_threshold': 0.01}, min_hits=11)
# Run tracking
with tqdm(total=len(all_preds), desc="Running tracking", ncols=0) as pbar:
for i, key in enumerate(sorted(all_preds.keys())):
if gp: gp(i / len(all_preds), pbar.__str__())
boxes = all_preds[key]
if boxes is not None:
tracker.update(boxes)
else:
tracker.update()
pbar.update(1)
json_data = tracker.finalize(min_length=min_length)
return json_data
@patch('json.encoder.c_make_encoder', None)
def json_dump_round_float(some_object, out_path, num_digits=4):
"""Write a json file to disk with a specified level of precision.
See: https://gist.github.com/Sukonnik-Illia/ed9b2bec1821cad437d1b8adb17406a3
"""
# saving original method
of = json.encoder._make_iterencode
def inner(*args, **kwargs):
args = list(args)
# fifth argument is float formater which will we replace
fmt_str = '{:.' + str(num_digits) + 'f}'
args[4] = lambda o: fmt_str.format(o)
return of(*args, **kwargs)
with patch('json.encoder._make_iterencode', wraps=inner):
return json.dump(some_object, open(out_path, 'w'), indent=2)