oskarastrom commited on
Commit
c376f3c
·
1 Parent(s): a63e231

Autoload parameters

Browse files
Files changed (2) hide show
  1. InferenceConfig.py +13 -1
  2. gradio_scripts/upload_ui.py +14 -12
InferenceConfig.py CHANGED
@@ -5,6 +5,11 @@ class TrackerType(Enum):
5
  CONF_BOOST = 1
6
  BYTETRACK = 2
7
 
 
 
 
 
 
8
  ### Configuration options
9
  WEIGHTS = 'models/v5m_896_300best.pt'
10
  # will need to configure these based on GPU hardware
@@ -16,7 +21,7 @@ MAX_AGE = 20 # time until missing fish get's new id
16
  MIN_HITS = 11 # minimum number of frames with a specific fish for it to count
17
  MIN_LENGTH = 0.3 # minimum fish length, in meters
18
  IOU_THRES = 0.01 # IOU threshold for tracking
19
- MIN_TRAVEL = -1 # Minimum distance a track has to travel
20
  DEFAULT_TRACKER = TrackerType.BYTETRACK
21
 
22
  class InferenceConfig:
@@ -50,6 +55,13 @@ class InferenceConfig:
50
  self.byte_low_conf = low
51
  self.byte_high_conf = high
52
 
 
 
 
 
 
 
 
53
  def to_dict(self):
54
  dict = {
55
  'weights': self.weights,
 
5
  CONF_BOOST = 1
6
  BYTETRACK = 2
7
 
8
+ def toString(val):
9
+ if val == TrackerType.NONE: return "None"
10
+ if val == TrackerType.CONF_BOOST: return "Confidence Boost"
11
+ if val == TrackerType.BYTETRACK: return "ByteTrack"
12
+
13
  ### Configuration options
14
  WEIGHTS = 'models/v5m_896_300best.pt'
15
  # will need to configure these based on GPU hardware
 
21
  MIN_HITS = 11 # minimum number of frames with a specific fish for it to count
22
  MIN_LENGTH = 0.3 # minimum fish length, in meters
23
  IOU_THRES = 0.01 # IOU threshold for tracking
24
+ MIN_TRAVEL = 0 # Minimum distance a track has to travel
25
  DEFAULT_TRACKER = TrackerType.BYTETRACK
26
 
27
  class InferenceConfig:
 
55
  self.byte_low_conf = low
56
  self.byte_high_conf = high
57
 
58
+ def find_model(self, model_list):
59
+ for model_name, model_path in enumerate(model_list):
60
+ if model_path == self.weights:
61
+ return model_name
62
+ return None
63
+
64
+
65
  def to_dict(self):
66
  dict = {
67
  'weights': self.weights,
gradio_scripts/upload_ui.py CHANGED
@@ -1,5 +1,6 @@
1
  import gradio as gr
2
  from gradio_scripts.file_reader import File
 
3
 
4
 
5
  models = {
@@ -17,35 +18,36 @@ def Upload_Gradio(gradio_components):
17
 
18
  gr.HTML("<p align='center' style='font-size: large;font-style: italic;'>Submit an .aris file to analyze result.</p>")
19
 
 
20
  settings = []
21
  with gr.Accordion("Advanced Settings", open=False):
22
- settings.append(gr.Dropdown(label="Model", value="master", choices=list(models.keys())))
23
 
24
  gr.Markdown("Detection Parameters")
25
  with gr.Row():
26
- settings.append(gr.Slider(0, 1, value=0.05, label="Confidence Threshold", info="Confidence cutoff for detection boxes"))
27
- settings.append(gr.Slider(0, 1, value=0.2, label="NMS IoU", info="IoU threshold for non-max suppression"))
28
 
29
  gr.Markdown("Tracking Parameters")
30
  with gr.Row():
31
- settings.append(gr.Slider(0, 100, value=16, label="Min Hits", info="Minimum number of frames a fish has to appear in to count"))
32
- settings.append(gr.Slider(0, 100, value=14, label="Max Age", info="Max age of occlusion before track is split"))
33
 
34
- tracker = gr.Dropdown(["None", "Confidence Boost", "ByteTrack"], label="Associative Tracking", value="None")
35
  settings.append(tracker)
36
  with gr.Row(visible=False) as track_row:
37
- settings.append(gr.Slider(0, 5, value=1, label="Boost Power", info=""))
38
- settings.append(gr.Slider(0, 1, value=1, label="Boost Decay", info=""))
39
  tracker.change(lambda x: gr.update(visible=(x=="Confidence Boost")), tracker, track_row)
40
  with gr.Row(visible=False) as track_row:
41
- settings.append(gr.Slider(0, 1, value=0.1, label="Low Conf Threshold", info=""))
42
- settings.append(gr.Slider(0, 1, value=0.3, label="High Conf Threshold", info=""))
43
  tracker.change(lambda x: gr.update(visible=(x=="ByteTrack")), tracker, track_row)
44
 
45
  gr.Markdown("Other")
46
  with gr.Row():
47
- settings.append(gr.Slider(0, 3, value=0.3, label="Min Length", info="Minimum length of fish (meters) in order for it to count"))
48
- settings.append(gr.Slider(0, 5, value=1, label="Min Travel", info="Minimum travel distance of track (meters) in order for it to count"))
49
 
50
  gradio_components['hyperparams'] = settings
51
 
 
1
  import gradio as gr
2
  from gradio_scripts.file_reader import File
3
+ from InferenceConfig import InferenceConfig, TrackerType
4
 
5
 
6
  models = {
 
18
 
19
  gr.HTML("<p align='center' style='font-size: large;font-style: italic;'>Submit an .aris file to analyze result.</p>")
20
 
21
+ default_settings = InferenceConfig()
22
  settings = []
23
  with gr.Accordion("Advanced Settings", open=False):
24
+ settings.append(gr.Dropdown(label="Model", value=default_settings.find_model(models), choices=list(models.keys())))
25
 
26
  gr.Markdown("Detection Parameters")
27
  with gr.Row():
28
+ settings.append(gr.Slider(0, 1, value=default_settings.conf_thresh, label="Confidence Threshold", info="Confidence cutoff for detection boxes"))
29
+ settings.append(gr.Slider(0, 1, value=default_settings.nms_iou, label="NMS IoU", info="IoU threshold for non-max suppression"))
30
 
31
  gr.Markdown("Tracking Parameters")
32
  with gr.Row():
33
+ settings.append(gr.Slider(0, 100, value=default_settings.min_hits, label="Min Hits", info="Minimum number of frames a fish has to appear in to count"))
34
+ settings.append(gr.Slider(0, 100, value=default_settings.max_age, label="Max Age", info="Max age of occlusion before track is split"))
35
 
36
+ tracker = gr.Dropdown(["None", "Confidence Boost", "ByteTrack"], value=TrackerType.toString(default_settings.associative_tracker), label="Associative Tracking")
37
  settings.append(tracker)
38
  with gr.Row(visible=False) as track_row:
39
+ settings.append(gr.Slider(0, 5, value=default_settings.boost_power, label="Boost Power", info=""))
40
+ settings.append(gr.Slider(0, 1, value=default_settings.boost_decay, label="Boost Decay", info=""))
41
  tracker.change(lambda x: gr.update(visible=(x=="Confidence Boost")), tracker, track_row)
42
  with gr.Row(visible=False) as track_row:
43
+ settings.append(gr.Slider(0, 1, value=default_settings.byte_low_conf, label="Low Conf Threshold", info=""))
44
+ settings.append(gr.Slider(0, 1, value=default_settings.byte_high_conf, label="High Conf Threshold", info=""))
45
  tracker.change(lambda x: gr.update(visible=(x=="ByteTrack")), tracker, track_row)
46
 
47
  gr.Markdown("Other")
48
  with gr.Row():
49
+ settings.append(gr.Slider(0, 3, value=default_settings.min_length, label="Min Length", info="Minimum length of fish (meters) in order for it to count"))
50
+ settings.append(gr.Slider(0, 5, value=default_settings.min_travel, label="Min Travel", info="Minimum travel distance of track (meters) in order for it to count"))
51
 
52
  gradio_components['hyperparams'] = settings
53