File size: 2,534 Bytes
2a572c2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
from enum import Enum

### Configuration options
WEIGHTS = 'models/v5m_896_300best.pt'
# will need to configure these based on GPU hardware
BATCH_SIZE = 32

CONF_THRES = 0.05 # detection
NMS_IOU  = 0.2 # NMS IOU
MAX_AGE = 14 # time until missing fish get's new id
MIN_HITS = 16 # minimum number of frames with a specific fish for it to count
MIN_LENGTH = 0.3 # minimum fish length, in meters
IOU_THRES = 0.01 # IOU threshold for tracking
MIN_TRAVEL = -1 # Minimum distance a track has to travel

class TrackerType(Enum):
    NONE = 0
    CONF_BOOST = 1
    BYTETRACK = 2

class InferenceConfig:
    def __init__(self, 
                weights=WEIGHTS, conf_thresh=CONF_THRES, nms_iou=NMS_IOU, 
                min_hits=MIN_HITS, max_age=MAX_AGE, min_length=MIN_LENGTH, min_travel=MIN_TRAVEL):
        self.weights = weights
        self.conf_thresh = conf_thresh
        self.nms_iou = nms_iou
        self.min_hits = min_hits
        self.max_age = max_age
        self.min_length = min_length
        self.min_travel = min_travel

        self.associative_tracker = TrackerType.NONE
        self.boost_power = 1
        self.boost_decay = 1
        self.byte_low_conf = 1
        self.byte_high_conf = 1

    def enable_conf_boost(self, power, decay):
        self.associative_tracker = TrackerType.CONF_BOOST
        self.boost_power = power
        self.boost_decay = decay

    def enable_byte_track(self, low, high):
        self.associative_tracker = TrackerType.BYTETRACK
        self.byte_low_conf = low
        self.byte_high_conf = high

    def to_dict(self):
        dict = {
            'weights': self.weights,
            'nms_iou': self.nms_iou,
            'min_hits': self.min_hits,
            'max_age': self.max_age,
            'min_length': self.min_length,
            'min_travel': self.min_travel,
        }

        # Add tracker specific parameters
        if (self.associative_tracker == TrackerType.BYTETRACK):
            dict['tracker'] = "ByteTrack"
            dict['byte_low_conf'] = self.byte_low_conf
            dict['byte_high_conf'] = self.byte_high_conf
        elif (self.associative_tracker == TrackerType.CONF_BOOST):
            dict['tracker'] = "Confidence Boost"
            dict['conf_thresh'] = self.conf_thresh
            dict['boost_power'] = self.boost_power
            dict['boost_decay'] = self.boost_decay
        elif (self.associative_tracker == TrackerType.NONE):
            dict['tracker'] = "None"
            dict['conf_thresh'] = self.conf_thresh

        return dict