Spaces:
Runtime error
Runtime error
Commit
·
e991912
1
Parent(s):
7e4e0ac
Update inference.py
Browse files- inference.py +6 -5
inference.py
CHANGED
@@ -54,7 +54,6 @@ def do_full_inference(dataloader, image_meter_width, image_meter_height, gp=None
|
|
54 |
|
55 |
# Load hyperparameters
|
56 |
if 'model' not in hyperparams: hyperparams['model'] = WEIGHTS
|
57 |
-
if 'conf_thresh' not in hyperparams: hyperparams['conf_thresh'] = CONF_THRES
|
58 |
if 'iou_thresh' not in hyperparams: hyperparams['iou_thresh'] = NMS_IOU
|
59 |
if 'min_hits' not in hyperparams: hyperparams['min_hits'] = MIN_HITS
|
60 |
if 'max_age' not in hyperparams: hyperparams['max_age'] = MAX_AGE
|
@@ -96,10 +95,10 @@ def do_full_inference(dataloader, image_meter_width, image_meter_height, gp=None
|
|
96 |
if 'byte_low_conf' not in hyperparams: hyperparams['byte_low_conf'] = 0.1
|
97 |
if 'byte_high_conf' not in hyperparams: hyperparams['byte_high_conf'] = 0.3
|
98 |
|
99 |
-
low_outputs = do_suppression(inference, conf_thres=hyperparams['
|
100 |
low_preds, real_width, real_height = format_predictions(image_shapes, low_outputs, width, height, gp=gp)
|
101 |
|
102 |
-
high_outputs = do_suppression(inference, conf_thres=hyperparams['
|
103 |
high_preds, real_width, real_height = format_predictions(image_shapes, high_outputs, width, height, gp=gp)
|
104 |
|
105 |
results = do_associative_tracking(
|
@@ -109,7 +108,9 @@ def do_full_inference(dataloader, image_meter_width, image_meter_height, gp=None
|
|
109 |
gp=gp)
|
110 |
else:
|
111 |
|
112 |
-
|
|
|
|
|
113 |
|
114 |
if hyperparams['associative_tracker'] == "Confidence Boost":
|
115 |
if 'boost_power' not in hyperparams: hyperparams['boost_power'] = 1
|
@@ -117,7 +118,7 @@ def do_full_inference(dataloader, image_meter_width, image_meter_height, gp=None
|
|
117 |
|
118 |
do_confidence_boost(inference, outputs, boost_power=hyperparams['boost_power'], boost_decay=hyperparams['boost_decay'], gp=gp)
|
119 |
|
120 |
-
outputs = do_suppression(inference, conf_thres=hyperparams['
|
121 |
|
122 |
all_preds, real_width, real_height = format_predictions(image_shapes, outputs, width, height, gp=gp)
|
123 |
|
|
|
54 |
|
55 |
# Load hyperparameters
|
56 |
if 'model' not in hyperparams: hyperparams['model'] = WEIGHTS
|
|
|
57 |
if 'iou_thresh' not in hyperparams: hyperparams['iou_thresh'] = NMS_IOU
|
58 |
if 'min_hits' not in hyperparams: hyperparams['min_hits'] = MIN_HITS
|
59 |
if 'max_age' not in hyperparams: hyperparams['max_age'] = MAX_AGE
|
|
|
95 |
if 'byte_low_conf' not in hyperparams: hyperparams['byte_low_conf'] = 0.1
|
96 |
if 'byte_high_conf' not in hyperparams: hyperparams['byte_high_conf'] = 0.3
|
97 |
|
98 |
+
low_outputs = do_suppression(inference, conf_thres=hyperparams['byte_low_conf'], iou_thres=hyperparams['iou_thresh'], gp=gp)
|
99 |
low_preds, real_width, real_height = format_predictions(image_shapes, low_outputs, width, height, gp=gp)
|
100 |
|
101 |
+
high_outputs = do_suppression(inference, conf_thres=hyperparams['byte_high_conf'], iou_thres=hyperparams['iou_thresh'], gp=gp)
|
102 |
high_preds, real_width, real_height = format_predictions(image_shapes, high_outputs, width, height, gp=gp)
|
103 |
|
104 |
results = do_associative_tracking(
|
|
|
108 |
gp=gp)
|
109 |
else:
|
110 |
|
111 |
+
if 'conf_thresh' not in hyperparams: hyperparams['conf_thresh'] = CONF_THRES
|
112 |
+
|
113 |
+
outputs = do_suppression(inference, conf_thres=hyperparams['conf_thresh'], iou_thres=hyperparams['iou_thresh'], gp=gp)
|
114 |
|
115 |
if hyperparams['associative_tracker'] == "Confidence Boost":
|
116 |
if 'boost_power' not in hyperparams: hyperparams['boost_power'] = 1
|
|
|
118 |
|
119 |
do_confidence_boost(inference, outputs, boost_power=hyperparams['boost_power'], boost_decay=hyperparams['boost_decay'], gp=gp)
|
120 |
|
121 |
+
outputs = do_suppression(inference, conf_thres=hyperparams['conf_thresh'], iou_thres=hyperparams['iou_thresh'], gp=gp)
|
122 |
|
123 |
all_preds, real_width, real_height = format_predictions(image_shapes, outputs, width, height, gp=gp)
|
124 |
|