oskarastrom commited on
Commit
29e11ce
·
1 Parent(s): 1eadfef

New Elwha Model

Browse files
app.py CHANGED
@@ -21,13 +21,13 @@ state = {
21
  'total': 1,
22
  'annotation_index': -1,
23
  'frame_index': 0,
24
- 'model': None
25
  }
26
  result = {}
27
 
28
 
29
  # Called when an Aris file is uploaded for inference
30
- def on_aris_input(file_list, model_id, conf_thresh, iou_thresh, min_hits, max_age):
31
 
32
  # Reset Result
33
  reset_state(result, state)
@@ -39,6 +39,7 @@ def on_aris_input(file_list, model_id, conf_thresh, iou_thresh, min_hits, max_ag
39
  'iou_thresh': iou_thresh,
40
  'min_hits': min_hits,
41
  'max_age': max_age,
 
42
  }
43
 
44
  print(" ")
@@ -138,7 +139,7 @@ def infer_next(_, progress=gr.Progress()):
138
  file_name = file_info[0].split("/")[-1]
139
  bytes = file_info[1]
140
  valid, file_path, dir_name = save_data(bytes, file_name)
141
-
142
  print("Directory: ", dir_name)
143
  print("Aris input: ", file_path)
144
  print(" ")
@@ -153,17 +154,11 @@ def infer_next(_, progress=gr.Progress()):
153
  # Send uploaded file to AWS
154
  upload_file(file_path, "fishcounting", "webapp_uploads/" + file_name)
155
 
156
- hyperparams = state['hyperparams']
157
-
158
  # Do inference
159
  json_result, json_filepath, zip_filepath, video_filepath, marking_filepath = predict_task(
160
  file_path,
161
- weights = hyperparams['model'],
162
- conf_thresh = hyperparams['conf_thresh'],
163
- iou_thresh = hyperparams['iou_thresh'],
164
- min_hits = hyperparams['min_hits'],
165
- max_age = hyperparams['max_age'],
166
- gradio_progress=set_progress
167
  )
168
 
169
  # Store result for that file
 
21
  'total': 1,
22
  'annotation_index': -1,
23
  'frame_index': 0,
24
+ 'hyperparams': {}
25
  }
26
  result = {}
27
 
28
 
29
  # Called when an Aris file is uploaded for inference
30
+ def on_aris_input(file_list, model_id, conf_thresh, iou_thresh, min_hits, max_age, use_associative):
31
 
32
  # Reset Result
33
  reset_state(result, state)
 
39
  'iou_thresh': iou_thresh,
40
  'min_hits': min_hits,
41
  'max_age': max_age,
42
+ 'use_associative_tracking': use_associative,
43
  }
44
 
45
  print(" ")
 
139
  file_name = file_info[0].split("/")[-1]
140
  bytes = file_info[1]
141
  valid, file_path, dir_name = save_data(bytes, file_name)
142
+
143
  print("Directory: ", dir_name)
144
  print("Aris input: ", file_path)
145
  print(" ")
 
154
  # Send uploaded file to AWS
155
  upload_file(file_path, "fishcounting", "webapp_uploads/" + file_name)
156
 
 
 
157
  # Do inference
158
  json_result, json_filepath, zip_filepath, video_filepath, marking_filepath = predict_task(
159
  file_path,
160
+ hyperparams = state['hyperparams'],
161
+ gradio_progress = set_progress
 
 
 
 
162
  )
163
 
164
  # Store result for that file
gradio_scripts/upload_ui.py CHANGED
@@ -4,7 +4,8 @@ from gradio_scripts.file_reader import File
4
 
5
  models = {
6
  'master': 'models/v5m_896_300best.pt',
7
- 'elwha': 'models/YsEE20.pt'
 
8
  }
9
 
10
  def Upload_Gradio(gradio_components):
@@ -29,6 +30,13 @@ def Upload_Gradio(gradio_components):
29
  settings.append(gr.Slider(0, 100, value=16, label="Min Hits", info="Minimum number of frames a fish has to appear in to count"))
30
  settings.append(gr.Slider(0, 100, value=14, label="Max Age", info="Max age of occlusion before track is split"))
31
 
 
 
 
 
 
 
 
32
  gradio_components['hyperparams'] = settings
33
 
34
  #Input field for aris submission
 
4
 
5
  models = {
6
  'master': 'models/v5m_896_300best.pt',
7
+ 'elwha': 'models/YsEE20.pt',
8
+ 'elwha_train': 'models/YsEKtE20.pt',
9
  }
10
 
11
  def Upload_Gradio(gradio_components):
 
30
  settings.append(gr.Slider(0, 100, value=16, label="Min Hits", info="Minimum number of frames a fish has to appear in to count"))
31
  settings.append(gr.Slider(0, 100, value=14, label="Max Age", info="Max age of occlusion before track is split"))
32
 
33
+ with gr.Row():
34
+ gr.Markdown("Associative Tracking")
35
+ settings.append(gr.Checkbox(value=False, label="Enabled"))
36
+ with gr.Row():
37
+ settings.append(gr.Slider(0, 100, value=16, label="Min Hits", info="Minimum number of frames a fish has to appear in to count"))
38
+ settings.append(gr.Slider(0, 100, value=14, label="Max Age", info="Max age of occlusion before track is split"))
39
+
40
  gradio_components['hyperparams'] = settings
41
 
42
  #Input field for aris submission
inference.py CHANGED
@@ -48,9 +48,18 @@ def norm(bbox, w, h):
48
  bb[3] /= h
49
  return bb
50
 
51
- def do_full_inference(dataloader, image_meter_width, image_meter_height, gp=None, weights=WEIGHTS, conf_thresh=CONF_THRES, nms_iou=NMS_IOU, min_hits=MIN_HITS, max_age=MAX_AGE):
 
 
 
 
 
 
 
 
 
52
 
53
- model, device = setup_model(weights)
54
 
55
  load = False
56
  save = False
@@ -78,15 +87,17 @@ def do_full_inference(dataloader, image_meter_width, image_meter_height, gp=None
78
  return
79
 
80
 
81
- outputs = do_suppression(inference, conf_thres=conf_thresh, iou_thres=nms_iou, gp=gp)
82
 
83
- #do_confidence_boost(inference, outputs, gp=gp)
 
 
84
 
85
- #new_outputs = do_suppression(inference, conf_thres=conf_thresh, iou_thres=nms_iou, gp=gp)
86
 
87
  all_preds, real_width, real_height = format_predictions(image_shapes, outputs, width, height, gp=gp)
88
 
89
- results = do_tracking(all_preds, image_meter_width, image_meter_height, min_hits=min_hits, max_age=max_age, gp=gp)
90
 
91
  return results
92
 
 
48
  bb[3] /= h
49
  return bb
50
 
51
+ def do_full_inference(dataloader, image_meter_width, image_meter_height, gp=None, hyperparams={}):
52
+
53
+ # Load hyperparameters
54
+ if 'model' not in hyperparams: hyperparams['model'] = WEIGHTS
55
+ if 'conf_thresh' not in hyperparams: hyperparams['conf_tresh'] = CONF_THRES
56
+ if 'iou_thresh' not in hyperparams: hyperparams['iou_thresh'] = NMS_IOU
57
+ if 'min_hits' not in hyperparams: hyperparams['min_hits'] = MIN_HITS
58
+ if 'max_age' not in hyperparams: hyperparams['max_age'] = MAX_AGE
59
+ if 'use_associative_tracking' not in hyperparams: hyperparams['use_associative_tracking'] = False
60
+ if 'AT_decay' not in hyperparams: hyperparams['AT_decay'] = MIN_HITS
61
 
62
+ model, device = setup_model(hyperparams['model'])
63
 
64
  load = False
65
  save = False
 
87
  return
88
 
89
 
90
+ outputs = do_suppression(inference, conf_thres=hyperparams['conf_tresh'], iou_thres=hyperparams['iou_tresh'], gp=gp)
91
 
92
+ if hyperparams['use_associative_tracking']:
93
+
94
+ do_confidence_boost(inference, outputs, gp=gp)
95
 
96
+ outputs = do_suppression(inference, conf_thres=hyperparams['conf_tresh'], iou_thres=hyperparams['iou_tresh'], gp=gp)
97
 
98
  all_preds, real_width, real_height = format_predictions(image_shapes, outputs, width, height, gp=gp)
99
 
100
+ results = do_tracking(all_preds, image_meter_width, image_meter_height, min_hits=hyperparams['min_hits'], max_age=hyperparams['max_age'], gp=gp)
101
 
102
  return results
103
 
main.py CHANGED
@@ -7,7 +7,7 @@ from dataloader import create_dataloader_aris
7
  from inference import do_full_inference, json_dump_round_float
8
  from visualizer import generate_video_batches
9
 
10
- def predict_task(filepath, weights, conf_thresh, iou_thresh, min_hits, max_age, gradio_progress=None):
11
  """
12
  Main processing task to be run in gradio
13
  - Writes aris frames to dirname(filepath)/frames/{i}.jpg
@@ -45,18 +45,12 @@ def predict_task(filepath, weights, conf_thresh, iou_thresh, min_hits, max_age,
45
  frame_rate = dataset.didson.info['framerate']
46
 
47
  # run detection + tracking
48
- results = do_full_inference(dataloader, image_meter_width, image_meter_height, gp=gradio_progress, weights=weights, conf_thresh=conf_thresh, nms_iou=iou_thresh, min_hits=min_hits, max_age=max_age)
49
 
50
  # re-index results if desired - this should be done before writing the file
51
  results = prep_for_mm(results)
52
  results = add_metadata_to_result(filepath, results)
53
- results['metadata']['hyperparameters'] = {
54
- 'model': weights,
55
- 'conf_thresh': conf_thresh,
56
- 'iou_thresh': iou_thresh,
57
- 'min_hits': min_hits,
58
- 'max_age': max_age
59
- }
60
 
61
  # write output to disk
62
  json_dump_round_float(results, results_filepath)
 
7
  from inference import do_full_inference, json_dump_round_float
8
  from visualizer import generate_video_batches
9
 
10
+ def predict_task(filepath, hyperparams, gradio_progress=None):
11
  """
12
  Main processing task to be run in gradio
13
  - Writes aris frames to dirname(filepath)/frames/{i}.jpg
 
45
  frame_rate = dataset.didson.info['framerate']
46
 
47
  # run detection + tracking
48
+ results = do_full_inference(dataloader, image_meter_width, image_meter_height, gp=gradio_progress, hyperparams=hyperparams)
49
 
50
  # re-index results if desired - this should be done before writing the file
51
  results = prep_for_mm(results)
52
  results = add_metadata_to_result(filepath, results)
53
+ results['metadata']['hyperparameters'] = hyperparams
 
 
 
 
 
 
54
 
55
  # write output to disk
56
  json_dump_round_float(results, results_filepath)
models/{YsEKvE20 → YsEKvE20.pt} RENAMED
File without changes