oskarastrom commited on
Commit
e8f4d7e
·
1 Parent(s): 2482ba4

Inference config

Browse files
Files changed (3) hide show
  1. inference.py +9 -6
  2. scripts/inferEval.py +34 -0
  3. scripts/infer_frames.py +22 -10
inference.py CHANGED
@@ -24,9 +24,12 @@ WEIGHTS = 'models/v5m_896_300best.pt'
24
  # will need to configure these based on GPU hardware
25
  BATCH_SIZE = 32
26
 
27
- conf_thres = 0.3 # detection
28
- iou_thres = 0.3 # NMS IOU
29
- min_length = 0.3 # minimum fish length, in meters
 
 
 
30
  ###
31
 
32
  def norm(bbox, w, h):
@@ -131,7 +134,7 @@ def do_detection(dataloader, model, device, gp=None, batch_size=BATCH_SIZE):
131
 
132
  return inference, width, height
133
 
134
- def do_suppression(dataloader, inference, width, height, gp=None, batch_size=BATCH_SIZE):
135
  """
136
  Args:
137
  frames_dir: a directory containing frames to be evaluated
@@ -177,7 +180,7 @@ def do_suppression(dataloader, inference, width, height, gp=None, batch_size=BAT
177
 
178
  return all_preds, real_width, real_height
179
 
180
- def do_tracking(all_preds, image_meter_width, image_meter_height, gp=None):
181
 
182
  if (gp): gp(0, "Tracking...")
183
 
@@ -188,7 +191,7 @@ def do_tracking(all_preds, image_meter_width, image_meter_height, gp=None):
188
  'image_meter_width': image_meter_width,
189
  'image_meter_height': image_meter_height
190
  }
191
- tracker = Tracker(clip_info, args={ 'max_age': 20, 'min_hits': 0, 'iou_threshold': 0.01}, min_hits=11)
192
 
193
  # Run tracking
194
  with tqdm(total=len(all_preds), desc="Running tracking", ncols=0) as pbar:
 
24
  # will need to configure these based on GPU hardware
25
  BATCH_SIZE = 32
26
 
27
+ CONF_THRES = 0.3 # detection
28
+ NMS_IOU = 0.3 # NMS IOU
29
+ MIN_LENGTH = 0.3 # minimum fish length, in meters
30
+ MAX_AGE = 20 # time until missing fish get's new id
31
+ IOU_THRES = 0.01 # IOU threshold for tracking
32
+ MIN_HITS = 11 # minimum number of frames with a specific fish for it to count
33
  ###
34
 
35
  def norm(bbox, w, h):
 
134
 
135
  return inference, width, height
136
 
137
+ def do_suppression(dataloader, inference, width, height, gp=None, batch_size=BATCH_SIZE, conf_thres=CONF_THRES, iou_thres=NMS_IOU):
138
  """
139
  Args:
140
  frames_dir: a directory containing frames to be evaluated
 
180
 
181
  return all_preds, real_width, real_height
182
 
183
+ def do_tracking(all_preds, image_meter_width, image_meter_height, gp=None, max_age=MAX_AGE, iou_thres=IOU_THRES, min_hits=MIN_HITS, min_length=MIN_LENGTH):
184
 
185
  if (gp): gp(0, "Tracking...")
186
 
 
191
  'image_meter_width': image_meter_width,
192
  'image_meter_height': image_meter_height
193
  }
194
+ tracker = Tracker(clip_info, args={ 'max_age': max_age, 'min_hits': 0, 'iou_threshold': iou_thres}, min_hits=min_hits)
195
 
196
  # Run tracking
197
  with tqdm(total=len(all_preds), desc="Running tracking", ncols=0) as pbar:
scripts/inferEval.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import project_path
2
+ import argparse
3
+ from infer_frames import main as infer
4
+ import sys
5
+ sys.path.append('..')
6
+ sys.path.append('../caltech-fish-counting')
7
+
8
+ from evaluate import evaluate
9
+
10
+ class Object(object):
11
+ pass
12
+
13
+ def main(args):
14
+
15
+ infer_args = Object()
16
+ infer_args.metadata = "../caltech-fish-counting/data/metadata"
17
+ infer_args.frames = "../caltech-fish-counting/data/images"
18
+ infer_args.output = "../caltech-fish-counting/data/result"
19
+ infer_args.weights = "models/v5m_896_300best.pt"
20
+ infer_args.config = args.config
21
+
22
+ infer(infer_args)
23
+
24
+ evaluate("../frames/result_testing", "../frames/MOT", "../frames/metadata", "tracker", True)
25
+
26
+
27
+ def argument_parser():
28
+ parser = argparse.ArgumentParser()
29
+ parser.add_argument("--config", required=True, help="Config object. Required.")
30
+ return parser
31
+
32
+ if __name__ == "__main__":
33
+ args = argument_parser().parse_args()
34
+ main(args)
scripts/infer_frames.py CHANGED
@@ -26,9 +26,20 @@ def main(args):
26
  print("In task...")
27
  print("Cuda available in task?", torch.cuda.is_available())
28
 
 
 
 
 
 
 
 
 
 
 
 
29
  dirname = args.frames
30
 
31
- locations = ["test"]
32
  for loc in locations:
33
 
34
  in_loc_dir = os.path.join(dirname, loc)
@@ -39,6 +50,9 @@ def main(args):
39
  print(out_dir)
40
  print(metadata_path)
41
 
 
 
 
42
  seq_list = os.listdir(in_loc_dir)
43
  idx = 1
44
  for seq in seq_list:
@@ -47,11 +61,11 @@ def main(args):
47
  print(" ")
48
  idx += 1
49
  in_seq_dir = os.path.join(in_loc_dir, seq)
50
- infer_seq(in_seq_dir, out_dir, seq, args.weights, metadata_path)
51
 
52
- def infer_seq(in_dir, out_dir, seq_name, weights, metadata_path):
53
 
54
- gradio_progress = lambda p, m: 0
55
 
56
  image_meter_width = -1
57
  image_meter_height = -1
@@ -68,21 +82,18 @@ def infer_seq(in_dir, out_dir, seq_name, weights, metadata_path):
68
 
69
  # create dataloader
70
  dataloader = create_dataloader_frames_only(in_dir)
71
-
72
- # run detection + tracking
73
- model, device = setup_model(weights)
74
 
75
  try:
76
- inference, width, height = do_detection(dataloader, model, device, gp=gradio_progress)
77
  except:
78
  print("Error in " + seq_name)
79
  with open(os.path.join(out_dir, "ERROR_" + seq_name + ".txt"), 'w') as f:
80
  f.write("ERROR")
81
  return
82
 
83
- all_preds, real_width, real_height = do_suppression(dataloader, inference, width, height, gp=gradio_progress)
84
 
85
- results = do_tracking(all_preds, image_meter_width, image_meter_height, gp=gradio_progress)
86
 
87
  mot_rows = []
88
  for frame in results['frames']:
@@ -118,6 +129,7 @@ def argument_parser():
118
  parser.add_argument("--frames", required=True, help="Path to frame directory. Required.")
119
  parser.add_argument("--metadata", required=True, help="Path to metadata directory. Required.")
120
  parser.add_argument("--output", required=True, help="Path to output directory. Required.")
 
121
  parser.add_argument("--weights", default='models/v5m_896_300best.pt', help="Path to saved YOLOv5 weights. Default: ../models/v5m_896_300best.pt")
122
  return parser
123
 
 
26
  print("In task...")
27
  print("Cuda available in task?", torch.cuda.is_available())
28
 
29
+ # setup config
30
+ config = json.loads(args.config)
31
+ if "conf_threshold" not in config: config['conf_threshold'] = 0.3
32
+ if "nms_iou" not in config: config['nms_iou'] = 0.3
33
+ if "min_length" not in config: config['min_length'] = 0.3
34
+ if "max_age" not in config: config['max_age'] = 20
35
+ if "iou_threshold" not in config: config['iou_threshold'] = 0.01
36
+ if "min_hits" not in config: config['min_hits'] = 11
37
+
38
+ print(config)
39
+
40
  dirname = args.frames
41
 
42
+ locations = ["kenai-val"]
43
  for loc in locations:
44
 
45
  in_loc_dir = os.path.join(dirname, loc)
 
50
  print(out_dir)
51
  print(metadata_path)
52
 
53
+ # run detection + tracking
54
+ model, device = setup_model(args.weights)
55
+
56
  seq_list = os.listdir(in_loc_dir)
57
  idx = 1
58
  for seq in seq_list:
 
61
  print(" ")
62
  idx += 1
63
  in_seq_dir = os.path.join(in_loc_dir, seq)
64
+ infer_seq(in_seq_dir, out_dir, config, seq, model, device, metadata_path)
65
 
66
+ def infer_seq(in_dir, out_dir, config, seq_name, model, device, metadata_path):
67
 
68
+ #progress_log = lambda p, m: 0
69
 
70
  image_meter_width = -1
71
  image_meter_height = -1
 
82
 
83
  # create dataloader
84
  dataloader = create_dataloader_frames_only(in_dir)
 
 
 
85
 
86
  try:
87
+ inference, width, height = do_detection(dataloader, model, device)
88
  except:
89
  print("Error in " + seq_name)
90
  with open(os.path.join(out_dir, "ERROR_" + seq_name + ".txt"), 'w') as f:
91
  f.write("ERROR")
92
  return
93
 
94
+ all_preds, real_width, real_height = do_suppression(dataloader, inference, width, height, conf_thres=config['conf_threshold'], iou_thres=config['nms_iou'])
95
 
96
+ results = do_tracking(all_preds, image_meter_width, image_meter_height, min_length=config['min_length'], max_age=config['max_age'], iou_thres=config['iou_threshold'], min_hits=config['min_hits'])
97
 
98
  mot_rows = []
99
  for frame in results['frames']:
 
129
  parser.add_argument("--frames", required=True, help="Path to frame directory. Required.")
130
  parser.add_argument("--metadata", required=True, help="Path to metadata directory. Required.")
131
  parser.add_argument("--output", required=True, help="Path to output directory. Required.")
132
+ parser.add_argument("--config", default="{}", help="Config object. Required.")
133
  parser.add_argument("--weights", default='models/v5m_896_300best.pt', help="Path to saved YOLOv5 weights. Default: ../models/v5m_896_300best.pt")
134
  return parser
135