oskarastrom commited on
Commit
e991912
·
1 Parent(s): 7e4e0ac

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +6 -5
inference.py CHANGED
@@ -54,7 +54,6 @@ def do_full_inference(dataloader, image_meter_width, image_meter_height, gp=None
54
 
55
  # Load hyperparameters
56
  if 'model' not in hyperparams: hyperparams['model'] = WEIGHTS
57
- if 'conf_thresh' not in hyperparams: hyperparams['conf_thresh'] = CONF_THRES
58
  if 'iou_thresh' not in hyperparams: hyperparams['iou_thresh'] = NMS_IOU
59
  if 'min_hits' not in hyperparams: hyperparams['min_hits'] = MIN_HITS
60
  if 'max_age' not in hyperparams: hyperparams['max_age'] = MAX_AGE
@@ -96,10 +95,10 @@ def do_full_inference(dataloader, image_meter_width, image_meter_height, gp=None
96
  if 'byte_low_conf' not in hyperparams: hyperparams['byte_low_conf'] = 0.1
97
  if 'byte_high_conf' not in hyperparams: hyperparams['byte_high_conf'] = 0.3
98
 
99
- low_outputs = do_suppression(inference, conf_thres=hyperparams['low_conf_threshold'], iou_thres=hyperparams['iou_thresh'], gp=gp)
100
  low_preds, real_width, real_height = format_predictions(image_shapes, low_outputs, width, height, gp=gp)
101
 
102
- high_outputs = do_suppression(inference, conf_thres=hyperparams['high_conf_threshold'], iou_thres=hyperparams['iou_thresh'], gp=gp)
103
  high_preds, real_width, real_height = format_predictions(image_shapes, high_outputs, width, height, gp=gp)
104
 
105
  results = do_associative_tracking(
@@ -109,7 +108,9 @@ def do_full_inference(dataloader, image_meter_width, image_meter_height, gp=None
109
  gp=gp)
110
  else:
111
 
112
- outputs = do_suppression(inference, conf_thres=hyperparams['conf_threshold'], iou_thres=hyperparams['iou_thresh'], gp=gp)
 
 
113
 
114
  if hyperparams['associative_tracker'] == "Confidence Boost":
115
  if 'boost_power' not in hyperparams: hyperparams['boost_power'] = 1
@@ -117,7 +118,7 @@ def do_full_inference(dataloader, image_meter_width, image_meter_height, gp=None
117
 
118
  do_confidence_boost(inference, outputs, boost_power=hyperparams['boost_power'], boost_decay=hyperparams['boost_decay'], gp=gp)
119
 
120
- outputs = do_suppression(inference, conf_thres=hyperparams['conf_threshold'], iou_thres=hyperparams['iou_thresh'], gp=gp)
121
 
122
  all_preds, real_width, real_height = format_predictions(image_shapes, outputs, width, height, gp=gp)
123
 
 
54
 
55
  # Load hyperparameters
56
  if 'model' not in hyperparams: hyperparams['model'] = WEIGHTS
 
57
  if 'iou_thresh' not in hyperparams: hyperparams['iou_thresh'] = NMS_IOU
58
  if 'min_hits' not in hyperparams: hyperparams['min_hits'] = MIN_HITS
59
  if 'max_age' not in hyperparams: hyperparams['max_age'] = MAX_AGE
 
95
  if 'byte_low_conf' not in hyperparams: hyperparams['byte_low_conf'] = 0.1
96
  if 'byte_high_conf' not in hyperparams: hyperparams['byte_high_conf'] = 0.3
97
 
98
+ low_outputs = do_suppression(inference, conf_thres=hyperparams['byte_low_conf'], iou_thres=hyperparams['iou_thresh'], gp=gp)
99
  low_preds, real_width, real_height = format_predictions(image_shapes, low_outputs, width, height, gp=gp)
100
 
101
+ high_outputs = do_suppression(inference, conf_thres=hyperparams['byte_high_conf'], iou_thres=hyperparams['iou_thresh'], gp=gp)
102
  high_preds, real_width, real_height = format_predictions(image_shapes, high_outputs, width, height, gp=gp)
103
 
104
  results = do_associative_tracking(
 
108
  gp=gp)
109
  else:
110
 
111
+ if 'conf_thresh' not in hyperparams: hyperparams['conf_thresh'] = CONF_THRES
112
+
113
+ outputs = do_suppression(inference, conf_thres=hyperparams['conf_thresh'], iou_thres=hyperparams['iou_thresh'], gp=gp)
114
 
115
  if hyperparams['associative_tracker'] == "Confidence Boost":
116
  if 'boost_power' not in hyperparams: hyperparams['boost_power'] = 1
 
118
 
119
  do_confidence_boost(inference, outputs, boost_power=hyperparams['boost_power'], boost_decay=hyperparams['boost_decay'], gp=gp)
120
 
121
+ outputs = do_suppression(inference, conf_thres=hyperparams['conf_thresh'], iou_thres=hyperparams['iou_thresh'], gp=gp)
122
 
123
  all_preds, real_width, real_height = format_predictions(image_shapes, outputs, width, height, gp=gp)
124