Spaces:
Runtime error
Runtime error
import project_path | |
from lib.yolov5.utils.general import clip_boxes, scale_boxes | |
import argparse | |
from datetime import datetime | |
import torch | |
import os | |
from dataloader import create_dataloader_frames_only | |
from inference import setup_model, do_detection, do_suppression, do_confidence_boost, format_predictions, do_tracking | |
from visualizer import generate_video_batches | |
import json | |
from tqdm import tqdm | |
import numpy as np | |
def main(args, config={}, verbose=True): | |
""" | |
Main processing task to be run in gradio | |
- Writes aris frames to dirname(filepath)/frames/{i}.jpg | |
- Writes json output to dirname(filepath)/{filename}_results.json | |
- Writes manual marking to dirname(filepath)/{filename}_marking.txt | |
- Writes video output to dirname(filepath)/{filename}_results.mp4 | |
- Zips all results to dirname(filepath)/{filename}_results.zip | |
Args: | |
filepath (str): path to aris file | |
TODO: Separate into subtasks in different queues; have a GPU-only queue. | |
""" | |
print("In task...") | |
print("Cuda available in task?", torch.cuda.is_available()) | |
# setup config | |
if "conf_threshold" not in config: config['conf_threshold'] = 0.001 | |
if "nms_iou" not in config: config['nms_iou'] = 0.6 | |
if "min_length" not in config: config['min_length'] = 0.3 | |
if "max_age" not in config: config['max_age'] = 20 | |
if "iou_threshold" not in config: config['iou_threshold'] = 0.01 | |
if "min_hits" not in config: config['min_hits'] = 11 | |
print(config) | |
model, device = setup_model(args.weights) | |
locations = [ | |
"kenai-val" | |
] | |
for loc in locations: | |
in_loc_dir = os.path.join(args.frames, loc) | |
out_loc_dir = os.path.join(args.output, loc) | |
print(in_loc_dir) | |
print(out_loc_dir) | |
detect_location(in_loc_dir, out_loc_dir, config, model, device, verbose) | |
def detect_location(in_loc_dir, out_loc_dir, config, model, device, verbose): | |
seq_list = os.listdir(in_loc_dir) | |
with tqdm(total=len(seq_list), desc="...", ncols=0) as pbar: | |
for seq in seq_list: | |
pbar.update(1) | |
if (seq.startswith(".")): continue | |
pbar.set_description("Processing " + seq) | |
in_seq_dir = os.path.join(in_loc_dir, seq) | |
out_seq_dir = os.path.join(out_loc_dir, seq) | |
os.makedirs(out_seq_dir, exist_ok=True) | |
detect_seq(in_seq_dir, out_seq_dir, config, model, device, verbose) | |
def detect_seq(in_seq_dir, out_seq_dir, config, model, device, verbose): | |
ann_list = [] | |
frame_list = detect(in_seq_dir, config, model, device, verbose) | |
for frame in frame_list: | |
if frame is not None: | |
for ann in frame: | |
ann_list.append({ | |
'image_id': ann[5], | |
'category_id': 0, | |
'bbox': [ann[0], ann[1], ann[2] - ann[0], ann[3] - ann[1]], | |
'score': ann[4] | |
}) | |
result = json.dumps(ann_list) | |
with open(os.path.join(out_seq_dir, 'pred.json'), 'w') as f: | |
f.write(result) | |
def detect(in_dir, config, model, device, verbose): | |
#progress_log = lambda p, m: 0 | |
# create dataloader | |
dataloader = create_dataloader_frames_only(in_dir) | |
inference, image_shapes, width, height = do_detection(dataloader, model, device, verbose=verbose) | |
outputs = do_suppression(inference, conf_thres=config['conf_threshold'], iou_thres=config['nms_iou'], verbose=verbose) | |
file_names = dataloader.files | |
frame_list = [] | |
for batch_i, batch in enumerate(outputs): | |
batch_shapes = image_shapes[batch_i] | |
# Format results | |
for si, pred in enumerate(batch): | |
(image_shape, original_shape) = batch_shapes[si] | |
# Clip boxes to image bounds and resize to input shape | |
clip_boxes(pred, (height, width)) | |
boxes = pred[:, :4].clone() # xyxy | |
confs = pred[:, 4].clone().tolist() | |
scale_boxes(image_shape, boxes, original_shape[0], original_shape[1]) # to original shape | |
frame = [ [*bb, conf] for bb, conf in zip(boxes.tolist(), confs) ] | |
file_name = file_names[batch_i*32 + si] | |
for ann in frame: | |
ann.append(file_name) | |
frame_list.append(frame) | |
return frame_list | |
def argument_parser(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--frames", required=True, help="Path to frame directory. Required.") | |
parser.add_argument("--output", required=True, help="Path to output directory. Required.") | |
parser.add_argument("--weights", default='models/v5m_896_300best.pt', help="Path to saved YOLOv5 weights. Default: ../models/v5m_896_300best.pt") | |
return parser | |
if __name__ == "__main__": | |
args = argument_parser().parse_args() | |
main(args) |