oskarastrom commited on
Commit
d8d9ab6
·
1 Parent(s): ea6a784

Max length hyperparameter

Browse files
Files changed (4) hide show
  1. InferenceConfig.py +3 -1
  2. app.py +2 -1
  3. gradio_scripts/upload_ui.py +2 -1
  4. inference.py +4 -4
InferenceConfig.py CHANGED
@@ -20,6 +20,7 @@ NMS_IOU = 0.25 # NMS IOU
20
  MAX_AGE = 20 # time until missing fish get's new id
21
  MIN_HITS = 11 # minimum number of frames with a specific fish for it to count
22
  MIN_LENGTH = 0.3 # minimum fish length, in meters
 
23
  IOU_THRES = 0.01 # IOU threshold for tracking
24
  MIN_TRAVEL = 0 # Minimum distance a track has to travel
25
  DEFAULT_TRACKER = TrackerType.BYTETRACK
@@ -27,13 +28,14 @@ DEFAULT_TRACKER = TrackerType.BYTETRACK
27
  class InferenceConfig:
28
  def __init__(self,
29
  weights=WEIGHTS, conf_thresh=CONF_THRES, nms_iou=NMS_IOU,
30
- min_hits=MIN_HITS, max_age=MAX_AGE, min_length=MIN_LENGTH, min_travel=MIN_TRAVEL):
31
  self.weights = weights
32
  self.conf_thresh = conf_thresh
33
  self.nms_iou = nms_iou
34
  self.min_hits = min_hits
35
  self.max_age = max_age
36
  self.min_length = min_length
 
37
  self.min_travel = min_travel
38
 
39
  self.associative_tracker = DEFAULT_TRACKER
 
20
  MAX_AGE = 20 # time until missing fish get's new id
21
  MIN_HITS = 11 # minimum number of frames with a specific fish for it to count
22
  MIN_LENGTH = 0.3 # minimum fish length, in meters
23
+ MAX_LENGTH = 1.5 # maximum fish length, in meters
24
  IOU_THRES = 0.01 # IOU threshold for tracking
25
  MIN_TRAVEL = 0 # Minimum distance a track has to travel
26
  DEFAULT_TRACKER = TrackerType.BYTETRACK
 
28
  class InferenceConfig:
29
  def __init__(self,
30
  weights=WEIGHTS, conf_thresh=CONF_THRES, nms_iou=NMS_IOU,
31
+ min_hits=MIN_HITS, max_age=MAX_AGE, min_length=MIN_LENGTH, max_length=MAX_LENGTH, min_travel=MIN_TRAVEL):
32
  self.weights = weights
33
  self.conf_thresh = conf_thresh
34
  self.nms_iou = nms_iou
35
  self.min_hits = min_hits
36
  self.max_age = max_age
37
  self.min_length = min_length
38
+ self.max_length = max_length
39
  self.min_travel = min_travel
40
 
41
  self.associative_tracker = DEFAULT_TRACKER
app.py CHANGED
@@ -31,7 +31,7 @@ result = {}
31
 
32
 
33
  # Called when an Aris file is uploaded for inference
34
- def on_aris_input(file_list, model_id, conf_thresh, iou_thresh, min_hits, max_age, associative_tracker, boost_power, boost_decay, byte_low_conf, byte_high_conf, min_length, min_travel, output_formats):
35
 
36
  print(output_formats)
37
 
@@ -48,6 +48,7 @@ def on_aris_input(file_list, model_id, conf_thresh, iou_thresh, min_hits, max_ag
48
  min_hits = min_hits,
49
  max_age = max_age,
50
  min_length = min_length,
 
51
  min_travel = min_travel,
52
  )
53
 
 
31
 
32
 
33
  # Called when an Aris file is uploaded for inference
34
+ def on_aris_input(file_list, model_id, conf_thresh, iou_thresh, min_hits, max_age, associative_tracker, boost_power, boost_decay, byte_low_conf, byte_high_conf, min_length, max_length, min_travel, output_formats):
35
 
36
  print(output_formats)
37
 
 
48
  min_hits = min_hits,
49
  max_age = max_age,
50
  min_length = min_length,
51
+ max_length = max_length,
52
  min_travel = min_travel,
53
  )
54
 
gradio_scripts/upload_ui.py CHANGED
@@ -49,7 +49,8 @@ def Upload_Gradio(gradio_components):
49
  gr.Markdown("Other")
50
  with gr.Row():
51
  settings.append(gr.Slider(0, 3, value=default_settings.min_length, label="Min Length", info="Minimum length of fish (meters) in order for it to count"))
52
- settings.append(gr.Slider(0, 5, value=default_settings.min_travel, label="Min Travel", info="Minimum travel distance of track (meters) in order for it to count"))
 
53
 
54
  gradio_components['hyperparams'] = settings
55
 
 
49
  gr.Markdown("Other")
50
  with gr.Row():
51
  settings.append(gr.Slider(0, 3, value=default_settings.min_length, label="Min Length", info="Minimum length of fish (meters) in order for it to count"))
52
+ settings.append(gr.Slider(0, 3, value=default_settings.max_length, label="Max Length", info="Maximum length of fish (meters) in order for it to count"))
53
+ settings.append(gr.Slider(0, 10, value=default_settings.min_travel, label="Min Travel", info="Minimum travel distance of track (meters) in order for it to count"))
54
 
55
  gradio_components['hyperparams'] = settings
56
 
inference.py CHANGED
@@ -62,11 +62,11 @@ def do_full_inference(dataloader, image_meter_width, image_meter_height, gp=None
62
  if config.associative_tracker == TrackerType.BYTETRACK:
63
 
64
  # Find low confidence detections
65
- low_outputs = do_suppression(inference, image_meter_width, width, conf_thres=config.byte_low_conf, iou_thres=config.nms_iou, gp=gp)
66
  low_preds, real_width, real_height = format_predictions(image_shapes, low_outputs, width, height, gp=gp)
67
 
68
  # Find high confidence detections
69
- high_outputs = do_suppression(inference, image_meter_width, width, conf_thres=config.byte_high_conf, iou_thres=config.nms_iou, gp=gp)
70
  high_preds, real_width, real_height = format_predictions(image_shapes, high_outputs, width, height, gp=gp)
71
 
72
  # Perform associative tracking (ByteTrack)
@@ -78,7 +78,7 @@ def do_full_inference(dataloader, image_meter_width, image_meter_height, gp=None
78
  else:
79
 
80
  # Find confident detections
81
- outputs = do_suppression(inference, image_meter_width, width, conf_thres=config.conf_thresh, iou_thres=config.nms_iou, gp=gp)
82
 
83
  if config.associative_tracker == TrackerType.CONF_BOOST:
84
 
@@ -86,7 +86,7 @@ def do_full_inference(dataloader, image_meter_width, image_meter_height, gp=None
86
  do_confidence_boost(inference, outputs, boost_power=config.boost_power, boost_decay=config.boost_decay, gp=gp)
87
 
88
  # Find confident detections from boosted list
89
- outputs = do_suppression(inference, image_meter_width, width, conf_thres=config.conf_thresh, iou_thres=config.nms_iou, gp=gp)
90
 
91
  # Format confident detections
92
  all_preds, real_width, real_height = format_predictions(image_shapes, outputs, width, height, gp=gp)
 
62
  if config.associative_tracker == TrackerType.BYTETRACK:
63
 
64
  # Find low confidence detections
65
+ low_outputs = do_suppression(inference, image_meter_width, width, conf_thres=config.byte_low_conf, iou_thres=config.nms_iou, max_length=config.max_length, gp=gp)
66
  low_preds, real_width, real_height = format_predictions(image_shapes, low_outputs, width, height, gp=gp)
67
 
68
  # Find high confidence detections
69
+ high_outputs = do_suppression(inference, image_meter_width, width, conf_thres=config.byte_high_conf, iou_thres=config.nms_iou, max_length=config.max_length, gp=gp)
70
  high_preds, real_width, real_height = format_predictions(image_shapes, high_outputs, width, height, gp=gp)
71
 
72
  # Perform associative tracking (ByteTrack)
 
78
  else:
79
 
80
  # Find confident detections
81
+ outputs = do_suppression(inference, image_meter_width, width, conf_thres=config.conf_thresh, iou_thres=config.nms_iou, max_length=config.max_length, gp=gp)
82
 
83
  if config.associative_tracker == TrackerType.CONF_BOOST:
84
 
 
86
  do_confidence_boost(inference, outputs, boost_power=config.boost_power, boost_decay=config.boost_decay, gp=gp)
87
 
88
  # Find confident detections from boosted list
89
+ outputs = do_suppression(inference, image_meter_width, width, conf_thres=config.conf_thresh, iou_thres=config.nms_iou, max_length=config.max_length, gp=gp)
90
 
91
  # Format confident detections
92
  all_preds, real_width, real_height = format_predictions(image_shapes, outputs, width, height, gp=gp)