File size: 1,019 Bytes
1a29f83
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
from transformers import PretrainedConfig


class TotalClassifierConfig(PretrainedConfig):
    model_type = "total_classifier"

    def __init__(
        self,
        backbone: str = "tf_efficientnetv2_b0",
        feature_dim: int = 192,
        cnn_dropout: float = 0.1,
        in_chans: int = 1,
        rnn_type: str = "GRU",
        rnn_num_layers: int = 1,
        rnn_dropout: float = 0.0,
        num_classes: int = 117,
        seq_len: int = 512,
        linear_dropout: float = 0.1,
        image_size: tuple[int, int] = (256, 256),
        **kwargs,
    ):
        self.backbone = backbone
        self.feature_dim = feature_dim
        self.cnn_dropout = cnn_dropout
        self.in_chans = in_chans
        self.rnn_type = rnn_type
        self.rnn_num_layers = rnn_num_layers
        self.rnn_dropout = rnn_dropout
        self.num_classes = num_classes
        self.seq_len = seq_len
        self.linear_dropout = linear_dropout
        self.image_size = image_size
        super().__init__(**kwargs)