Spaces:
Runtime error
Runtime error
import argparse | |
import torch | |
import os | |
import json | |
from tqdm import tqdm | |
import project_subpath | |
from backend.dataloader import create_dataloader_frames_only | |
from backend.inference import setup_model, do_detection | |
def main(args, verbose=False): | |
""" | |
Construct and save raw detections from yolov5 based on a frame directory | |
Args: | |
frames (str): path to image directory | |
output (str): where detections will be stored | |
weights (str): path to model weights | |
""" | |
print("In task...") | |
print("Cuda available in task?", torch.cuda.is_available()) | |
model, device = setup_model(args.weights) | |
in_loc_dir = os.path.join(args.frames, args.location) | |
out_loc_dir = os.path.join(args.output, args.location) | |
print(in_loc_dir) | |
print(out_loc_dir) | |
detect_location(in_loc_dir, out_loc_dir, model, device, verbose) | |
def detect_location(in_loc_dir, out_loc_dir, model, device, verbose): | |
seq_list = os.listdir(in_loc_dir) | |
with tqdm(total=len(seq_list), desc="...", ncols=0) as pbar: | |
for seq in seq_list: | |
pbar.update(1) | |
if (seq.startswith(".")): continue | |
pbar.set_description("Processing " + seq) | |
in_seq_dir = os.path.join(in_loc_dir, seq) | |
out_seq_dir = os.path.join(out_loc_dir, seq) | |
os.makedirs(out_seq_dir, exist_ok=True) | |
detect(in_seq_dir, out_seq_dir, model, device, verbose) | |
def detect(in_seq_dir, out_seq_dir, model, device, verbose): | |
# create dataloader | |
dataloader = create_dataloader_frames_only(in_seq_dir) | |
inference, image_shapes, width, height = do_detection(dataloader, model, device, verbose=verbose) | |
json_obj = { | |
'image_shapes': image_shapes, | |
'width': width, | |
'height': height | |
} | |
with open(os.path.join(out_seq_dir, 'pred.json'), 'w') as f: | |
json.dump(json_obj, f) | |
torch.save(inference, os.path.join(out_seq_dir, 'inference.pt')) | |
def argument_parser(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--frames", default="../frames/images", help="Path to frame directory. Required.") | |
parser.add_argument("--location", default="kenai-val", help="Name of location dir. Required.") | |
parser.add_argument("--output", default="../frames/detections/detection_storage/", help="Path to output directory. Required.") | |
parser.add_argument("--weights", default='models/v5m_896_300best.pt', help="Path to saved YOLOv5 weights. Default: ../models/v5m_896_300best.pt") | |
return parser | |
if __name__ == "__main__": | |
args = argument_parser().parse_args() | |
main(args) |