oskarastrom's picture
min_travel for inference
73ba285
raw
history blame
3.2 kB
import project_path
import argparse
from track_detection import main as track
import sys
import numpy as np
sys.path.append('..')
sys.path.append('../caltech-fish-counting')
from evaluate import evaluate
class Object(object):
pass
def main(args):
infer_args = Object()
infer_args.detections = args.detection_dir
infer_args.metadata = "../frames/metadata"
infer_args.output = "../frames/result_testing"
infer_args.tracker = 'tracker'
config = {
'conf_threshold': float(args.conf_threshold),
'nms_iou': float(args.nms_iou),
'min_length': float(args.min_length),
'min_travel': float(args.min_travel),
'max_age': int(args.max_age),
'iou_threshold': float(args.iou_threshold),
'min_hits': int(args.min_hits),
'associativity': None
}
if args.associativity != "":
if (args.associativity.startswith("boost")):
config['associativity'] = "boost"
conf = args.associativity.split(":")
if len(conf) > 1: config['boost_power'] = float(conf[1])
if len(conf) > 2: config['boost_decay'] = float(conf[2])
elif (args.associativity.startswith("bytetrack")):
config['associativity'] = "bytetrack"
conf = args.associativity.split(":")
if len(conf) > 1: config['low_conf_threshold'] = float(conf[1])
if len(conf) > 2: config['high_conf_threshold'] = float(conf[2])
else:
print("INVALID ASSOCIATIVITY TYPE:", args.associativity)
return
print("verbose", args.verbose)
track(infer_args, config=config, verbose=args.verbose)
result = evaluate(infer_args.output, "../frames/MOT", "../frames/metadata", infer_args.tracker, True)
metrics = result['MotChallenge2DBox']['tracker']['COMBINED_SEQ']['pedestrian']
print('HOTA:', np.mean(metrics['HOTA']['HOTA'])*100)
print('MOTA:', metrics['CLEAR']['MOTA']*100)
print('IDF1:', metrics['Identity']['IDF1']*100)
print('nMAE:', metrics['nMAE']['nMAE']*100)
print('misscounts:', str(metrics['nMAE']['nMAE_numer']) + "/" + str(metrics['nMAE']['nMAE_denom']))
return result
def argument_parser():
parser = argparse.ArgumentParser()
parser.add_argument("--detection_dir", default="../frames/detection_storage")
parser.add_argument("--conf_threshold", default=0.3, help="Config object. Required.")
parser.add_argument("--nms_iou", default=0.3, help="Config object. Required.")
parser.add_argument("--min_length", default=0.3, help="Config object. Required.")
parser.add_argument("--min_travel", default=0, help="Config object. Required.")
parser.add_argument("--max_age", default=20, help="Config object. Required.")
parser.add_argument("--iou_threshold", default=0.01, help="Config object. Required.")
parser.add_argument("--min_hits", default=11, help="Config object. Required.")
parser.add_argument("--associativity", default='', help="Config object. Required.")
parser.add_argument("--verbose", action='store_true', help="Config object. Required.")
return parser
if __name__ == "__main__":
args = argument_parser().parse_args()
main(args)