oskarastrom commited on
Commit
1ae5c71
·
1 Parent(s): 9ab5dcd

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +97 -22
inference.py CHANGED
@@ -330,45 +330,36 @@ def do_tracking(all_preds, image_meter_width, image_meter_height, gp=None, max_a
330
 
331
  return json_data
332
 
333
- def do_associative_tracking(raw_detections, image_meter_width, image_meter_height, gp=None, conf_thresh=0.2, max_age=MAX_AGE, iou_thres=IOU_THRES, min_hits=MIN_HITS, min_length=MIN_LENGTH, verbose=True):
334
 
335
  if (gp): gp(0, "Tracking...")
336
 
337
  print("Preprocessing")
338
 
339
- low_dets = []
340
- high_dets = []
341
- with tqdm(total=len(raw_detections), desc="Running tracking", ncols=0, disable=not verbose) as pbar:
342
- for batch in raw_detections:
343
- for frame in batch:
344
- low_frame = []
345
- high_frame = []
346
- for bbox in frame:
347
- if bbox[4] > conf_thresh:
348
- high_frame.append(bbox)
349
- else:
350
- low_frame.append(bbox)
351
- low_dets.append(low_frame)
352
- high_dets.append(high_frame)
353
- pbar.update(1)
354
 
355
  print("Preprocess done")
356
 
357
  # Initialize tracker
358
  clip_info = {
359
  'start_frame': 0,
360
- 'end_frame': len(raw_detections),
361
  'image_meter_width': image_meter_width,
362
  'image_meter_height': image_meter_height
363
  }
364
  tracker = Tracker(clip_info, algorithm=Associate, args={ 'max_age': max_age, 'min_hits': 0, 'iou_threshold': iou_thres}, min_hits=min_hits)
365
 
366
  # Run tracking
367
- with tqdm(total=len(low_dets), desc="Running tracking", ncols=0, disable=not verbose) as pbar:
368
- for i in range(len(low_dets)):
369
- if gp: gp(i / len(low_dets), pbar.__str__())
370
- low_boxes = low_dets[i]
371
- high_boxes = high_dets[i]
372
  boxes = (low_boxes, high_boxes)
373
  if len(low_boxes) + len(high_boxes) > 0:
374
  tracker.update(boxes)
@@ -455,6 +446,90 @@ def non_max_suppression(
455
  x = torch.cat((box, conf, j.float(), mask), 1)[conf.view(-1) > conf_thres]
456
 
457
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
458
  # Check shape
459
  n = x.shape[0] # number of boxes
460
  if not n: # no boxes
 
330
 
331
  return json_data
332
 
333
+ def do_associative_tracking(inference, image_shapes, width, height, image_meter_width, image_meter_height, gp=None, low_thresh=0.001, high_threshold=0.2, max_age=MAX_AGE, iou_thres=IOU_THRES, min_hits=MIN_HITS, min_length=MIN_LENGTH, batch_size=BATCH_SIZE, verbose=True):
334
 
335
  if (gp): gp(0, "Tracking...")
336
 
337
  print("Preprocessing")
338
 
339
+
340
+ low_outputs = do_suppression(inference, conf_thres=low_thresh, iou_thres=iou_thres, gp=gp)
341
+ low_preds, real_width, real_height = format_predictions(image_shapes, low_outputs, width, height, verbose=verbose)
342
+
343
+ high_outputs = do_suppression(inference, conf_thres=high_threshold, iou_thres=iou_thres, gp=gp)
344
+ high_preds, real_width, real_height = format_predictions(image_shapes, high_outputs, width, height, verbose=verbose)
 
 
 
 
 
 
 
 
 
345
 
346
  print("Preprocess done")
347
 
348
  # Initialize tracker
349
  clip_info = {
350
  'start_frame': 0,
351
+ 'end_frame': len(low_preds),
352
  'image_meter_width': image_meter_width,
353
  'image_meter_height': image_meter_height
354
  }
355
  tracker = Tracker(clip_info, algorithm=Associate, args={ 'max_age': max_age, 'min_hits': 0, 'iou_threshold': iou_thres}, min_hits=min_hits)
356
 
357
  # Run tracking
358
+ with tqdm(total=len(low_preds), desc="Running tracking", ncols=0, disable=not verbose) as pbar:
359
+ for i in range(len(low_preds)):
360
+ if gp: gp(i / len(low_preds), pbar.__str__())
361
+ low_boxes = low_preds[i]
362
+ high_boxes = high_preds[i]
363
  boxes = (low_boxes, high_boxes)
364
  if len(low_boxes) + len(high_boxes) > 0:
365
  tracker.update(boxes)
 
446
  x = torch.cat((box, conf, j.float(), mask), 1)[conf.view(-1) > conf_thres]
447
 
448
 
449
+ # Check shape
450
+ n = x.shape[0] # number of boxes
451
+ if not n: # no boxes
452
+ continue
453
+ x = x[x[:, 4].argsort(descending=True)[:max_nms]] # sort by confidence and remove excess boxes
454
+
455
+ # Batched NMS
456
+ boxes = x[:, :4] # boxes (offset by class), scores
457
+ scores = x[:, 4]
458
+ i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS
459
+
460
+ i = i[:max_det] # limit detections
461
+ if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean)
462
+ # update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
463
+ iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix
464
+ weights = iou * scores[None] # box weights
465
+ x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes
466
+ if redundant:
467
+ i = i[iou.sum(1) > 1] # require redundancy
468
+
469
+ output[xi] = x[i]
470
+ if mps:
471
+ output[xi] = output[xi].to(device)
472
+
473
+ logging = False
474
+
475
+ return output
476
+
477
+
478
+ def no_suppression(
479
+ prediction,
480
+ conf_thres=0.25,
481
+ iou_thres=0.45,
482
+ max_det=300,
483
+ ):
484
+ """Non-Maximum Suppression (NMS) on inference results to reject overlapping detections
485
+
486
+ Returns:
487
+ list of detections, on (n,6) tensor per image [xyxy, conf, cls]
488
+ """
489
+
490
+ # Checks
491
+ assert 0 <= conf_thres <= 1, f'Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0'
492
+ assert 0 <= iou_thres <= 1, f'Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0'
493
+ if isinstance(prediction, (list, tuple)): # YOLOv5 model in validation model, output = (inference_out, loss_out)
494
+ prediction = prediction[0] # select only inference output
495
+
496
+ device = prediction.device
497
+ mps = 'mps' in device.type # Apple MPS
498
+ if mps: # MPS not fully supported yet, convert tensors to CPU before NMS
499
+ prediction = prediction.cpu()
500
+ bs = prediction.shape[0] # batch size
501
+ xc = prediction[..., 4] > conf_thres # candidates
502
+
503
+ # Settings
504
+ # min_wh = 2 # (pixels) minimum box width and height
505
+ max_nms = 30000 # maximum number of boxes into torchvision.ops.nms()
506
+ redundant = True # require redundant detections
507
+ merge = False # use merge-NMS
508
+
509
+ output = [torch.zeros((0, 6), device=prediction.device)] * bs
510
+ for xi, x in enumerate(prediction): # image index, image inference
511
+
512
+
513
+ # Keep boxes that pass confidence threshold
514
+ x = x[xc[xi]] # confidence
515
+
516
+ # If none remain process next image
517
+ if not x.shape[0]:
518
+ continue
519
+
520
+ # Compute conf
521
+ x[:, 5:] *= x[:, 4:5] # conf = obj_conf * cls_conf
522
+
523
+
524
+ # Box/Mask
525
+ box = xywh2xyxy(x[:, :4]) # center_x, center_y, width, height) to (x1, y1, x2, y2)
526
+ mask = x[:, 6:] # zero columns if no masks
527
+
528
+ # Detections matrix nx6 (xyxy, conf, cls)
529
+ conf, j = x[:, 5:6].max(1, keepdim=True)
530
+ x = torch.cat((box, conf, j.float(), mask), 1)[conf.view(-1) > conf_thres]
531
+
532
+
533
  # Check shape
534
  n = x.shape[0] # number of boxes
535
  if not n: # no boxes