from enum import Enum class TrackerType(Enum): NONE = 0 CONF_BOOST = 1 BYTETRACK = 2 def toString(val): if val == TrackerType.NONE: return "None" if val == TrackerType.CONF_BOOST: return "Confidence Boost" if val == TrackerType.BYTETRACK: return "ByteTrack" ### Configuration options WEIGHTS = 'models/v5m_896_300best.pt' # will need to configure these based on GPU hardware BATCH_SIZE = 32 CONF_THRES = 0.05 # detection NMS_IOU = 0.25 # NMS IOU MAX_AGE = 20 # time until missing fish get's new id MIN_HITS = 11 # minimum number of frames with a specific fish for it to count MIN_LENGTH = 0.3 # minimum fish length, in meters MAX_LENGTH = 0 # maximum fish length, in meters IOU_THRES = 0.01 # IOU threshold for tracking MIN_TRAVEL = 0 # Minimum distance a track has to travel DEFAULT_TRACKER = TrackerType.BYTETRACK class InferenceConfig: def __init__(self, weights=WEIGHTS, conf_thresh=CONF_THRES, nms_iou=NMS_IOU, min_hits=MIN_HITS, max_age=MAX_AGE, min_length=MIN_LENGTH, max_length=MAX_LENGTH, min_travel=MIN_TRAVEL): self.weights = weights self.conf_thresh = conf_thresh self.nms_iou = nms_iou self.min_hits = min_hits self.max_age = max_age self.min_length = min_length self.max_length = max_length self.min_travel = min_travel self.associative_tracker = DEFAULT_TRACKER self.boost_power = 2 self.boost_decay = 0.1 self.byte_low_conf = 0.1 self.byte_high_conf = 0.3 def enable_sort_track(self): self.associative_tracker = TrackerType.NONE def enable_conf_boost(self, power=2, decay=0.1): self.associative_tracker = TrackerType.CONF_BOOST self.boost_power = power self.boost_decay = decay def enable_byte_track(self, low=0.1, high=0.3): self.associative_tracker = TrackerType.BYTETRACK self.byte_low_conf = low self.byte_high_conf = high def enable_tracker_from_string(self, associativity): if associativity != "": if (associativity.startswith("boost")): conf = associativity.split(":") if len(conf) == 3: self.enable_conf_boost(power=float(conf[1]), decay=float(conf[2])) return True else: print("INVALID PARAMETERS FOR CONFIDENCE BOOST:", associativity) return False elif (associativity.startswith("bytetrack")): conf = associativity.split(":") if len(conf) == 3: self.enable_byte_track(low=float(conf[1]), high=float(conf[2])) return True else: print("INVALID PARAMETERS FOR BYTETRACK:", associativity) return False else: print("INVALID ASSOCIATIVITY TYPE:", associativity) return False else: self.enable_sort_track() return True def find_model(self, model_list): print("weights", self.weights) for model_name in model_list: print("Path", model_list[model_name], "->", model_name) if model_list[model_name] == self.weights: return model_name print("not found") return None def to_dict(self): dict = { 'weights': self.weights, 'nms_iou': self.nms_iou, 'min_hits': self.min_hits, 'max_age': self.max_age, 'min_length': self.min_length, 'min_travel': self.min_travel, } # Add tracker specific parameters if (self.associative_tracker == TrackerType.BYTETRACK): dict['tracker'] = "ByteTrack" dict['byte_low_conf'] = self.byte_low_conf dict['byte_high_conf'] = self.byte_high_conf elif (self.associative_tracker == TrackerType.CONF_BOOST): dict['tracker'] = "Confidence Boost" dict['conf_thresh'] = self.conf_thresh dict['boost_power'] = self.boost_power dict['boost_decay'] = self.boost_decay elif (self.associative_tracker == TrackerType.NONE): dict['tracker'] = "None" dict['conf_thresh'] = self.conf_thresh return dict