import argparse import sys, os import numpy as np import project_subpath from frames_to_tracks import main as infer from backend.InferenceConfig import InferenceConfig current_dir = os.path.dirname(os.path.realpath(__file__)) pardir = os.path.dirname(current_dir) sys.path.append(os.path.join(pardir, "../caltech-fish-counting/")) from evaluate import evaluate class Object(object): pass def main(args): """ Perform inference on a directory of frames, saves the tracks, and runs the 'evaluate' script from the 'caltech-fish-counting' repo Args: weights (str): path to weights conf_threshold (float): confidence cutoff for detection filtering nms_iou (float): non-maximum suppression IOU threshold min_length (float): minimum length of fish in meters in order to count max_length (float): maximum length of fish in meters in order to count. Disable with 0 min_travel (float): minimum travel distance in meters of track in order to count max_age (int): aximum time between detections before a fish is forgotten by the tracker min_hits (int): minimum length of track in frames in order to count associativity (str): string representation of tracking method with corresponding hyperparameters separated by ':' """ infer_args = Object() infer_args.metadata = "../frames/metadata" infer_args.frames = "../frames/images" infer_args.location = "kenai-val" infer_args.output = "../frames/result" infer_args.weights = "models/v5m_896_300best.pt" config = InferenceConfig( conf_thresh=float(args.conf_threshold), nms_iou=float(args.nms_iou), min_hits=int(args.min_hits), max_age=int(args.max_age), min_length=float(args.min_length), max_length=float(args.max_length), min_travel=float(args.min_travel), ) config.enable_tracker_from_string(args.associativity) infer(infer_args, config=config, verbose=False) result = evaluate("../frames/result", "../frames/MOT", "../frames/metadata", "tracker", True, location=infer_args.location) metrics = result['MotChallenge2DBox']['tracker']['COMBINED_SEQ']['pedestrian'] print('HOTA:', np.mean(metrics['HOTA']['HOTA'])*100) print('MOTA:', metrics['CLEAR']['MOTA']*100) print('IDF1:', metrics['Identity']['IDF1']*100) print('nMAE:', metrics['nMAE']['nMAE']*100) print('misscounts:', str(metrics['nMAE']['nMAE_numer']) + "/" + str(metrics['nMAE']['nMAE_denom'])) return result def argument_parser(): default = InferenceConfig() parser = argparse.ArgumentParser() parser.add_argument("--weights", default=default.weights, help="Path to weights") parser.add_argument("--conf_threshold", default=default.conf_thresh, help="Confidence cutoff for detection filtering") parser.add_argument("--nms_iou", default=default.nms_iou, help="Non-maximum Suppression IOU threshold") parser.add_argument("--min_length", default=default.min_length, help="Minimum length of fish in meters in order to count") parser.add_argument("--max_length", default=default.max_length, help="Maximum length of fish in meters in order to count. Disable with 0") parser.add_argument("--min_travel", default=default.min_travel, help="Minimum travel distance in meters of track in order to count.") parser.add_argument("--max_age", default=default.max_age, help="Maximum time between detections before a fish is forgotten by the tracker") parser.add_argument("--min_hits", default=default.min_hits, help="Minimum length of track in frames in order to count") parser.add_argument("--associativity", default='', help="String representation of tracking method with corresponding hyperparameters separated by ':'") return parser if __name__ == "__main__": args = argument_parser().parse_args() main(args)