Spaces:
Runtime error
Runtime error
import project_path | |
import torch | |
from tqdm import tqdm | |
from functools import partial | |
import numpy as np | |
import json | |
from unittest.mock import patch | |
# assumes yolov5 on sys.path | |
from lib.yolov5.models.experimental import attempt_load | |
from lib.yolov5.utils.torch_utils import select_device | |
from lib.yolov5.utils.general import non_max_suppression | |
from lib.yolov5.utils.general import clip_boxes, scale_boxes | |
from lib.fish_eye.tracker import Tracker | |
### Configuration options | |
WEIGHTS = 'models/v5m_896_300best.pt' | |
# will need to configure these based on GPU hardware | |
BATCH_SIZE = 32 | |
conf_thres = 0.3 # detection | |
iou_thres = 0.3 # NMS IOU | |
min_length = 0.3 # minimum fish length, in meters | |
### | |
def norm(bbox, w, h): | |
""" | |
Normalize a bounding box. | |
Args: | |
bbox: list of length 4. Can be [x,y,w,h] or [x0,y0,x1,y1] | |
w: image width | |
h: image height | |
""" | |
bb = bbox.copy() | |
bb[0] /= w | |
bb[1] /= h | |
bb[2] /= w | |
bb[3] /= h | |
return bb | |
def do_full_inference(dataloader, image_meter_width, image_meter_height, gp=None, weights=WEIGHTS): | |
model, device = setup_model(weights) | |
all_preds = do_detection(dataloader, model, device, gp=gp) | |
results = do_tracking(all_preds, image_meter_width, image_meter_height, gp=gp) | |
return results | |
def setup_model(weights_fp=WEIGHTS, imgsz=896, batch_size=32): | |
if torch.cuda.is_available(): | |
device = select_device('0', batch_size=batch_size) | |
else: | |
print("CUDA not available. Using CPU inference.") | |
device = select_device('cpu', batch_size=batch_size) | |
# Setup model for inference | |
model = attempt_load(weights_fp, device=device) | |
half = device.type != 'cpu' # half precision only supported on CUDA | |
if half: | |
model.half() | |
model.eval(); | |
# Create dataloader for batched inference | |
img = torch.zeros((1, 3, imgsz, imgsz), device=device) | |
_ = model(img.half() if half else img) if device.type != 'cpu' else None # run once | |
return model, device | |
def do_detection(dataloader, model, device, gp=None, batch_size=BATCH_SIZE): | |
""" | |
Args: | |
frames_dir: a directory containing frames to be evaluated | |
image_meter_width: the width of each image, in meters (used for fish length calculation) | |
gp: a callback function which takes as input 1 parameter, (int) percent complete | |
prep_for_marking: re-index fish for manual marking output | |
""" | |
if (gp): gp(0, "Detection...") | |
# keep predictions to feed them ordered into the Tracker | |
# TODO: how to deal with large files? | |
all_preds = {} | |
# Run detection | |
with tqdm(total=len(dataloader)*batch_size, desc="Running detection", ncols=0) as pbar: | |
for batch_i, (img, _, shapes) in enumerate(dataloader): | |
if gp: gp(batch_i / len(dataloader), pbar.__str__()) | |
img = img.to(device, non_blocking=True) | |
img = img.half() if device.type != 'cpu' else img.float() # uint8 to fp16/32 | |
img /= 255.0 # 0 - 255 to 0.0 - 1.0 | |
nb, _, height, width = img.shape # batch size, channels, height, width | |
# Run model & NMS | |
with torch.no_grad(): | |
inf_out, _ = model(img, augment=False) | |
output = non_max_suppression(inf_out, conf_thres=conf_thres, iou_thres=iou_thres) | |
# Format results | |
for si, pred in enumerate(output): | |
# Clip boxes to image bounds and resize to input shape | |
clip_boxes(pred, (height, width)) | |
box = pred[:, :4].clone() # xyxy | |
confs = pred[:, 4].clone().tolist() | |
scale_boxes(img[si].shape[1:], box, shapes[si][0], shapes[si][1]) # to original shape | |
# get boxes into tracker input format - normalized xyxy with confidence score | |
# confidence score currently not used by tracker; set to 1.0 | |
boxes = None | |
if box.shape[0]: | |
do_norm = partial(norm, w=shapes[si][0][1], h=shapes[si][0][0]) | |
normed = list((map(do_norm, box[:, :4].tolist()))) | |
boxes = np.stack([ [*bb, conf] for bb, conf in zip(normed, confs) ]) | |
frame_num = (batch_i, si) | |
all_preds[frame_num] = boxes | |
pbar.update(1*batch_size) | |
return all_preds | |
def do_tracking(all_preds, image_meter_width, image_meter_height, gp=None): | |
if (gp): gp(0, "Tracking...") | |
# Initialize tracker | |
clip_info = { | |
'start_frame': 0, | |
'end_frame': len(all_preds), | |
'image_meter_width': image_meter_width, | |
'image_meter_height': image_meter_height | |
} | |
tracker = Tracker(clip_info, args={ 'max_age': 9, 'min_hits': 0, 'iou_threshold': 0.01}, min_hits=11) | |
# Run tracking | |
with tqdm(total=len(all_preds), desc="Running tracking", ncols=0) as pbar: | |
for i, key in enumerate(sorted(all_preds.keys())): | |
if gp: gp(i / len(all_preds), pbar.__str__()) | |
boxes = all_preds[key] | |
if boxes is not None: | |
tracker.update(boxes) | |
else: | |
tracker.update() | |
pbar.update(1) | |
json_data = tracker.finalize(min_length=min_length) | |
return json_data | |
def json_dump_round_float(some_object, out_path, num_digits=4): | |
"""Write a json file to disk with a specified level of precision. | |
See: https://gist.github.com/Sukonnik-Illia/ed9b2bec1821cad437d1b8adb17406a3 | |
""" | |
# saving original method | |
of = json.encoder._make_iterencode | |
def inner(*args, **kwargs): | |
args = list(args) | |
# fifth argument is float formater which will we replace | |
fmt_str = '{:.' + str(num_digits) + 'f}' | |
args[4] = lambda o: fmt_str.format(o) | |
return of(*args, **kwargs) | |
with patch('json.encoder._make_iterencode', wraps=inner): | |
return json.dump(some_object, open(out_path, 'w'), indent=2) |