fisheye-experimental / scripts /full_detect_frames.py
oskarastrom's picture
Tracking script
809371f
raw
history blame
4.89 kB
import project_path
from lib.yolov5.utils.general import clip_boxes, scale_boxes
import argparse
from datetime import datetime
import torch
import os
from dataloader import create_dataloader_frames_only
from inference import setup_model, do_detection, do_suppression, do_confidence_boost, format_predictions, do_tracking
from visualizer import generate_video_batches
import json
from tqdm import tqdm
import numpy as np
def main(args, config={}, verbose=True):
"""
Main processing task to be run in gradio
- Writes aris frames to dirname(filepath)/frames/{i}.jpg
- Writes json output to dirname(filepath)/{filename}_results.json
- Writes manual marking to dirname(filepath)/{filename}_marking.txt
- Writes video output to dirname(filepath)/{filename}_results.mp4
- Zips all results to dirname(filepath)/{filename}_results.zip
Args:
filepath (str): path to aris file
TODO: Separate into subtasks in different queues; have a GPU-only queue.
"""
print("In task...")
print("Cuda available in task?", torch.cuda.is_available())
# setup config
if "conf_threshold" not in config: config['conf_threshold'] = 0.001
if "nms_iou" not in config: config['nms_iou'] = 0.6
if "min_length" not in config: config['min_length'] = 0.3
if "max_age" not in config: config['max_age'] = 20
if "iou_threshold" not in config: config['iou_threshold'] = 0.01
if "min_hits" not in config: config['min_hits'] = 11
print(config)
model, device = setup_model(args.weights)
locations = [
"kenai-val"
]
for loc in locations:
in_loc_dir = os.path.join(args.frames, loc)
out_loc_dir = os.path.join(args.output, loc)
print(in_loc_dir)
print(out_loc_dir)
detect_location(in_loc_dir, out_loc_dir, config, model, device, verbose)
def detect_location(in_loc_dir, out_loc_dir, config, model, device, verbose):
seq_list = os.listdir(in_loc_dir)
with tqdm(total=len(seq_list), desc="...", ncols=0) as pbar:
for seq in seq_list:
pbar.update(1)
if (seq.startswith(".")): continue
pbar.set_description("Processing " + seq)
in_seq_dir = os.path.join(in_loc_dir, seq)
out_seq_dir = os.path.join(out_loc_dir, seq)
os.makedirs(out_seq_dir, exist_ok=True)
detect_seq(in_seq_dir, out_seq_dir, config, model, device, verbose)
def detect_seq(in_seq_dir, out_seq_dir, config, model, device, verbose):
ann_list = []
frame_list = detect(in_seq_dir, config, model, device, verbose)
for frame in frame_list:
if frame is not None:
for ann in frame:
ann_list.append({
'image_id': ann[5],
'category_id': 0,
'bbox': [ann[0], ann[1], ann[2] - ann[0], ann[3] - ann[1]],
'score': ann[4]
})
result = json.dumps(ann_list)
with open(os.path.join(out_seq_dir, 'pred.json'), 'w') as f:
f.write(result)
def detect(in_dir, config, model, device, verbose):
#progress_log = lambda p, m: 0
# create dataloader
dataloader = create_dataloader_frames_only(in_dir)
inference, image_shapes, width, height = do_detection(dataloader, model, device, verbose=verbose)
outputs = do_suppression(inference, conf_thres=config['conf_threshold'], iou_thres=config['nms_iou'], verbose=verbose)
file_names = dataloader.files
frame_list = []
for batch_i, batch in enumerate(outputs):
batch_shapes = image_shapes[batch_i]
# Format results
for si, pred in enumerate(batch):
(image_shape, original_shape) = batch_shapes[si]
# Clip boxes to image bounds and resize to input shape
clip_boxes(pred, (height, width))
boxes = pred[:, :4].clone() # xyxy
confs = pred[:, 4].clone().tolist()
scale_boxes(image_shape, boxes, original_shape[0], original_shape[1]) # to original shape
frame = [ [*bb, conf] for bb, conf in zip(boxes.tolist(), confs) ]
file_name = file_names[batch_i*32 + si]
for ann in frame:
ann.append(file_name)
frame_list.append(frame)
return frame_list
def argument_parser():
parser = argparse.ArgumentParser()
parser.add_argument("--frames", required=True, help="Path to frame directory. Required.")
parser.add_argument("--output", required=True, help="Path to output directory. Required.")
parser.add_argument("--weights", default='models/v5m_896_300best.pt', help="Path to saved YOLOv5 weights. Default: ../models/v5m_896_300best.pt")
return parser
if __name__ == "__main__":
args = argument_parser().parse_args()
main(args)