oskarastrom commited on
Commit
889bde6
·
1 Parent(s): bb2dfaa

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +24 -8
inference.py CHANGED
@@ -18,6 +18,7 @@ import torch
18
  import torchvision
19
 
20
  from lib.fish_eye.tracker import Tracker
 
21
 
22
 
23
  ### Configuration options
@@ -329,11 +330,24 @@ def do_tracking(all_preds, image_meter_width, image_meter_height, gp=None, max_a
329
 
330
  return json_data
331
 
332
- def do_associative_tracking(raw_detections, image_meter_width, image_meter_height, gp=None, max_age=MAX_AGE, iou_thres=IOU_THRES, min_hits=MIN_HITS, min_length=MIN_LENGTH, verbose=True):
333
 
334
  if (gp): gp(0, "Tracking...")
335
 
336
- print(len(raw_detections))
 
 
 
 
 
 
 
 
 
 
 
 
 
337
 
338
  # Initialize tracker
339
  clip_info = {
@@ -342,14 +356,16 @@ def do_associative_tracking(raw_detections, image_meter_width, image_meter_heigh
342
  'image_meter_width': image_meter_width,
343
  'image_meter_height': image_meter_height
344
  }
345
- tracker = Tracker(clip_info, args={ 'max_age': max_age, 'min_hits': 0, 'iou_threshold': iou_thres}, min_hits=min_hits)
346
 
347
  # Run tracking
348
- with tqdm(total=len(all_preds), desc="Running tracking", ncols=0, disable=not verbose) as pbar:
349
- for i, key in enumerate(sorted(all_preds.keys())):
350
- if gp: gp(i / len(all_preds), pbar.__str__())
351
- boxes = all_preds[key]
352
- if boxes is not None:
 
 
353
  tracker.update(boxes)
354
  else:
355
  tracker.update()
 
18
  import torchvision
19
 
20
  from lib.fish_eye.tracker import Tracker
21
+ from lib.fish_eye.associative import Associate
22
 
23
 
24
  ### Configuration options
 
330
 
331
  return json_data
332
 
333
+ def do_associative_tracking(raw_detections, image_meter_width, image_meter_height, gp=None, conf_thresh=0.2, max_age=MAX_AGE, iou_thres=IOU_THRES, min_hits=MIN_HITS, min_length=MIN_LENGTH, verbose=True):
334
 
335
  if (gp): gp(0, "Tracking...")
336
 
337
+ low_dets = []
338
+ high_dets = []
339
+ for batch in raw_detections:
340
+ for frame in batch:
341
+ low_frame = []
342
+ high_frame = []
343
+ for bbox in frame:
344
+ if bbox[4] > conf_thresh:
345
+ high_frame.append(bbox)
346
+ else:
347
+ low_frame.append(bbox)
348
+ low_dets.append(low_frame)
349
+ high_dets.append(high_frame)
350
+
351
 
352
  # Initialize tracker
353
  clip_info = {
 
356
  'image_meter_width': image_meter_width,
357
  'image_meter_height': image_meter_height
358
  }
359
+ tracker = Tracker(clip_info, algorithm=Associate, args={ 'max_age': max_age, 'min_hits': 0, 'iou_threshold': iou_thres}, min_hits=min_hits)
360
 
361
  # Run tracking
362
+ with tqdm(total=len(low_dets), desc="Running tracking", ncols=0, disable=not verbose) as pbar:
363
+ for i, key in enumerate(sorted(low_dets.keys())):
364
+ if gp: gp(i / len(low_dets), pbar.__str__())
365
+ low_boxes = low_dets[key]
366
+ high_boxes = high_dets[key]
367
+ boxes = (low_boxes, high_boxes)
368
+ if len(low_boxes) + len(high_boxes) > 0:
369
  tracker.update(boxes)
370
  else:
371
  tracker.update()