Spaces:
Runtime error
Runtime error
Commit
·
e8f4d7e
1
Parent(s):
2482ba4
Inference config
Browse files- inference.py +9 -6
- scripts/inferEval.py +34 -0
- scripts/infer_frames.py +22 -10
inference.py
CHANGED
@@ -24,9 +24,12 @@ WEIGHTS = 'models/v5m_896_300best.pt'
|
|
24 |
# will need to configure these based on GPU hardware
|
25 |
BATCH_SIZE = 32
|
26 |
|
27 |
-
|
28 |
-
|
29 |
-
|
|
|
|
|
|
|
30 |
###
|
31 |
|
32 |
def norm(bbox, w, h):
|
@@ -131,7 +134,7 @@ def do_detection(dataloader, model, device, gp=None, batch_size=BATCH_SIZE):
|
|
131 |
|
132 |
return inference, width, height
|
133 |
|
134 |
-
def do_suppression(dataloader, inference, width, height, gp=None, batch_size=BATCH_SIZE):
|
135 |
"""
|
136 |
Args:
|
137 |
frames_dir: a directory containing frames to be evaluated
|
@@ -177,7 +180,7 @@ def do_suppression(dataloader, inference, width, height, gp=None, batch_size=BAT
|
|
177 |
|
178 |
return all_preds, real_width, real_height
|
179 |
|
180 |
-
def do_tracking(all_preds, image_meter_width, image_meter_height, gp=None):
|
181 |
|
182 |
if (gp): gp(0, "Tracking...")
|
183 |
|
@@ -188,7 +191,7 @@ def do_tracking(all_preds, image_meter_width, image_meter_height, gp=None):
|
|
188 |
'image_meter_width': image_meter_width,
|
189 |
'image_meter_height': image_meter_height
|
190 |
}
|
191 |
-
tracker = Tracker(clip_info, args={ 'max_age':
|
192 |
|
193 |
# Run tracking
|
194 |
with tqdm(total=len(all_preds), desc="Running tracking", ncols=0) as pbar:
|
|
|
24 |
# will need to configure these based on GPU hardware
|
25 |
BATCH_SIZE = 32
|
26 |
|
27 |
+
CONF_THRES = 0.3 # detection
|
28 |
+
NMS_IOU = 0.3 # NMS IOU
|
29 |
+
MIN_LENGTH = 0.3 # minimum fish length, in meters
|
30 |
+
MAX_AGE = 20 # time until missing fish get's new id
|
31 |
+
IOU_THRES = 0.01 # IOU threshold for tracking
|
32 |
+
MIN_HITS = 11 # minimum number of frames with a specific fish for it to count
|
33 |
###
|
34 |
|
35 |
def norm(bbox, w, h):
|
|
|
134 |
|
135 |
return inference, width, height
|
136 |
|
137 |
+
def do_suppression(dataloader, inference, width, height, gp=None, batch_size=BATCH_SIZE, conf_thres=CONF_THRES, iou_thres=NMS_IOU):
|
138 |
"""
|
139 |
Args:
|
140 |
frames_dir: a directory containing frames to be evaluated
|
|
|
180 |
|
181 |
return all_preds, real_width, real_height
|
182 |
|
183 |
+
def do_tracking(all_preds, image_meter_width, image_meter_height, gp=None, max_age=MAX_AGE, iou_thres=IOU_THRES, min_hits=MIN_HITS, min_length=MIN_LENGTH):
|
184 |
|
185 |
if (gp): gp(0, "Tracking...")
|
186 |
|
|
|
191 |
'image_meter_width': image_meter_width,
|
192 |
'image_meter_height': image_meter_height
|
193 |
}
|
194 |
+
tracker = Tracker(clip_info, args={ 'max_age': max_age, 'min_hits': 0, 'iou_threshold': iou_thres}, min_hits=min_hits)
|
195 |
|
196 |
# Run tracking
|
197 |
with tqdm(total=len(all_preds), desc="Running tracking", ncols=0) as pbar:
|
scripts/inferEval.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import project_path
|
2 |
+
import argparse
|
3 |
+
from infer_frames import main as infer
|
4 |
+
import sys
|
5 |
+
sys.path.append('..')
|
6 |
+
sys.path.append('../caltech-fish-counting')
|
7 |
+
|
8 |
+
from evaluate import evaluate
|
9 |
+
|
10 |
+
class Object(object):
|
11 |
+
pass
|
12 |
+
|
13 |
+
def main(args):
|
14 |
+
|
15 |
+
infer_args = Object()
|
16 |
+
infer_args.metadata = "../caltech-fish-counting/data/metadata"
|
17 |
+
infer_args.frames = "../caltech-fish-counting/data/images"
|
18 |
+
infer_args.output = "../caltech-fish-counting/data/result"
|
19 |
+
infer_args.weights = "models/v5m_896_300best.pt"
|
20 |
+
infer_args.config = args.config
|
21 |
+
|
22 |
+
infer(infer_args)
|
23 |
+
|
24 |
+
evaluate("../frames/result_testing", "../frames/MOT", "../frames/metadata", "tracker", True)
|
25 |
+
|
26 |
+
|
27 |
+
def argument_parser():
|
28 |
+
parser = argparse.ArgumentParser()
|
29 |
+
parser.add_argument("--config", required=True, help="Config object. Required.")
|
30 |
+
return parser
|
31 |
+
|
32 |
+
if __name__ == "__main__":
|
33 |
+
args = argument_parser().parse_args()
|
34 |
+
main(args)
|
scripts/infer_frames.py
CHANGED
@@ -26,9 +26,20 @@ def main(args):
|
|
26 |
print("In task...")
|
27 |
print("Cuda available in task?", torch.cuda.is_available())
|
28 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
dirname = args.frames
|
30 |
|
31 |
-
locations = ["
|
32 |
for loc in locations:
|
33 |
|
34 |
in_loc_dir = os.path.join(dirname, loc)
|
@@ -39,6 +50,9 @@ def main(args):
|
|
39 |
print(out_dir)
|
40 |
print(metadata_path)
|
41 |
|
|
|
|
|
|
|
42 |
seq_list = os.listdir(in_loc_dir)
|
43 |
idx = 1
|
44 |
for seq in seq_list:
|
@@ -47,11 +61,11 @@ def main(args):
|
|
47 |
print(" ")
|
48 |
idx += 1
|
49 |
in_seq_dir = os.path.join(in_loc_dir, seq)
|
50 |
-
infer_seq(in_seq_dir, out_dir, seq,
|
51 |
|
52 |
-
def infer_seq(in_dir, out_dir, seq_name,
|
53 |
|
54 |
-
|
55 |
|
56 |
image_meter_width = -1
|
57 |
image_meter_height = -1
|
@@ -68,21 +82,18 @@ def infer_seq(in_dir, out_dir, seq_name, weights, metadata_path):
|
|
68 |
|
69 |
# create dataloader
|
70 |
dataloader = create_dataloader_frames_only(in_dir)
|
71 |
-
|
72 |
-
# run detection + tracking
|
73 |
-
model, device = setup_model(weights)
|
74 |
|
75 |
try:
|
76 |
-
inference, width, height = do_detection(dataloader, model, device
|
77 |
except:
|
78 |
print("Error in " + seq_name)
|
79 |
with open(os.path.join(out_dir, "ERROR_" + seq_name + ".txt"), 'w') as f:
|
80 |
f.write("ERROR")
|
81 |
return
|
82 |
|
83 |
-
all_preds, real_width, real_height = do_suppression(dataloader, inference, width, height,
|
84 |
|
85 |
-
results = do_tracking(all_preds, image_meter_width, image_meter_height,
|
86 |
|
87 |
mot_rows = []
|
88 |
for frame in results['frames']:
|
@@ -118,6 +129,7 @@ def argument_parser():
|
|
118 |
parser.add_argument("--frames", required=True, help="Path to frame directory. Required.")
|
119 |
parser.add_argument("--metadata", required=True, help="Path to metadata directory. Required.")
|
120 |
parser.add_argument("--output", required=True, help="Path to output directory. Required.")
|
|
|
121 |
parser.add_argument("--weights", default='models/v5m_896_300best.pt', help="Path to saved YOLOv5 weights. Default: ../models/v5m_896_300best.pt")
|
122 |
return parser
|
123 |
|
|
|
26 |
print("In task...")
|
27 |
print("Cuda available in task?", torch.cuda.is_available())
|
28 |
|
29 |
+
# setup config
|
30 |
+
config = json.loads(args.config)
|
31 |
+
if "conf_threshold" not in config: config['conf_threshold'] = 0.3
|
32 |
+
if "nms_iou" not in config: config['nms_iou'] = 0.3
|
33 |
+
if "min_length" not in config: config['min_length'] = 0.3
|
34 |
+
if "max_age" not in config: config['max_age'] = 20
|
35 |
+
if "iou_threshold" not in config: config['iou_threshold'] = 0.01
|
36 |
+
if "min_hits" not in config: config['min_hits'] = 11
|
37 |
+
|
38 |
+
print(config)
|
39 |
+
|
40 |
dirname = args.frames
|
41 |
|
42 |
+
locations = ["kenai-val"]
|
43 |
for loc in locations:
|
44 |
|
45 |
in_loc_dir = os.path.join(dirname, loc)
|
|
|
50 |
print(out_dir)
|
51 |
print(metadata_path)
|
52 |
|
53 |
+
# run detection + tracking
|
54 |
+
model, device = setup_model(args.weights)
|
55 |
+
|
56 |
seq_list = os.listdir(in_loc_dir)
|
57 |
idx = 1
|
58 |
for seq in seq_list:
|
|
|
61 |
print(" ")
|
62 |
idx += 1
|
63 |
in_seq_dir = os.path.join(in_loc_dir, seq)
|
64 |
+
infer_seq(in_seq_dir, out_dir, config, seq, model, device, metadata_path)
|
65 |
|
66 |
+
def infer_seq(in_dir, out_dir, config, seq_name, model, device, metadata_path):
|
67 |
|
68 |
+
#progress_log = lambda p, m: 0
|
69 |
|
70 |
image_meter_width = -1
|
71 |
image_meter_height = -1
|
|
|
82 |
|
83 |
# create dataloader
|
84 |
dataloader = create_dataloader_frames_only(in_dir)
|
|
|
|
|
|
|
85 |
|
86 |
try:
|
87 |
+
inference, width, height = do_detection(dataloader, model, device)
|
88 |
except:
|
89 |
print("Error in " + seq_name)
|
90 |
with open(os.path.join(out_dir, "ERROR_" + seq_name + ".txt"), 'w') as f:
|
91 |
f.write("ERROR")
|
92 |
return
|
93 |
|
94 |
+
all_preds, real_width, real_height = do_suppression(dataloader, inference, width, height, conf_thres=config['conf_threshold'], iou_thres=config['nms_iou'])
|
95 |
|
96 |
+
results = do_tracking(all_preds, image_meter_width, image_meter_height, min_length=config['min_length'], max_age=config['max_age'], iou_thres=config['iou_threshold'], min_hits=config['min_hits'])
|
97 |
|
98 |
mot_rows = []
|
99 |
for frame in results['frames']:
|
|
|
129 |
parser.add_argument("--frames", required=True, help="Path to frame directory. Required.")
|
130 |
parser.add_argument("--metadata", required=True, help="Path to metadata directory. Required.")
|
131 |
parser.add_argument("--output", required=True, help="Path to output directory. Required.")
|
132 |
+
parser.add_argument("--config", default="{}", help="Config object. Required.")
|
133 |
parser.add_argument("--weights", default='models/v5m_896_300best.pt', help="Path to saved YOLOv5 weights. Default: ../models/v5m_896_300best.pt")
|
134 |
return parser
|
135 |
|