Spaces:
Runtime error
Runtime error
Commit
·
1ae5c71
1
Parent(s):
9ab5dcd
Update inference.py
Browse files- inference.py +97 -22
inference.py
CHANGED
@@ -330,45 +330,36 @@ def do_tracking(all_preds, image_meter_width, image_meter_height, gp=None, max_a
|
|
330 |
|
331 |
return json_data
|
332 |
|
333 |
-
def do_associative_tracking(
|
334 |
|
335 |
if (gp): gp(0, "Tracking...")
|
336 |
|
337 |
print("Preprocessing")
|
338 |
|
339 |
-
|
340 |
-
|
341 |
-
|
342 |
-
|
343 |
-
|
344 |
-
|
345 |
-
high_frame = []
|
346 |
-
for bbox in frame:
|
347 |
-
if bbox[4] > conf_thresh:
|
348 |
-
high_frame.append(bbox)
|
349 |
-
else:
|
350 |
-
low_frame.append(bbox)
|
351 |
-
low_dets.append(low_frame)
|
352 |
-
high_dets.append(high_frame)
|
353 |
-
pbar.update(1)
|
354 |
|
355 |
print("Preprocess done")
|
356 |
|
357 |
# Initialize tracker
|
358 |
clip_info = {
|
359 |
'start_frame': 0,
|
360 |
-
'end_frame': len(
|
361 |
'image_meter_width': image_meter_width,
|
362 |
'image_meter_height': image_meter_height
|
363 |
}
|
364 |
tracker = Tracker(clip_info, algorithm=Associate, args={ 'max_age': max_age, 'min_hits': 0, 'iou_threshold': iou_thres}, min_hits=min_hits)
|
365 |
|
366 |
# Run tracking
|
367 |
-
with tqdm(total=len(
|
368 |
-
for i in range(len(
|
369 |
-
if gp: gp(i / len(
|
370 |
-
low_boxes =
|
371 |
-
high_boxes =
|
372 |
boxes = (low_boxes, high_boxes)
|
373 |
if len(low_boxes) + len(high_boxes) > 0:
|
374 |
tracker.update(boxes)
|
@@ -455,6 +446,90 @@ def non_max_suppression(
|
|
455 |
x = torch.cat((box, conf, j.float(), mask), 1)[conf.view(-1) > conf_thres]
|
456 |
|
457 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
458 |
# Check shape
|
459 |
n = x.shape[0] # number of boxes
|
460 |
if not n: # no boxes
|
|
|
330 |
|
331 |
return json_data
|
332 |
|
333 |
+
def do_associative_tracking(inference, image_shapes, width, height, image_meter_width, image_meter_height, gp=None, low_thresh=0.001, high_threshold=0.2, max_age=MAX_AGE, iou_thres=IOU_THRES, min_hits=MIN_HITS, min_length=MIN_LENGTH, batch_size=BATCH_SIZE, verbose=True):
|
334 |
|
335 |
if (gp): gp(0, "Tracking...")
|
336 |
|
337 |
print("Preprocessing")
|
338 |
|
339 |
+
|
340 |
+
low_outputs = do_suppression(inference, conf_thres=low_thresh, iou_thres=iou_thres, gp=gp)
|
341 |
+
low_preds, real_width, real_height = format_predictions(image_shapes, low_outputs, width, height, verbose=verbose)
|
342 |
+
|
343 |
+
high_outputs = do_suppression(inference, conf_thres=high_threshold, iou_thres=iou_thres, gp=gp)
|
344 |
+
high_preds, real_width, real_height = format_predictions(image_shapes, high_outputs, width, height, verbose=verbose)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
345 |
|
346 |
print("Preprocess done")
|
347 |
|
348 |
# Initialize tracker
|
349 |
clip_info = {
|
350 |
'start_frame': 0,
|
351 |
+
'end_frame': len(low_preds),
|
352 |
'image_meter_width': image_meter_width,
|
353 |
'image_meter_height': image_meter_height
|
354 |
}
|
355 |
tracker = Tracker(clip_info, algorithm=Associate, args={ 'max_age': max_age, 'min_hits': 0, 'iou_threshold': iou_thres}, min_hits=min_hits)
|
356 |
|
357 |
# Run tracking
|
358 |
+
with tqdm(total=len(low_preds), desc="Running tracking", ncols=0, disable=not verbose) as pbar:
|
359 |
+
for i in range(len(low_preds)):
|
360 |
+
if gp: gp(i / len(low_preds), pbar.__str__())
|
361 |
+
low_boxes = low_preds[i]
|
362 |
+
high_boxes = high_preds[i]
|
363 |
boxes = (low_boxes, high_boxes)
|
364 |
if len(low_boxes) + len(high_boxes) > 0:
|
365 |
tracker.update(boxes)
|
|
|
446 |
x = torch.cat((box, conf, j.float(), mask), 1)[conf.view(-1) > conf_thres]
|
447 |
|
448 |
|
449 |
+
# Check shape
|
450 |
+
n = x.shape[0] # number of boxes
|
451 |
+
if not n: # no boxes
|
452 |
+
continue
|
453 |
+
x = x[x[:, 4].argsort(descending=True)[:max_nms]] # sort by confidence and remove excess boxes
|
454 |
+
|
455 |
+
# Batched NMS
|
456 |
+
boxes = x[:, :4] # boxes (offset by class), scores
|
457 |
+
scores = x[:, 4]
|
458 |
+
i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS
|
459 |
+
|
460 |
+
i = i[:max_det] # limit detections
|
461 |
+
if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean)
|
462 |
+
# update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
|
463 |
+
iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix
|
464 |
+
weights = iou * scores[None] # box weights
|
465 |
+
x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes
|
466 |
+
if redundant:
|
467 |
+
i = i[iou.sum(1) > 1] # require redundancy
|
468 |
+
|
469 |
+
output[xi] = x[i]
|
470 |
+
if mps:
|
471 |
+
output[xi] = output[xi].to(device)
|
472 |
+
|
473 |
+
logging = False
|
474 |
+
|
475 |
+
return output
|
476 |
+
|
477 |
+
|
478 |
+
def no_suppression(
|
479 |
+
prediction,
|
480 |
+
conf_thres=0.25,
|
481 |
+
iou_thres=0.45,
|
482 |
+
max_det=300,
|
483 |
+
):
|
484 |
+
"""Non-Maximum Suppression (NMS) on inference results to reject overlapping detections
|
485 |
+
|
486 |
+
Returns:
|
487 |
+
list of detections, on (n,6) tensor per image [xyxy, conf, cls]
|
488 |
+
"""
|
489 |
+
|
490 |
+
# Checks
|
491 |
+
assert 0 <= conf_thres <= 1, f'Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0'
|
492 |
+
assert 0 <= iou_thres <= 1, f'Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0'
|
493 |
+
if isinstance(prediction, (list, tuple)): # YOLOv5 model in validation model, output = (inference_out, loss_out)
|
494 |
+
prediction = prediction[0] # select only inference output
|
495 |
+
|
496 |
+
device = prediction.device
|
497 |
+
mps = 'mps' in device.type # Apple MPS
|
498 |
+
if mps: # MPS not fully supported yet, convert tensors to CPU before NMS
|
499 |
+
prediction = prediction.cpu()
|
500 |
+
bs = prediction.shape[0] # batch size
|
501 |
+
xc = prediction[..., 4] > conf_thres # candidates
|
502 |
+
|
503 |
+
# Settings
|
504 |
+
# min_wh = 2 # (pixels) minimum box width and height
|
505 |
+
max_nms = 30000 # maximum number of boxes into torchvision.ops.nms()
|
506 |
+
redundant = True # require redundant detections
|
507 |
+
merge = False # use merge-NMS
|
508 |
+
|
509 |
+
output = [torch.zeros((0, 6), device=prediction.device)] * bs
|
510 |
+
for xi, x in enumerate(prediction): # image index, image inference
|
511 |
+
|
512 |
+
|
513 |
+
# Keep boxes that pass confidence threshold
|
514 |
+
x = x[xc[xi]] # confidence
|
515 |
+
|
516 |
+
# If none remain process next image
|
517 |
+
if not x.shape[0]:
|
518 |
+
continue
|
519 |
+
|
520 |
+
# Compute conf
|
521 |
+
x[:, 5:] *= x[:, 4:5] # conf = obj_conf * cls_conf
|
522 |
+
|
523 |
+
|
524 |
+
# Box/Mask
|
525 |
+
box = xywh2xyxy(x[:, :4]) # center_x, center_y, width, height) to (x1, y1, x2, y2)
|
526 |
+
mask = x[:, 6:] # zero columns if no masks
|
527 |
+
|
528 |
+
# Detections matrix nx6 (xyxy, conf, cls)
|
529 |
+
conf, j = x[:, 5:6].max(1, keepdim=True)
|
530 |
+
x = torch.cat((box, conf, j.float(), mask), 1)[conf.view(-1) > conf_thres]
|
531 |
+
|
532 |
+
|
533 |
# Check shape
|
534 |
n = x.shape[0] # number of boxes
|
535 |
if not n: # no boxes
|