File size: 4,429 Bytes
128e4f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import argparse
import torch
import os
import json
from tqdm import tqdm

import project_subpath

from backend.InferenceConfig import InferenceConfig
from backend.inference import do_full_tracking


def main(args, config=InferenceConfig(), verbose=True):
    """
    Convert raw detections to tracks and saves the tracking json result
    Args:
        detections (str): path to raw detections directory. Required
        output (str): where tracking result will be stored. Required
        metadata (str): path to metadata directory. Required
        tracker (str): arbitrary name of tracker folder that you want to save trajectories to
    """

    print("running detections_to_tracks.py with:", config.to_dict())

    loc = args.location

    in_loc_dir = os.path.join(args.detections, loc)
    out_loc_dir = os.path.join(args.output, loc, args.tracker, "data")
    os.makedirs(out_loc_dir, exist_ok=True)
    metadata_path = os.path.join(args.metadata, loc + ".json")
    print(in_loc_dir)
    print(out_loc_dir)
    print(metadata_path)

    track_location(in_loc_dir, out_loc_dir, metadata_path, config, verbose)


                
def track_location(in_loc_dir, out_loc_dir, metadata_path, config, verbose):

    seq_list = os.listdir(in_loc_dir)

    with tqdm(total=len(seq_list), desc="...", ncols=0) as pbar:
        for seq in seq_list:

            pbar.update(1)
            if (seq.startswith(".")): continue
            pbar.set_description("Processing " + seq)


            track(in_loc_dir, out_loc_dir, metadata_path, seq, config, verbose)

def track(in_loc_dir, out_loc_dir, metadata_path, seq, config, verbose):

    json_path = os.path.join(in_loc_dir, seq, 'pred.json')
    inference_path = os.path.join(in_loc_dir, seq, 'inference.pt')
    out_path = os.path.join(out_loc_dir, seq + ".txt")


    device_name = 'cuda:0' if torch.cuda.is_available() else 'cpu'
    device = torch.device(device_name)
    inference = torch.load(inference_path, map_location=device)

    # read detection
    with open(json_path, 'r') as f:
        detection = json.load(f)
    image_shapes = detection['image_shapes']
    width = detection['width']
    height = detection['height']

    # read metadata
    image_meter_width = -1
    image_meter_height = -1
    with open(metadata_path, 'r') as f:
        json_object = json.loads(f.read())
        for sequence in json_object:
            if sequence['clip_name'] == seq:
                image_meter_width = sequence['x_meter_stop'] - sequence['x_meter_start']
                image_meter_height = sequence['y_meter_start'] - sequence['y_meter_stop']


    # assume all images in the sequence have the same shape
    real_width = image_shapes[0][0][0][1]
    real_height = image_shapes[0][0][0][0]

    # perform tracking
    results = do_full_tracking(inference, image_shapes, image_meter_width, image_meter_height, width, height, config=config, gp=None, verbose=verbose)

    # write tracking result
    mot_rows = []
    for frame in results['frames']:
        for fish in frame['fish']:
            bbox = fish['bbox']
            row = []
            right = bbox[0]*real_width
            top = bbox[1]*real_height
            w = bbox[2]*real_width - bbox[0]*real_width
            h = bbox[3]*real_height - bbox[1]*real_height

            row.append(str(frame['frame_num'] + 1))
            row.append(str(fish['fish_id'] + 1))
            row.append(str(int(right)))
            row.append(str(int(top)))
            row.append(str(int(w)))
            row.append(str(int(h)))
            row.append("-1")
            row.append("-1")
            row.append("-1")
            row.append("-1")
            mot_rows.append(",".join(row))

    mot_text = "\n".join(mot_rows)

    with open(out_path, 'w') as f:
        f.write(mot_text)

def argument_parser():
    parser = argparse.ArgumentParser()
    parser.add_argument("--detections", required=True, help="Path to frame directory. Required.")
    parser.add_argument("--location", required=True, help="Name of location dir. Required.")
    parser.add_argument("--output", required=True, help="Path to output directory. Required.")
    parser.add_argument("--metadata", required=True, help="Path to metadata directory. Required.")
    parser.add_argument("--tracker", default='tracker', help="Tracker name.")
    return parser

if __name__ == "__main__":
    args = argument_parser().parse_args()
    main(args)