Spaces:
Runtime error
Runtime error
Commit
·
5657a6c
1
Parent(s):
034aaec
Added Elwha test model
Browse files- app.py +10 -5
- gradio_scripts/annotation_handler.py +1 -1
- gradio_scripts/upload_ui.py +7 -0
- models/YsEE20.pt +3 -0
app.py
CHANGED
@@ -9,7 +9,7 @@ from gradio_scripts.annotation_handler import init_frames
|
|
9 |
import json
|
10 |
from zipfile import ZipFile
|
11 |
import os
|
12 |
-
from gradio_scripts.upload_ui import Upload_Gradio
|
13 |
from gradio_scripts.result_ui import Result_Gradio, update_result, table_headers, info_headers, js_update_tab_labels
|
14 |
from dataloader import create_dataloader_aris
|
15 |
from aris import BEAM_WIDTH_DIR
|
@@ -20,18 +20,23 @@ state = {
|
|
20 |
'index': 1,
|
21 |
'total': 1,
|
22 |
'annotation_index': -1,
|
23 |
-
'frame_index': 0
|
|
|
24 |
}
|
25 |
result = {}
|
26 |
|
27 |
|
28 |
# Called when an Aris file is uploaded for inference
|
29 |
-
def on_aris_input(file_list):
|
|
|
|
|
|
|
30 |
|
31 |
# Reset Result
|
32 |
reset_state(result, state)
|
33 |
state['files'] = file_list
|
34 |
state['total'] = len(file_list)
|
|
|
35 |
|
36 |
# Update loading_space to start inference on first file
|
37 |
return {
|
@@ -139,7 +144,7 @@ def infer_next(_, progress=gr.Progress()):
|
|
139 |
upload_file(file_path, "fishcounting", "webapp_uploads/" + file_name)
|
140 |
|
141 |
# Do inference
|
142 |
-
json_result, json_filepath, zip_filepath, video_filepath, marking_filepath = predict_task(file_path, gradio_progress=set_progress)
|
143 |
|
144 |
# Store result for that file
|
145 |
result['json_result'].append(json_result)
|
@@ -365,7 +370,7 @@ with demo:
|
|
365 |
inference_comps = [inference_handler, master_tabs, components['cancelBtn'], components['skipBtn']]
|
366 |
|
367 |
# When a file is uploaded to the input, tell the inference_handler to start inference
|
368 |
-
input.upload(on_aris_input, input, inference_comps)
|
369 |
|
370 |
# When inference handler updates, tell result_handler to show the new result
|
371 |
# Also, add inference_handler as the output in order to have it display the progress
|
|
|
9 |
import json
|
10 |
from zipfile import ZipFile
|
11 |
import os
|
12 |
+
from gradio_scripts.upload_ui import Upload_Gradio, models
|
13 |
from gradio_scripts.result_ui import Result_Gradio, update_result, table_headers, info_headers, js_update_tab_labels
|
14 |
from dataloader import create_dataloader_aris
|
15 |
from aris import BEAM_WIDTH_DIR
|
|
|
20 |
'index': 1,
|
21 |
'total': 1,
|
22 |
'annotation_index': -1,
|
23 |
+
'frame_index': 0,
|
24 |
+
'model': None
|
25 |
}
|
26 |
result = {}
|
27 |
|
28 |
|
29 |
# Called when an Aris file is uploaded for inference
|
30 |
+
def on_aris_input(file_list, model_id):
|
31 |
+
|
32 |
+
print(model_id)
|
33 |
+
print(models[model_id] if model_id in models else models['master'])
|
34 |
|
35 |
# Reset Result
|
36 |
reset_state(result, state)
|
37 |
state['files'] = file_list
|
38 |
state['total'] = len(file_list)
|
39 |
+
state['model'] = models[model_id] if model_id in models else models['master']
|
40 |
|
41 |
# Update loading_space to start inference on first file
|
42 |
return {
|
|
|
144 |
upload_file(file_path, "fishcounting", "webapp_uploads/" + file_name)
|
145 |
|
146 |
# Do inference
|
147 |
+
json_result, json_filepath, zip_filepath, video_filepath, marking_filepath = predict_task(file_path, weights=state['model'], gradio_progress=set_progress)
|
148 |
|
149 |
# Store result for that file
|
150 |
result['json_result'].append(json_result)
|
|
|
370 |
inference_comps = [inference_handler, master_tabs, components['cancelBtn'], components['skipBtn']]
|
371 |
|
372 |
# When a file is uploaded to the input, tell the inference_handler to start inference
|
373 |
+
input.upload(on_aris_input, [input, components['model_select']], inference_comps)
|
374 |
|
375 |
# When inference handler updates, tell result_handler to show the new result
|
376 |
# Also, add inference_handler as the output in order to have it display the progress
|
gradio_scripts/annotation_handler.py
CHANGED
@@ -39,7 +39,7 @@ def init_frames(dataset, preds, index, gp=None):
|
|
39 |
if gp: gp((index + i)/len(preds['frames']), "Extracting Frames")
|
40 |
|
41 |
# Extract frames
|
42 |
-
img_raw = dataset.didson.load_frames(start_frame=i, end_frame=i+1)[0]
|
43 |
image = cv2.resize(cv2.cvtColor(img_raw, cv2.COLOR_GRAY2BGR), (w, h))
|
44 |
#cv2.imwrite("annotation_frame_dir/" + str(i) + ".jpg", image)
|
45 |
retval, buffer = cv2.imencode('.jpg', image)
|
|
|
39 |
if gp: gp((index + i)/len(preds['frames']), "Extracting Frames")
|
40 |
|
41 |
# Extract frames
|
42 |
+
img_raw = dataset.didson.load_frames(start_frame=index+i, end_frame=index+i+1)[0]
|
43 |
image = cv2.resize(cv2.cvtColor(img_raw, cv2.COLOR_GRAY2BGR), (w, h))
|
44 |
#cv2.imwrite("annotation_frame_dir/" + str(i) + ".jpg", image)
|
45 |
retval, buffer = cv2.imencode('.jpg', image)
|
gradio_scripts/upload_ui.py
CHANGED
@@ -2,6 +2,11 @@ import gradio as gr
|
|
2 |
from gradio_scripts.file_reader import File
|
3 |
|
4 |
|
|
|
|
|
|
|
|
|
|
|
5 |
def Upload_Gradio(gradio_components):
|
6 |
with gr.Tabs():
|
7 |
|
@@ -10,6 +15,8 @@ def Upload_Gradio(gradio_components):
|
|
10 |
|
11 |
gr.HTML("<p align='center' style='font-size: large;font-style: italic;'>Submit an .aris file to analyze result.</p>")
|
12 |
|
|
|
|
|
13 |
#Input field for aris submission
|
14 |
gradio_components['input'] = File(file_types=[".aris", ".ddf"], type="binary", label="ARIS Input", file_count="multiple")
|
15 |
|
|
|
2 |
from gradio_scripts.file_reader import File
|
3 |
|
4 |
|
5 |
+
models = {
|
6 |
+
'master': 'models/v5m_896_300best.pt',
|
7 |
+
'elwha': 'models/YsEE20.pt'
|
8 |
+
}
|
9 |
+
|
10 |
def Upload_Gradio(gradio_components):
|
11 |
with gr.Tabs():
|
12 |
|
|
|
15 |
|
16 |
gr.HTML("<p align='center' style='font-size: large;font-style: italic;'>Submit an .aris file to analyze result.</p>")
|
17 |
|
18 |
+
gradio_components['model_select'] = gr.Dropdown(value="master", choices=list(models.keys()))
|
19 |
+
|
20 |
#Input field for aris submission
|
21 |
gradio_components['input'] = File(file_types=[".aris", ".ddf"], type="binary", label="ARIS Input", file_count="multiple")
|
22 |
|
models/YsEE20.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d661200690bc0e9d3e48b907fcf0a9fa9165b3e92861e275b21cc962c5970262
|
3 |
+
size 56791439
|