File size: 6,065 Bytes
7a4b92f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
import project_path

import torch
from tqdm import tqdm
from functools import partial
import numpy as np
import json
from unittest.mock import patch

# assumes yolov5 on sys.path
from lib.yolov5.models.experimental import attempt_load
from lib.yolov5.utils.torch_utils import select_device
from lib.yolov5.utils.general import non_max_suppression
from lib.yolov5.utils.general import clip_boxes, scale_boxes

from lib.fish_eye.tracker import Tracker

### Configuration options
WEIGHTS = 'models/v5m_896_300best.pt'
# will need to configure these based on GPU hardware
BATCH_SIZE = 32

conf_thres = 0.3 # detection
iou_thres  = 0.3 # NMS IOU
min_length = 0.3 # minimum fish length, in meters
###

def norm(bbox, w, h):
    """
    Normalize a bounding box.
    Args:
        bbox: list of length 4. Can be [x,y,w,h] or [x0,y0,x1,y1]
        w: image width
        h: image height
    """
    bb = bbox.copy()
    bb[0] /= w
    bb[1] /= h
    bb[2] /= w
    bb[3] /= h
    return bb

def do_full_inference(dataloader, image_meter_width, image_meter_height, gp=None, weights=WEIGHTS):
    
    model, device = setup_model(weights)

    all_preds = do_detection(dataloader, model, device, gp=gp)

    results = do_tracking(all_preds, image_meter_width, image_meter_height, gp=gp)

    return results
    

def setup_model(weights_fp=WEIGHTS, imgsz=896, batch_size=32):
    if torch.cuda.is_available():
        device = select_device('0', batch_size=batch_size)
    else:
        print("CUDA not available. Using CPU inference.")
        device = select_device('cpu', batch_size=batch_size)
    
    # Setup model for inference
    model = attempt_load(weights_fp, device=device)
    half = device.type != 'cpu'  # half precision only supported on CUDA
    if half:
        model.half()
    model.eval();
    
    # Create dataloader for batched inference
    img = torch.zeros((1, 3, imgsz, imgsz), device=device)
    _ = model(img.half() if half else img) if device.type != 'cpu' else None  # run once
    
    return model, device
                       
def do_detection(dataloader, model, device, gp=None, batch_size=BATCH_SIZE):
    """
    Args:
        frames_dir: a directory containing frames to be evaluated
        image_meter_width: the width of each image, in meters (used for fish length calculation)
        gp: a callback function which takes as input 1 parameter, (int) percent complete
        prep_for_marking: re-index fish for manual marking output
    """

    if (gp): gp(0, "Detection...")
    
    # keep predictions to feed them ordered into the Tracker
    # TODO: how to deal with large files?
    all_preds = {}

    # Run detection
    with tqdm(total=len(dataloader)*batch_size, desc="Running detection", ncols=0) as pbar:
        for batch_i, (img, _, shapes) in enumerate(dataloader):
            if gp: gp(batch_i / len(dataloader), pbar.__str__())
            img = img.to(device, non_blocking=True)
            img = img.half() if device.type != 'cpu' else img.float()  # uint8 to fp16/32
            img /= 255.0  # 0 - 255 to 0.0 - 1.0
            nb, _, height, width = img.shape  # batch size, channels, height, width
            # Run model & NMS
            with torch.no_grad():
                inf_out, _ = model(img, augment=False) 
                output = non_max_suppression(inf_out, conf_thres=conf_thres, iou_thres=iou_thres)

            # Format results
            for si, pred in enumerate(output):
                # Clip boxes to image bounds and resize to input shape
                clip_boxes(pred, (height, width))
                box = pred[:, :4].clone()  # xyxy
                confs = pred[:, 4].clone().tolist()
                scale_boxes(img[si].shape[1:], box, shapes[si][0], shapes[si][1])  # to original shape
                
                # get boxes into tracker input format - normalized xyxy with confidence score
                # confidence score currently not used by tracker; set to 1.0
                boxes = None
                if box.shape[0]:
                    do_norm = partial(norm, w=shapes[si][0][1], h=shapes[si][0][0])
                    normed = list((map(do_norm, box[:, :4].tolist())))
                    boxes = np.stack([ [*bb, conf] for bb, conf in zip(normed, confs) ])
                frame_num = (batch_i, si)
                all_preds[frame_num] = boxes

            pbar.update(1*batch_size)
         
    return all_preds

def do_tracking(all_preds, image_meter_width, image_meter_height, gp=None): 

    if (gp): gp(0, "Tracking...")

    # Initialize tracker
    clip_info = {
        'start_frame': 0,
        'end_frame': len(all_preds),
        'image_meter_width': image_meter_width,
        'image_meter_height': image_meter_height
    }
    tracker = Tracker(clip_info, args={ 'max_age': 9, 'min_hits': 0, 'iou_threshold': 0.01}, min_hits=11)
    
    # Run tracking
    with tqdm(total=len(all_preds), desc="Running tracking", ncols=0) as pbar:
        for i, key in enumerate(sorted(all_preds.keys())):
            if gp: gp(i / len(all_preds), pbar.__str__())
            boxes = all_preds[key]
            if boxes is not None:
                tracker.update(boxes)
            else:
                tracker.update()
            pbar.update(1)
    json_data = tracker.finalize(min_length=min_length)

    return json_data


@patch('json.encoder.c_make_encoder', None)
def json_dump_round_float(some_object, out_path, num_digits=4):
    """Write a json file to disk with a specified level of precision.
    See: https://gist.github.com/Sukonnik-Illia/ed9b2bec1821cad437d1b8adb17406a3
    """
    # saving original method
    of = json.encoder._make_iterencode
    def inner(*args, **kwargs):
        args = list(args)
        # fifth argument is float formater which will we replace
        fmt_str = '{:.' + str(num_digits) + 'f}'
        args[4] = lambda o: fmt_str.format(o)
        return of(*args, **kwargs)
    
    with patch('json.encoder._make_iterencode', wraps=inner):
        return json.dump(some_object, open(out_path, 'w'), indent=2)