Spaces:
Runtime error
Runtime error
import argparse | |
import torch | |
import os | |
import json | |
from tqdm import tqdm | |
import project_subpath | |
from backend.InferenceConfig import InferenceConfig | |
from backend.dataloader import create_dataloader_frames_only | |
from backend.inference import do_full_tracking, setup_model, do_detection | |
def main(args, config=InferenceConfig(), verbose=True): | |
""" | |
Perform inference on a directory of frames and saves the tracking json result | |
Args: | |
frames (str): Path to frame directory. Required. | |
metadata (str): Path to metadata directory. Required. | |
output (str): Path to output directory. Required. | |
weights (str): Path to saved YOLOv5 weights. Default: ../models/v5m_896_300best.pt | |
""" | |
print("In task...") | |
print("Cuda available in task?", torch.cuda.is_available()) | |
print("Config:", config.to_dict()) | |
dirname = args.frames | |
loc = args.location | |
in_loc_dir = os.path.join(dirname, loc) | |
out_dir = os.path.join(args.output, loc, "tracker", "data") | |
metadata_path = os.path.join(args.metadata, loc + ".json") | |
os.makedirs(out_dir, exist_ok=True) | |
print(in_loc_dir) | |
print(out_dir) | |
print(metadata_path) | |
# run detection + tracking | |
model, device = setup_model(args.weights) | |
seq_list = os.listdir(in_loc_dir) | |
idx = 1 | |
with tqdm(total=len(seq_list), desc="...", ncols=0) as pbar: | |
for seq in seq_list: | |
pbar.update(1) | |
pbar.set_description("Processing " + seq) | |
if verbose: | |
print(" ") | |
print("(" + str(idx) + "/" + str(len(seq_list)) + ") " + seq) | |
print(" ") | |
idx += 1 | |
in_seq_dir = os.path.join(in_loc_dir, seq) | |
infer_seq(in_seq_dir, out_dir, config, seq, model, device, metadata_path, verbose) | |
def infer_seq(in_dir, out_dir, config, seq_name, model, device, metadata_path, verbose): | |
#progress_log = lambda p, m: 0 | |
image_meter_width = -1 | |
image_meter_height = -1 | |
with open(metadata_path, 'r') as f: | |
json_object = json.loads(f.read()) | |
for seq in json_object: | |
if seq['clip_name'] == seq_name: | |
image_meter_width = seq['x_meter_stop'] - seq['x_meter_start'] | |
image_meter_height = seq['y_meter_stop'] - seq['y_meter_start'] | |
if (image_meter_height == -1): | |
print("No metadata found for file " + seq_name) | |
return | |
# create dataloader | |
dataloader = create_dataloader_frames_only(in_dir) | |
try: | |
inference, image_shapes, width, height = do_detection(dataloader, model, device, verbose=verbose) | |
except: | |
print("Error in " + seq_name) | |
with open(os.path.join(out_dir, "ERROR_" + seq_name + ".txt"), 'w') as f: | |
f.write("ERROR") | |
return | |
real_width = image_shapes[0][0][0][1] | |
real_height = image_shapes[0][0][0][0] | |
results = do_full_tracking(inference, image_shapes, image_meter_width, image_meter_height, width, height, config=config, gp=None, verbose=verbose) | |
mot_rows = [] | |
for frame in results['frames']: | |
for fish in frame['fish']: | |
bbox = fish['bbox'] | |
row = [] | |
right = bbox[0]*real_width | |
top = bbox[1]*real_height | |
w = bbox[2]*real_width - bbox[0]*real_width | |
h = bbox[3]*real_height - bbox[1]*real_height | |
row.append(str(frame['frame_num'] + 1)) | |
row.append(str(fish['fish_id'] + 1)) | |
row.append(str(int(right))) | |
row.append(str(int(top))) | |
row.append(str(int(w))) | |
row.append(str(int(h))) | |
row.append("-1") | |
row.append("-1") | |
row.append("-1") | |
row.append("-1") | |
mot_rows.append(",".join(row)) | |
mot_text = "\n".join(mot_rows) | |
with open(os.path.join(out_dir, seq_name + ".txt"), 'w') as f: | |
f.write(mot_text) | |
return | |
def argument_parser(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--frames", required=True, help="Path to frame directory. Required.") | |
parser.add_argument("--location", required=True, help="Name of location dir. Required.") | |
parser.add_argument("--metadata", required=True, help="Path to metadata directory. Required.") | |
parser.add_argument("--output", required=True, help="Path to output directory. Required.") | |
parser.add_argument("--weights", default='models/v5m_896_300best.pt', help="Path to saved YOLOv5 weights. Default: ../models/v5m_896_300best.pt") | |
return parser | |
if __name__ == "__main__": | |
args = argument_parser().parse_args() | |
main(args) |