File size: 5,821 Bytes
809371f
25ba50d
809371f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73afcbb
 
809371f
 
 
 
 
711b619
 
809371f
 
 
 
 
 
 
 
 
 
f639061
809371f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91c1277
809371f
409a900
91c1277
 
 
 
 
 
b9c5fdf
 
25bfc19
91c1277
809371f
91c1277
4744754
809371f
 
 
 
 
 
 
 
 
9bcdb1d
 
 
38ce66e
ef7e2aa
809371f
 
 
f639061
711b619
809371f
 
 
1548828
809371f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1953d07
f639061
809371f
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
import project_path
from lib.yolov5.utils.torch_utils import select_device
from lib.yolov5.utils.general import clip_boxes, scale_boxes
import argparse
from datetime import datetime
import torch
import os
from dataloader import create_dataloader_frames_only
from inference import setup_model, do_detection, do_suppression, do_confidence_boost, format_predictions, do_tracking
from visualizer import generate_video_batches
import json
from tqdm import tqdm
import numpy as np


def main(args, config={}, verbose=True):
    """
    Main processing task to be run in gradio
        - Writes aris frames to dirname(filepath)/frames/{i}.jpg
        - Writes json output to dirname(filepath)/{filename}_results.json
        - Writes manual marking to dirname(filepath)/{filename}_marking.txt
        - Writes video output to dirname(filepath)/{filename}_results.mp4
        - Zips all results to dirname(filepath)/{filename}_results.zip
    Args:
        filepath (str): path to aris file
        
    TODO: Separate into subtasks in different queues; have a GPU-only queue.
    """

    # setup config
    if "conf_threshold" not in config: config['conf_threshold'] = 0.3#0.001
    if "nms_iou" not in config: config['nms_iou'] = 0.3#0.6
    if "min_length" not in config: config['min_length'] = 0.3
    if "max_age" not in config: config['max_age'] = 20
    if "iou_threshold" not in config: config['iou_threshold'] = 0.01
    if "min_hits" not in config: config['min_hits'] = 11
    if "use_associative" not in config: config['use_associative'] = False
    if "boost_power" not in config: config['boost_power'] = 1
    if "boost_decay" not in config: config['boost_decay'] = 1

    print(config)

    
    locations = [
        "kenai-val"
    ]
    for loc in locations:

        in_loc_dir = os.path.join(args.detections, loc)
        out_loc_dir = os.path.join(args.output, loc, args.tracker, "data")
        os.makedirs(out_loc_dir, exist_ok=True)
        metadata_path = os.path.join(args.metadata, loc + ".json")
        print(in_loc_dir)
        print(out_loc_dir)
        print(metadata_path)

        track_location(in_loc_dir, out_loc_dir, metadata_path, config, verbose)


                
def track_location(in_loc_dir, out_loc_dir, metadata_path, config, verbose):

    seq_list = os.listdir(in_loc_dir)

    with tqdm(total=len(seq_list), desc="...", ncols=0) as pbar:
        for seq in seq_list:

            pbar.update(1)
            if (seq.startswith(".")): continue
            pbar.set_description("Processing " + seq)


            track(in_loc_dir, out_loc_dir, metadata_path, seq, config, verbose)

def track(in_loc_dir, out_loc_dir, metadata_path, seq, config, verbose):

    json_path = os.path.join(in_loc_dir, seq, 'pred.json')
    inference_path = os.path.join(in_loc_dir, seq, 'inference.pt')
    out_path = os.path.join(out_loc_dir, seq + ".txt")


    device_name = 'cuda:0' if torch.cuda.is_available() else 'cpu'
    device = torch.device(device_name)
    inference = torch.load(inference_path, map_location=device)

    # read detection
    with open(json_path, 'r') as f:
        detection = json.load(f)
    image_shapes = detection['image_shapes']
    width = detection['width']
    height = detection['height']

    # read metadata
    image_meter_width = -1
    image_meter_height = -1
    with open(metadata_path, 'r') as f:
        json_object = json.loads(f.read())
        for sequence in json_object:
            if sequence['clip_name'] == seq:
                image_meter_width = sequence['x_meter_stop'] - sequence['x_meter_start']
                image_meter_height = sequence['y_meter_start'] - sequence['y_meter_stop']

    outputs = do_suppression(inference, conf_thres=config['conf_threshold'], iou_thres=config['nms_iou'], verbose=verbose)

    if config['use_associative']:

        do_confidence_boost(inference, outputs, conf_power=config['boost_power'], conf_decay=config['boost_decay'], verbose=verbose)

        outputs = do_suppression(inference, conf_thres=config['conf_threshold'], iou_thres=config['nms_iou'], verbose=verbose)

    all_preds, real_width, real_height = format_predictions(image_shapes, outputs, width, height, verbose=verbose)

    results = do_tracking(all_preds, image_meter_width, image_meter_height, min_length=config['min_length'], max_age=config['max_age'], iou_thres=config['iou_threshold'], min_hits=config['min_hits'], verbose=verbose)

    mot_rows = []
    for frame in results['frames']:
        for fish in frame['fish']:
            bbox = fish['bbox']
            row = []
            right = bbox[0]*real_width
            top = bbox[1]*real_height
            w = bbox[2]*real_width - bbox[0]*real_width
            h = bbox[3]*real_height - bbox[1]*real_height

            row.append(str(frame['frame_num'] + 1))
            row.append(str(fish['fish_id'] + 1))
            row.append(str(int(right)))
            row.append(str(int(top)))
            row.append(str(int(w)))
            row.append(str(int(h)))
            row.append("-1")
            row.append("-1")
            row.append("-1")
            row.append("-1")
            mot_rows.append(",".join(row))

    mot_text = "\n".join(mot_rows)

    with open(out_path, 'w') as f:
        f.write(mot_text)

def argument_parser():
    parser = argparse.ArgumentParser()
    parser.add_argument("--detections", required=True, help="Path to frame directory. Required.")
    parser.add_argument("--output", required=True, help="Path to output directory. Required.")
    parser.add_argument("--metadata", required=True, help="Path to output directory. Required.")
    parser.add_argument("--tracker", default='tracker', help="Path to output directory. Required.")
    return parser

if __name__ == "__main__":
    args = argument_parser().parse_args()
    main(args)