oskarastrom commited on
Commit
bb2dfaa
·
1 Parent(s): 5b63380

Update track_detection.py

Browse files
Files changed (1) hide show
  1. scripts/track_detection.py +8 -5
scripts/track_detection.py CHANGED
@@ -6,7 +6,7 @@ from datetime import datetime
6
  import torch
7
  import os
8
  from dataloader import create_dataloader_frames_only
9
- from inference import setup_model, do_detection, do_suppression, do_confidence_boost, format_predictions, do_tracking
10
  from visualizer import generate_video_batches
11
  import json
12
  from tqdm import tqdm
@@ -102,6 +102,13 @@ def track(in_loc_dir, out_loc_dir, metadata_path, seq, config, verbose):
102
 
103
  outputs = do_suppression(inference, conf_thres=config['conf_threshold'], iou_thres=config['nms_iou'], verbose=verbose)
104
 
 
 
 
 
 
 
 
105
  if config['use_associative']:
106
 
107
  do_confidence_boost(inference, outputs, conf_power=config['boost_power'], conf_decay=config['boost_decay'], verbose=verbose)
@@ -110,10 +117,6 @@ def track(in_loc_dir, out_loc_dir, metadata_path, seq, config, verbose):
110
 
111
  all_preds, real_width, real_height = format_predictions(image_shapes, outputs, width, height, verbose=verbose)
112
 
113
- print(len(all_preds))
114
- print(all_preds[0][0])
115
-
116
- results = do_tracking(all_preds, image_meter_width, image_meter_height, min_length=config['min_length'], max_age=config['max_age'], iou_thres=config['iou_threshold'], min_hits=config['min_hits'], verbose=verbose)
117
 
118
  mot_rows = []
119
  for frame in results['frames']:
 
6
  import torch
7
  import os
8
  from dataloader import create_dataloader_frames_only
9
+ from inference import setup_model, do_detection, do_suppression, do_confidence_boost, format_predictions, do_tracking, do_associative_tracking
10
  from visualizer import generate_video_batches
11
  import json
12
  from tqdm import tqdm
 
102
 
103
  outputs = do_suppression(inference, conf_thres=config['conf_threshold'], iou_thres=config['nms_iou'], verbose=verbose)
104
 
105
+ print(len(outputs))
106
+ print(len(outputs[0]))
107
+ print(outputs[0][0])
108
+
109
+ results = do_associative_tracking(outputs, image_meter_width, image_meter_height, min_length=config['min_length'], max_age=config['max_age'], iou_thres=config['iou_threshold'], min_hits=config['min_hits'], verbose=verbose)
110
+
111
+
112
  if config['use_associative']:
113
 
114
  do_confidence_boost(inference, outputs, conf_power=config['boost_power'], conf_decay=config['boost_decay'], verbose=verbose)
 
117
 
118
  all_preds, real_width, real_height = format_predictions(image_shapes, outputs, width, height, verbose=verbose)
119
 
 
 
 
 
120
 
121
  mot_rows = []
122
  for frame in results['frames']: