Spaces:
Runtime error
Runtime error
Commit
·
91158dd
1
Parent(s):
1f98f08
Reverse tracking test
Browse files- inference.py +2 -2
- lib/fish_eye/associative.py +4 -3
- lib/fish_eye/tracker.py +10 -3
inference.py
CHANGED
@@ -342,11 +342,11 @@ def do_associative_tracking(low_preds, high_preds, image_meter_width, image_mete
|
|
342 |
'image_meter_width': image_meter_width,
|
343 |
'image_meter_height': image_meter_height
|
344 |
}
|
345 |
-
tracker = Tracker(clip_info, algorithm=Associate, args={ 'max_age': max_age, 'min_hits': 0, 'iou_threshold': iou_thres}, min_hits=min_hits)
|
346 |
|
347 |
# Run tracking
|
348 |
with tqdm(total=len(low_preds), desc="Running tracking", ncols=0, disable=not verbose) as pbar:
|
349 |
-
for i, key in enumerate(sorted(low_preds.keys())):
|
350 |
if gp: gp(i / len(low_preds), pbar.__str__())
|
351 |
low_boxes = low_preds[key]
|
352 |
high_boxes = high_preds[key]
|
|
|
342 |
'image_meter_width': image_meter_width,
|
343 |
'image_meter_height': image_meter_height
|
344 |
}
|
345 |
+
tracker = Tracker(clip_info, algorithm=Associate, reversed=True, args={ 'max_age': max_age, 'min_hits': 0, 'iou_threshold': iou_thres}, min_hits=min_hits)
|
346 |
|
347 |
# Run tracking
|
348 |
with tqdm(total=len(low_preds), desc="Running tracking", ncols=0, disable=not verbose) as pbar:
|
349 |
+
for i, key in enumerate(sorted(low_preds.keys(), reverse=True)):
|
350 |
if gp: gp(i / len(low_preds), pbar.__str__())
|
351 |
low_boxes = low_preds[key]
|
352 |
high_boxes = high_preds[key]
|
lib/fish_eye/associative.py
CHANGED
@@ -147,6 +147,7 @@ def associate_detections_to_trackers(detections,trackers,iou_threshold = 0.3):
|
|
147 |
|
148 |
iou_matrix = iou_batch(detections, trackers)
|
149 |
|
|
|
150 |
if min(iou_matrix.shape) > 0:
|
151 |
a = (iou_matrix > iou_threshold).astype(np.int32)
|
152 |
if a.sum(1).max() == 1 and a.sum(0).max() == 1:
|
@@ -165,15 +166,15 @@ def associate_detections_to_trackers(detections,trackers,iou_threshold = 0.3):
|
|
165 |
if(t not in matched_indices[:,1]):
|
166 |
unmatched_trackers.append(t)
|
167 |
|
168 |
-
|
169 |
matches = []
|
170 |
for m in matched_indices:
|
171 |
-
if(iou_matrix[m[0], m[1]]<iou_threshold):
|
172 |
unmatched_detections.append(m[0])
|
173 |
unmatched_trackers.append(m[1])
|
174 |
else:
|
175 |
matches.append(m.reshape(1,2))
|
176 |
-
if(len(matches)==0):
|
177 |
matches = np.empty((0,2),dtype=int)
|
178 |
else:
|
179 |
matches = np.concatenate(matches,axis=0)
|
|
|
147 |
|
148 |
iou_matrix = iou_batch(detections, trackers)
|
149 |
|
150 |
+
# find
|
151 |
if min(iou_matrix.shape) > 0:
|
152 |
a = (iou_matrix > iou_threshold).astype(np.int32)
|
153 |
if a.sum(1).max() == 1 and a.sum(0).max() == 1:
|
|
|
166 |
if(t not in matched_indices[:,1]):
|
167 |
unmatched_trackers.append(t)
|
168 |
|
169 |
+
# filter out matched with low IOU
|
170 |
matches = []
|
171 |
for m in matched_indices:
|
172 |
+
if (iou_matrix[m[0], m[1]] < iou_threshold):
|
173 |
unmatched_detections.append(m[0])
|
174 |
unmatched_trackers.append(m[1])
|
175 |
else:
|
176 |
matches.append(m.reshape(1,2))
|
177 |
+
if (len(matches) == 0):
|
178 |
matches = np.empty((0,2),dtype=int)
|
179 |
else:
|
180 |
matches = np.concatenate(matches,axis=0)
|
lib/fish_eye/tracker.py
CHANGED
@@ -9,12 +9,16 @@ from sort import Sort
|
|
9 |
from associative import Associate
|
10 |
|
11 |
class Tracker:
|
12 |
-
def __init__(self, clip_info, algorithm=Sort, args={'max_age':1, 'min_hits':0, 'iou_threshold':0.05}, min_hits=3):
|
13 |
self.algorithm = algorithm(**args)
|
14 |
self.fish_ids = Counter()
|
|
|
15 |
self.min_hits = min_hits
|
16 |
self.json_data = deepcopy(clip_info)
|
17 |
-
|
|
|
|
|
|
|
18 |
self.json_data['frames'] = []
|
19 |
|
20 |
# Boxes should be given in normalized [x1,y1,x2,y2,c]
|
@@ -59,7 +63,10 @@ class Tracker:
|
|
59 |
'frame_num': self.frame_id,
|
60 |
'fish': new_frame_entries
|
61 |
})
|
62 |
-
self.
|
|
|
|
|
|
|
63 |
|
64 |
def finalize(self, output_path=None, min_length=-1.0): # vert_margin=0.0
|
65 |
json_data = deepcopy(self.json_data)
|
|
|
9 |
from associative import Associate
|
10 |
|
11 |
class Tracker:
|
12 |
+
def __init__(self, clip_info, algorithm=Sort, args={'max_age':1, 'min_hits':0, 'iou_threshold':0.05}, min_hits=3, reversed=False):
|
13 |
self.algorithm = algorithm(**args)
|
14 |
self.fish_ids = Counter()
|
15 |
+
self.reversed = reversed
|
16 |
self.min_hits = min_hits
|
17 |
self.json_data = deepcopy(clip_info)
|
18 |
+
if reversed:
|
19 |
+
self.frame_id = self.json_data['end_frame']
|
20 |
+
else:
|
21 |
+
self.frame_id = self.json_data['start_frame']
|
22 |
self.json_data['frames'] = []
|
23 |
|
24 |
# Boxes should be given in normalized [x1,y1,x2,y2,c]
|
|
|
63 |
'frame_num': self.frame_id,
|
64 |
'fish': new_frame_entries
|
65 |
})
|
66 |
+
if self.reversed:
|
67 |
+
self.frame_id -= 1
|
68 |
+
else:
|
69 |
+
self.frame_id += 1
|
70 |
|
71 |
def finalize(self, output_path=None, min_length=-1.0): # vert_margin=0.0
|
72 |
json_data = deepcopy(self.json_data)
|