File size: 5,633 Bytes
5ab0373
 
 
 
 
 
fbb3995
5ab0373
 
c9d11b2
5ab0373
 
c9d11b2
5ab0373
 
 
 
 
 
 
 
 
 
 
 
 
 
d57b89b
e8f4d7e
 
 
 
 
 
 
 
 
 
d57b89b
5ab0373
66f8a6f
d57b89b
 
 
 
 
 
 
 
 
 
e8f4d7e
 
 
d57b89b
 
c9d11b2
 
 
 
8b2b08b
 
 
 
c9d11b2
 
 
 
 
5ab0373
e8f4d7e
5ab0373
 
 
d57b89b
5ab0373
 
d57b89b
5ab0373
 
 
 
d57b89b
5ab0373
 
 
d57b89b
5ab0373
2482ba4
fbb3995
2482ba4
 
 
 
 
5ab0373
fbb3995
 
 
 
 
 
 
 
5ab0373
c9d11b2
5ab0373
 
 
 
 
 
2482ba4
 
 
 
 
 
5ab0373
2482ba4
 
 
 
5ab0373
 
 
 
 
 
 
 
d57b89b
5ab0373
 
 
 
 
 
 
d57b89b
 
5ab0373
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import project_path
import argparse
from datetime import datetime
import torch
import os
from dataloader import create_dataloader_frames_only
from inference import setup_model, do_detection, do_suppression, do_confidence_boost, format_predictions, do_tracking
from visualizer import generate_video_batches
import json
from tqdm import tqdm


def main(args, config={}, verbose=True):
    """
    Main processing task to be run in gradio
        - Writes aris frames to dirname(filepath)/frames/{i}.jpg
        - Writes json output to dirname(filepath)/{filename}_results.json
        - Writes manual marking to dirname(filepath)/{filename}_marking.txt
        - Writes video output to dirname(filepath)/{filename}_results.mp4
        - Zips all results to dirname(filepath)/{filename}_results.zip
    Args:
        filepath (str): path to aris file
        
    TODO: Separate into subtasks in different queues; have a GPU-only queue.
    """
    print("In task...")
    print("Cuda available in task?", torch.cuda.is_available())

    # setup config
    if "conf_threshold" not in config: config['conf_threshold'] = 0.3
    if "nms_iou" not in config: config['nms_iou'] = 0.3
    if "min_length" not in config: config['min_length'] = 0.3
    if "max_age" not in config: config['max_age'] = 20
    if "iou_threshold" not in config: config['iou_threshold'] = 0.01
    if "min_hits" not in config: config['min_hits'] = 11

    print(config)

    dirname = args.frames
    
    locations = ["kenai-val"]
    for loc in locations:

        in_loc_dir = os.path.join(dirname, loc)
        out_dir = os.path.join(args.output, loc, "tracker", "data")
        metadata_path = os.path.join(args.metadata, loc + ".json")
        os.makedirs(out_dir, exist_ok=True)
        print(in_loc_dir)
        print(out_dir)
        print(metadata_path)

        # run detection + tracking
        model, device = setup_model(args.weights)

        seq_list = os.listdir(in_loc_dir)
        idx = 1
        with tqdm(total=len(seq_list), desc="...", ncols=0) as pbar:
            for seq in seq_list:
                pbar.update(1)
                pbar.set_description("Processing " + seq)
                if verbose:
                    print(" ")
                    print("(" + str(idx) + "/" + str(len(seq_list)) + ") " + seq)
                    print(" ")
                idx += 1
                in_seq_dir = os.path.join(in_loc_dir, seq)
                infer_seq(in_seq_dir, out_dir, config, seq, model, device, metadata_path, verbose)

def infer_seq(in_dir, out_dir, config, seq_name, model, device, metadata_path, verbose):
    
    #progress_log = lambda p, m: 0

    image_meter_width = -1
    image_meter_height = -1
    with open(metadata_path, 'r') as f:
        json_object = json.loads(f.read())
        for seq in json_object:
            if seq['clip_name'] == seq_name:
                image_meter_width = seq['x_meter_stop'] - seq['x_meter_start']
                image_meter_height = seq['y_meter_stop'] - seq['y_meter_start']

    if (image_meter_height == -1):
        print("No metadata found for file " + seq_name)
        return

    # create dataloader
    dataloader = create_dataloader_frames_only(in_dir)

    try:
        inference, image_shapes, width, height = do_detection(dataloader, model, device, verbose=verbose)
    except:
        print("Error in " + seq_name)
        with open(os.path.join(out_dir, "ERROR_" + seq_name + ".txt"), 'w') as f:
            f.write("ERROR")
        return


    outputs = do_suppression(inference, conf_thres=config['conf_threshold'], iou_thres=config['nms_iou'], verbose=verbose)

    do_confidence_boost(inference, outputs, verbose=verbose)

    new_outputs = do_suppression(inference, conf_thres=config['conf_threshold'], iou_thres=config['nms_iou'], verbose=verbose)

    all_preds, real_width, real_height = format_predictions(image_shapes, new_outputs, width, height)

    results = do_tracking(all_preds, image_meter_width, image_meter_height, min_length=config['min_length'], max_age=config['max_age'], iou_thres=config['iou_threshold'], min_hits=config['min_hits'], verbose=verbose)

    mot_rows = []
    for frame in results['frames']:
        for fish in frame['fish']:
            bbox = fish['bbox']
            row = []
            right = bbox[0]*real_width
            top = bbox[1]*real_height
            w = bbox[2]*real_width - bbox[0]*real_width
            h = bbox[3]*real_height - bbox[1]*real_height

            row.append(str(frame['frame_num'] + 1))
            row.append(str(fish['fish_id'] + 1))
            row.append(str(int(right)))
            row.append(str(int(top)))
            row.append(str(int(w)))
            row.append(str(int(h)))
            row.append("-1")
            row.append("-1")
            row.append("-1")
            row.append("-1")
            mot_rows.append(",".join(row))

    mot_text = "\n".join(mot_rows)

    with open(os.path.join(out_dir, seq_name + ".txt"), 'w') as f:
        f.write(mot_text)

    return

def argument_parser():
    parser = argparse.ArgumentParser()
    parser.add_argument("--frames", required=True, help="Path to frame directory. Required.")
    parser.add_argument("--metadata", required=True, help="Path to metadata directory. Required.")
    parser.add_argument("--output", required=True, help="Path to output directory. Required.")
    parser.add_argument("--weights", default='models/v5m_896_300best.pt', help="Path to saved YOLOv5 weights. Default: ../models/v5m_896_300best.pt")
    return parser

if __name__ == "__main__":
    args = argument_parser().parse_args()
    main(args)