Spaces:
Runtime error
Runtime error
Commit
·
29e11ce
1
Parent(s):
1eadfef
New Elwha Model
Browse files- app.py +6 -11
- gradio_scripts/upload_ui.py +9 -1
- inference.py +17 -6
- main.py +3 -9
- models/{YsEKvE20 → YsEKvE20.pt} +0 -0
app.py
CHANGED
@@ -21,13 +21,13 @@ state = {
|
|
21 |
'total': 1,
|
22 |
'annotation_index': -1,
|
23 |
'frame_index': 0,
|
24 |
-
'
|
25 |
}
|
26 |
result = {}
|
27 |
|
28 |
|
29 |
# Called when an Aris file is uploaded for inference
|
30 |
-
def on_aris_input(file_list, model_id, conf_thresh, iou_thresh, min_hits, max_age):
|
31 |
|
32 |
# Reset Result
|
33 |
reset_state(result, state)
|
@@ -39,6 +39,7 @@ def on_aris_input(file_list, model_id, conf_thresh, iou_thresh, min_hits, max_ag
|
|
39 |
'iou_thresh': iou_thresh,
|
40 |
'min_hits': min_hits,
|
41 |
'max_age': max_age,
|
|
|
42 |
}
|
43 |
|
44 |
print(" ")
|
@@ -138,7 +139,7 @@ def infer_next(_, progress=gr.Progress()):
|
|
138 |
file_name = file_info[0].split("/")[-1]
|
139 |
bytes = file_info[1]
|
140 |
valid, file_path, dir_name = save_data(bytes, file_name)
|
141 |
-
|
142 |
print("Directory: ", dir_name)
|
143 |
print("Aris input: ", file_path)
|
144 |
print(" ")
|
@@ -153,17 +154,11 @@ def infer_next(_, progress=gr.Progress()):
|
|
153 |
# Send uploaded file to AWS
|
154 |
upload_file(file_path, "fishcounting", "webapp_uploads/" + file_name)
|
155 |
|
156 |
-
hyperparams = state['hyperparams']
|
157 |
-
|
158 |
# Do inference
|
159 |
json_result, json_filepath, zip_filepath, video_filepath, marking_filepath = predict_task(
|
160 |
file_path,
|
161 |
-
|
162 |
-
|
163 |
-
iou_thresh = hyperparams['iou_thresh'],
|
164 |
-
min_hits = hyperparams['min_hits'],
|
165 |
-
max_age = hyperparams['max_age'],
|
166 |
-
gradio_progress=set_progress
|
167 |
)
|
168 |
|
169 |
# Store result for that file
|
|
|
21 |
'total': 1,
|
22 |
'annotation_index': -1,
|
23 |
'frame_index': 0,
|
24 |
+
'hyperparams': {}
|
25 |
}
|
26 |
result = {}
|
27 |
|
28 |
|
29 |
# Called when an Aris file is uploaded for inference
|
30 |
+
def on_aris_input(file_list, model_id, conf_thresh, iou_thresh, min_hits, max_age, use_associative):
|
31 |
|
32 |
# Reset Result
|
33 |
reset_state(result, state)
|
|
|
39 |
'iou_thresh': iou_thresh,
|
40 |
'min_hits': min_hits,
|
41 |
'max_age': max_age,
|
42 |
+
'use_associative_tracking': use_associative,
|
43 |
}
|
44 |
|
45 |
print(" ")
|
|
|
139 |
file_name = file_info[0].split("/")[-1]
|
140 |
bytes = file_info[1]
|
141 |
valid, file_path, dir_name = save_data(bytes, file_name)
|
142 |
+
|
143 |
print("Directory: ", dir_name)
|
144 |
print("Aris input: ", file_path)
|
145 |
print(" ")
|
|
|
154 |
# Send uploaded file to AWS
|
155 |
upload_file(file_path, "fishcounting", "webapp_uploads/" + file_name)
|
156 |
|
|
|
|
|
157 |
# Do inference
|
158 |
json_result, json_filepath, zip_filepath, video_filepath, marking_filepath = predict_task(
|
159 |
file_path,
|
160 |
+
hyperparams = state['hyperparams'],
|
161 |
+
gradio_progress = set_progress
|
|
|
|
|
|
|
|
|
162 |
)
|
163 |
|
164 |
# Store result for that file
|
gradio_scripts/upload_ui.py
CHANGED
@@ -4,7 +4,8 @@ from gradio_scripts.file_reader import File
|
|
4 |
|
5 |
models = {
|
6 |
'master': 'models/v5m_896_300best.pt',
|
7 |
-
'elwha': 'models/YsEE20.pt'
|
|
|
8 |
}
|
9 |
|
10 |
def Upload_Gradio(gradio_components):
|
@@ -29,6 +30,13 @@ def Upload_Gradio(gradio_components):
|
|
29 |
settings.append(gr.Slider(0, 100, value=16, label="Min Hits", info="Minimum number of frames a fish has to appear in to count"))
|
30 |
settings.append(gr.Slider(0, 100, value=14, label="Max Age", info="Max age of occlusion before track is split"))
|
31 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
gradio_components['hyperparams'] = settings
|
33 |
|
34 |
#Input field for aris submission
|
|
|
4 |
|
5 |
models = {
|
6 |
'master': 'models/v5m_896_300best.pt',
|
7 |
+
'elwha': 'models/YsEE20.pt',
|
8 |
+
'elwha_train': 'models/YsEKtE20.pt',
|
9 |
}
|
10 |
|
11 |
def Upload_Gradio(gradio_components):
|
|
|
30 |
settings.append(gr.Slider(0, 100, value=16, label="Min Hits", info="Minimum number of frames a fish has to appear in to count"))
|
31 |
settings.append(gr.Slider(0, 100, value=14, label="Max Age", info="Max age of occlusion before track is split"))
|
32 |
|
33 |
+
with gr.Row():
|
34 |
+
gr.Markdown("Associative Tracking")
|
35 |
+
settings.append(gr.Checkbox(value=False, label="Enabled"))
|
36 |
+
with gr.Row():
|
37 |
+
settings.append(gr.Slider(0, 100, value=16, label="Min Hits", info="Minimum number of frames a fish has to appear in to count"))
|
38 |
+
settings.append(gr.Slider(0, 100, value=14, label="Max Age", info="Max age of occlusion before track is split"))
|
39 |
+
|
40 |
gradio_components['hyperparams'] = settings
|
41 |
|
42 |
#Input field for aris submission
|
inference.py
CHANGED
@@ -48,9 +48,18 @@ def norm(bbox, w, h):
|
|
48 |
bb[3] /= h
|
49 |
return bb
|
50 |
|
51 |
-
def do_full_inference(dataloader, image_meter_width, image_meter_height, gp=None,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
52 |
|
53 |
-
model, device = setup_model(
|
54 |
|
55 |
load = False
|
56 |
save = False
|
@@ -78,15 +87,17 @@ def do_full_inference(dataloader, image_meter_width, image_meter_height, gp=None
|
|
78 |
return
|
79 |
|
80 |
|
81 |
-
outputs = do_suppression(inference, conf_thres=
|
82 |
|
83 |
-
|
|
|
|
|
84 |
|
85 |
-
|
86 |
|
87 |
all_preds, real_width, real_height = format_predictions(image_shapes, outputs, width, height, gp=gp)
|
88 |
|
89 |
-
results = do_tracking(all_preds, image_meter_width, image_meter_height, min_hits=min_hits, max_age=max_age, gp=gp)
|
90 |
|
91 |
return results
|
92 |
|
|
|
48 |
bb[3] /= h
|
49 |
return bb
|
50 |
|
51 |
+
def do_full_inference(dataloader, image_meter_width, image_meter_height, gp=None, hyperparams={}):
|
52 |
+
|
53 |
+
# Load hyperparameters
|
54 |
+
if 'model' not in hyperparams: hyperparams['model'] = WEIGHTS
|
55 |
+
if 'conf_thresh' not in hyperparams: hyperparams['conf_tresh'] = CONF_THRES
|
56 |
+
if 'iou_thresh' not in hyperparams: hyperparams['iou_thresh'] = NMS_IOU
|
57 |
+
if 'min_hits' not in hyperparams: hyperparams['min_hits'] = MIN_HITS
|
58 |
+
if 'max_age' not in hyperparams: hyperparams['max_age'] = MAX_AGE
|
59 |
+
if 'use_associative_tracking' not in hyperparams: hyperparams['use_associative_tracking'] = False
|
60 |
+
if 'AT_decay' not in hyperparams: hyperparams['AT_decay'] = MIN_HITS
|
61 |
|
62 |
+
model, device = setup_model(hyperparams['model'])
|
63 |
|
64 |
load = False
|
65 |
save = False
|
|
|
87 |
return
|
88 |
|
89 |
|
90 |
+
outputs = do_suppression(inference, conf_thres=hyperparams['conf_tresh'], iou_thres=hyperparams['iou_tresh'], gp=gp)
|
91 |
|
92 |
+
if hyperparams['use_associative_tracking']:
|
93 |
+
|
94 |
+
do_confidence_boost(inference, outputs, gp=gp)
|
95 |
|
96 |
+
outputs = do_suppression(inference, conf_thres=hyperparams['conf_tresh'], iou_thres=hyperparams['iou_tresh'], gp=gp)
|
97 |
|
98 |
all_preds, real_width, real_height = format_predictions(image_shapes, outputs, width, height, gp=gp)
|
99 |
|
100 |
+
results = do_tracking(all_preds, image_meter_width, image_meter_height, min_hits=hyperparams['min_hits'], max_age=hyperparams['max_age'], gp=gp)
|
101 |
|
102 |
return results
|
103 |
|
main.py
CHANGED
@@ -7,7 +7,7 @@ from dataloader import create_dataloader_aris
|
|
7 |
from inference import do_full_inference, json_dump_round_float
|
8 |
from visualizer import generate_video_batches
|
9 |
|
10 |
-
def predict_task(filepath,
|
11 |
"""
|
12 |
Main processing task to be run in gradio
|
13 |
- Writes aris frames to dirname(filepath)/frames/{i}.jpg
|
@@ -45,18 +45,12 @@ def predict_task(filepath, weights, conf_thresh, iou_thresh, min_hits, max_age,
|
|
45 |
frame_rate = dataset.didson.info['framerate']
|
46 |
|
47 |
# run detection + tracking
|
48 |
-
results = do_full_inference(dataloader, image_meter_width, image_meter_height, gp=gradio_progress,
|
49 |
|
50 |
# re-index results if desired - this should be done before writing the file
|
51 |
results = prep_for_mm(results)
|
52 |
results = add_metadata_to_result(filepath, results)
|
53 |
-
results['metadata']['hyperparameters'] =
|
54 |
-
'model': weights,
|
55 |
-
'conf_thresh': conf_thresh,
|
56 |
-
'iou_thresh': iou_thresh,
|
57 |
-
'min_hits': min_hits,
|
58 |
-
'max_age': max_age
|
59 |
-
}
|
60 |
|
61 |
# write output to disk
|
62 |
json_dump_round_float(results, results_filepath)
|
|
|
7 |
from inference import do_full_inference, json_dump_round_float
|
8 |
from visualizer import generate_video_batches
|
9 |
|
10 |
+
def predict_task(filepath, hyperparams, gradio_progress=None):
|
11 |
"""
|
12 |
Main processing task to be run in gradio
|
13 |
- Writes aris frames to dirname(filepath)/frames/{i}.jpg
|
|
|
45 |
frame_rate = dataset.didson.info['framerate']
|
46 |
|
47 |
# run detection + tracking
|
48 |
+
results = do_full_inference(dataloader, image_meter_width, image_meter_height, gp=gradio_progress, hyperparams=hyperparams)
|
49 |
|
50 |
# re-index results if desired - this should be done before writing the file
|
51 |
results = prep_for_mm(results)
|
52 |
results = add_metadata_to_result(filepath, results)
|
53 |
+
results['metadata']['hyperparameters'] = hyperparams
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
|
55 |
# write output to disk
|
56 |
json_dump_round_float(results, results_filepath)
|
models/{YsEKvE20 → YsEKvE20.pt}
RENAMED
File without changes
|