Spaces:
Runtime error
Runtime error
Commit
·
d8d9ab6
1
Parent(s):
ea6a784
Max length hyperparameter
Browse files- InferenceConfig.py +3 -1
- app.py +2 -1
- gradio_scripts/upload_ui.py +2 -1
- inference.py +4 -4
InferenceConfig.py
CHANGED
@@ -20,6 +20,7 @@ NMS_IOU = 0.25 # NMS IOU
|
|
20 |
MAX_AGE = 20 # time until missing fish get's new id
|
21 |
MIN_HITS = 11 # minimum number of frames with a specific fish for it to count
|
22 |
MIN_LENGTH = 0.3 # minimum fish length, in meters
|
|
|
23 |
IOU_THRES = 0.01 # IOU threshold for tracking
|
24 |
MIN_TRAVEL = 0 # Minimum distance a track has to travel
|
25 |
DEFAULT_TRACKER = TrackerType.BYTETRACK
|
@@ -27,13 +28,14 @@ DEFAULT_TRACKER = TrackerType.BYTETRACK
|
|
27 |
class InferenceConfig:
|
28 |
def __init__(self,
|
29 |
weights=WEIGHTS, conf_thresh=CONF_THRES, nms_iou=NMS_IOU,
|
30 |
-
min_hits=MIN_HITS, max_age=MAX_AGE, min_length=MIN_LENGTH, min_travel=MIN_TRAVEL):
|
31 |
self.weights = weights
|
32 |
self.conf_thresh = conf_thresh
|
33 |
self.nms_iou = nms_iou
|
34 |
self.min_hits = min_hits
|
35 |
self.max_age = max_age
|
36 |
self.min_length = min_length
|
|
|
37 |
self.min_travel = min_travel
|
38 |
|
39 |
self.associative_tracker = DEFAULT_TRACKER
|
|
|
20 |
MAX_AGE = 20 # time until missing fish get's new id
|
21 |
MIN_HITS = 11 # minimum number of frames with a specific fish for it to count
|
22 |
MIN_LENGTH = 0.3 # minimum fish length, in meters
|
23 |
+
MAX_LENGTH = 1.5 # maximum fish length, in meters
|
24 |
IOU_THRES = 0.01 # IOU threshold for tracking
|
25 |
MIN_TRAVEL = 0 # Minimum distance a track has to travel
|
26 |
DEFAULT_TRACKER = TrackerType.BYTETRACK
|
|
|
28 |
class InferenceConfig:
|
29 |
def __init__(self,
|
30 |
weights=WEIGHTS, conf_thresh=CONF_THRES, nms_iou=NMS_IOU,
|
31 |
+
min_hits=MIN_HITS, max_age=MAX_AGE, min_length=MIN_LENGTH, max_length=MAX_LENGTH, min_travel=MIN_TRAVEL):
|
32 |
self.weights = weights
|
33 |
self.conf_thresh = conf_thresh
|
34 |
self.nms_iou = nms_iou
|
35 |
self.min_hits = min_hits
|
36 |
self.max_age = max_age
|
37 |
self.min_length = min_length
|
38 |
+
self.max_length = max_length
|
39 |
self.min_travel = min_travel
|
40 |
|
41 |
self.associative_tracker = DEFAULT_TRACKER
|
app.py
CHANGED
@@ -31,7 +31,7 @@ result = {}
|
|
31 |
|
32 |
|
33 |
# Called when an Aris file is uploaded for inference
|
34 |
-
def on_aris_input(file_list, model_id, conf_thresh, iou_thresh, min_hits, max_age, associative_tracker, boost_power, boost_decay, byte_low_conf, byte_high_conf, min_length, min_travel, output_formats):
|
35 |
|
36 |
print(output_formats)
|
37 |
|
@@ -48,6 +48,7 @@ def on_aris_input(file_list, model_id, conf_thresh, iou_thresh, min_hits, max_ag
|
|
48 |
min_hits = min_hits,
|
49 |
max_age = max_age,
|
50 |
min_length = min_length,
|
|
|
51 |
min_travel = min_travel,
|
52 |
)
|
53 |
|
|
|
31 |
|
32 |
|
33 |
# Called when an Aris file is uploaded for inference
|
34 |
+
def on_aris_input(file_list, model_id, conf_thresh, iou_thresh, min_hits, max_age, associative_tracker, boost_power, boost_decay, byte_low_conf, byte_high_conf, min_length, max_length, min_travel, output_formats):
|
35 |
|
36 |
print(output_formats)
|
37 |
|
|
|
48 |
min_hits = min_hits,
|
49 |
max_age = max_age,
|
50 |
min_length = min_length,
|
51 |
+
max_length = max_length,
|
52 |
min_travel = min_travel,
|
53 |
)
|
54 |
|
gradio_scripts/upload_ui.py
CHANGED
@@ -49,7 +49,8 @@ def Upload_Gradio(gradio_components):
|
|
49 |
gr.Markdown("Other")
|
50 |
with gr.Row():
|
51 |
settings.append(gr.Slider(0, 3, value=default_settings.min_length, label="Min Length", info="Minimum length of fish (meters) in order for it to count"))
|
52 |
-
settings.append(gr.Slider(0,
|
|
|
53 |
|
54 |
gradio_components['hyperparams'] = settings
|
55 |
|
|
|
49 |
gr.Markdown("Other")
|
50 |
with gr.Row():
|
51 |
settings.append(gr.Slider(0, 3, value=default_settings.min_length, label="Min Length", info="Minimum length of fish (meters) in order for it to count"))
|
52 |
+
settings.append(gr.Slider(0, 3, value=default_settings.max_length, label="Max Length", info="Maximum length of fish (meters) in order for it to count"))
|
53 |
+
settings.append(gr.Slider(0, 10, value=default_settings.min_travel, label="Min Travel", info="Minimum travel distance of track (meters) in order for it to count"))
|
54 |
|
55 |
gradio_components['hyperparams'] = settings
|
56 |
|
inference.py
CHANGED
@@ -62,11 +62,11 @@ def do_full_inference(dataloader, image_meter_width, image_meter_height, gp=None
|
|
62 |
if config.associative_tracker == TrackerType.BYTETRACK:
|
63 |
|
64 |
# Find low confidence detections
|
65 |
-
low_outputs = do_suppression(inference, image_meter_width, width, conf_thres=config.byte_low_conf, iou_thres=config.nms_iou, gp=gp)
|
66 |
low_preds, real_width, real_height = format_predictions(image_shapes, low_outputs, width, height, gp=gp)
|
67 |
|
68 |
# Find high confidence detections
|
69 |
-
high_outputs = do_suppression(inference, image_meter_width, width, conf_thres=config.byte_high_conf, iou_thres=config.nms_iou, gp=gp)
|
70 |
high_preds, real_width, real_height = format_predictions(image_shapes, high_outputs, width, height, gp=gp)
|
71 |
|
72 |
# Perform associative tracking (ByteTrack)
|
@@ -78,7 +78,7 @@ def do_full_inference(dataloader, image_meter_width, image_meter_height, gp=None
|
|
78 |
else:
|
79 |
|
80 |
# Find confident detections
|
81 |
-
outputs = do_suppression(inference, image_meter_width, width, conf_thres=config.conf_thresh, iou_thres=config.nms_iou, gp=gp)
|
82 |
|
83 |
if config.associative_tracker == TrackerType.CONF_BOOST:
|
84 |
|
@@ -86,7 +86,7 @@ def do_full_inference(dataloader, image_meter_width, image_meter_height, gp=None
|
|
86 |
do_confidence_boost(inference, outputs, boost_power=config.boost_power, boost_decay=config.boost_decay, gp=gp)
|
87 |
|
88 |
# Find confident detections from boosted list
|
89 |
-
outputs = do_suppression(inference, image_meter_width, width, conf_thres=config.conf_thresh, iou_thres=config.nms_iou, gp=gp)
|
90 |
|
91 |
# Format confident detections
|
92 |
all_preds, real_width, real_height = format_predictions(image_shapes, outputs, width, height, gp=gp)
|
|
|
62 |
if config.associative_tracker == TrackerType.BYTETRACK:
|
63 |
|
64 |
# Find low confidence detections
|
65 |
+
low_outputs = do_suppression(inference, image_meter_width, width, conf_thres=config.byte_low_conf, iou_thres=config.nms_iou, max_length=config.max_length, gp=gp)
|
66 |
low_preds, real_width, real_height = format_predictions(image_shapes, low_outputs, width, height, gp=gp)
|
67 |
|
68 |
# Find high confidence detections
|
69 |
+
high_outputs = do_suppression(inference, image_meter_width, width, conf_thres=config.byte_high_conf, iou_thres=config.nms_iou, max_length=config.max_length, gp=gp)
|
70 |
high_preds, real_width, real_height = format_predictions(image_shapes, high_outputs, width, height, gp=gp)
|
71 |
|
72 |
# Perform associative tracking (ByteTrack)
|
|
|
78 |
else:
|
79 |
|
80 |
# Find confident detections
|
81 |
+
outputs = do_suppression(inference, image_meter_width, width, conf_thres=config.conf_thresh, iou_thres=config.nms_iou, max_length=config.max_length, gp=gp)
|
82 |
|
83 |
if config.associative_tracker == TrackerType.CONF_BOOST:
|
84 |
|
|
|
86 |
do_confidence_boost(inference, outputs, boost_power=config.boost_power, boost_decay=config.boost_decay, gp=gp)
|
87 |
|
88 |
# Find confident detections from boosted list
|
89 |
+
outputs = do_suppression(inference, image_meter_width, width, conf_thres=config.conf_thresh, iou_thres=config.nms_iou, max_length=config.max_length, gp=gp)
|
90 |
|
91 |
# Format confident detections
|
92 |
all_preds, real_width, real_height = format_predictions(image_shapes, outputs, width, height, gp=gp)
|