Spaces:
Runtime error
Runtime error
Commit
·
889bde6
1
Parent(s):
bb2dfaa
Update inference.py
Browse files- inference.py +24 -8
inference.py
CHANGED
@@ -18,6 +18,7 @@ import torch
|
|
18 |
import torchvision
|
19 |
|
20 |
from lib.fish_eye.tracker import Tracker
|
|
|
21 |
|
22 |
|
23 |
### Configuration options
|
@@ -329,11 +330,24 @@ def do_tracking(all_preds, image_meter_width, image_meter_height, gp=None, max_a
|
|
329 |
|
330 |
return json_data
|
331 |
|
332 |
-
def do_associative_tracking(raw_detections, image_meter_width, image_meter_height, gp=None, max_age=MAX_AGE, iou_thres=IOU_THRES, min_hits=MIN_HITS, min_length=MIN_LENGTH, verbose=True):
|
333 |
|
334 |
if (gp): gp(0, "Tracking...")
|
335 |
|
336 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
337 |
|
338 |
# Initialize tracker
|
339 |
clip_info = {
|
@@ -342,14 +356,16 @@ def do_associative_tracking(raw_detections, image_meter_width, image_meter_heigh
|
|
342 |
'image_meter_width': image_meter_width,
|
343 |
'image_meter_height': image_meter_height
|
344 |
}
|
345 |
-
tracker = Tracker(clip_info, args={ 'max_age': max_age, 'min_hits': 0, 'iou_threshold': iou_thres}, min_hits=min_hits)
|
346 |
|
347 |
# Run tracking
|
348 |
-
with tqdm(total=len(
|
349 |
-
for i, key in enumerate(sorted(
|
350 |
-
if gp: gp(i / len(
|
351 |
-
|
352 |
-
|
|
|
|
|
353 |
tracker.update(boxes)
|
354 |
else:
|
355 |
tracker.update()
|
|
|
18 |
import torchvision
|
19 |
|
20 |
from lib.fish_eye.tracker import Tracker
|
21 |
+
from lib.fish_eye.associative import Associate
|
22 |
|
23 |
|
24 |
### Configuration options
|
|
|
330 |
|
331 |
return json_data
|
332 |
|
333 |
+
def do_associative_tracking(raw_detections, image_meter_width, image_meter_height, gp=None, conf_thresh=0.2, max_age=MAX_AGE, iou_thres=IOU_THRES, min_hits=MIN_HITS, min_length=MIN_LENGTH, verbose=True):
|
334 |
|
335 |
if (gp): gp(0, "Tracking...")
|
336 |
|
337 |
+
low_dets = []
|
338 |
+
high_dets = []
|
339 |
+
for batch in raw_detections:
|
340 |
+
for frame in batch:
|
341 |
+
low_frame = []
|
342 |
+
high_frame = []
|
343 |
+
for bbox in frame:
|
344 |
+
if bbox[4] > conf_thresh:
|
345 |
+
high_frame.append(bbox)
|
346 |
+
else:
|
347 |
+
low_frame.append(bbox)
|
348 |
+
low_dets.append(low_frame)
|
349 |
+
high_dets.append(high_frame)
|
350 |
+
|
351 |
|
352 |
# Initialize tracker
|
353 |
clip_info = {
|
|
|
356 |
'image_meter_width': image_meter_width,
|
357 |
'image_meter_height': image_meter_height
|
358 |
}
|
359 |
+
tracker = Tracker(clip_info, algorithm=Associate, args={ 'max_age': max_age, 'min_hits': 0, 'iou_threshold': iou_thres}, min_hits=min_hits)
|
360 |
|
361 |
# Run tracking
|
362 |
+
with tqdm(total=len(low_dets), desc="Running tracking", ncols=0, disable=not verbose) as pbar:
|
363 |
+
for i, key in enumerate(sorted(low_dets.keys())):
|
364 |
+
if gp: gp(i / len(low_dets), pbar.__str__())
|
365 |
+
low_boxes = low_dets[key]
|
366 |
+
high_boxes = high_dets[key]
|
367 |
+
boxes = (low_boxes, high_boxes)
|
368 |
+
if len(low_boxes) + len(high_boxes) > 0:
|
369 |
tracker.update(boxes)
|
370 |
else:
|
371 |
tracker.update()
|