File size: 3,187 Bytes
2a572c2
 
a63e231
 
 
 
 
c376f3c
 
 
 
 
2a572c2
 
 
 
 
 
a63e231
 
 
2a572c2
 
c376f3c
a63e231
2a572c2
 
 
 
 
 
 
 
 
 
 
 
 
a63e231
 
 
 
 
 
 
2a572c2
 
 
 
 
 
 
 
 
 
 
 
c376f3c
cf87b0c
c376f3c
cf87b0c
c376f3c
 
cf87b0c
c376f3c
 
 
2a572c2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
from enum import Enum

class TrackerType(Enum):
    NONE = 0
    CONF_BOOST = 1
    BYTETRACK = 2

    def toString(val):
        if val == TrackerType.NONE: return "None"
        if val == TrackerType.CONF_BOOST: return "Confidence Boost"
        if val == TrackerType.BYTETRACK: return "ByteTrack"

### Configuration options
WEIGHTS = 'models/v5m_896_300best.pt'
# will need to configure these based on GPU hardware
BATCH_SIZE = 32

CONF_THRES = 0.05 # detection
NMS_IOU  = 0.25 # NMS IOU
MAX_AGE = 20 # time until missing fish get's new id
MIN_HITS = 11 # minimum number of frames with a specific fish for it to count
MIN_LENGTH = 0.3 # minimum fish length, in meters
IOU_THRES = 0.01 # IOU threshold for tracking
MIN_TRAVEL = 0 # Minimum distance a track has to travel
DEFAULT_TRACKER = TrackerType.BYTETRACK

class InferenceConfig:
    def __init__(self, 
                weights=WEIGHTS, conf_thresh=CONF_THRES, nms_iou=NMS_IOU, 
                min_hits=MIN_HITS, max_age=MAX_AGE, min_length=MIN_LENGTH, min_travel=MIN_TRAVEL):
        self.weights = weights
        self.conf_thresh = conf_thresh
        self.nms_iou = nms_iou
        self.min_hits = min_hits
        self.max_age = max_age
        self.min_length = min_length
        self.min_travel = min_travel

        self.associative_tracker = DEFAULT_TRACKER
        self.boost_power = 2
        self.boost_decay = 0.1
        self.byte_low_conf = 0.1
        self.byte_high_conf = 0.3

    def enable_sort_track(self):
        self.associative_tracker = TrackerType.NONE

    def enable_conf_boost(self, power, decay):
        self.associative_tracker = TrackerType.CONF_BOOST
        self.boost_power = power
        self.boost_decay = decay

    def enable_byte_track(self, low, high):
        self.associative_tracker = TrackerType.BYTETRACK
        self.byte_low_conf = low
        self.byte_high_conf = high

    def find_model(self, model_list):
        print("weights", self.weights)
        for model_name, model_path in enumerate(model_list):
            print("Path", model_path, "->", model_name)
            if model_path == self.weights:
                return model_name
        print("not found")
        return None


    def to_dict(self):
        dict = {
            'weights': self.weights,
            'nms_iou': self.nms_iou,
            'min_hits': self.min_hits,
            'max_age': self.max_age,
            'min_length': self.min_length,
            'min_travel': self.min_travel,
        }

        # Add tracker specific parameters
        if (self.associative_tracker == TrackerType.BYTETRACK):
            dict['tracker'] = "ByteTrack"
            dict['byte_low_conf'] = self.byte_low_conf
            dict['byte_high_conf'] = self.byte_high_conf
        elif (self.associative_tracker == TrackerType.CONF_BOOST):
            dict['tracker'] = "Confidence Boost"
            dict['conf_thresh'] = self.conf_thresh
            dict['boost_power'] = self.boost_power
            dict['boost_decay'] = self.boost_decay
        elif (self.associative_tracker == TrackerType.NONE):
            dict['tracker'] = "None"
            dict['conf_thresh'] = self.conf_thresh

        return dict