File size: 1,619 Bytes
f0e612b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
from transformers import PretrainedConfig
from typing import List


class GRAMTBinauralTimeConfig(PretrainedConfig):
    model_type = "gramt-binaural-time"
    model_size = "base"
    in_channels: int = 2
    patch_size = (128,2)
    frequency_stride = 128
    time_stride = 2

    def __init__(
        self,
        decoder_mlp_ratio: float = 4.0,
        decoder_depth: int = 8,
        decoder_num_heads: int = 8,
        decoder_embedding_dim: int = 512,
        decoder_window_sizes: List[int] = [2, 5, 10, 25, 50, 0, 0, 0],
        encoder_num_layers = 12,
        encoder_num_heads = 12,
        encoder_hidden_dim = 768,
        encoder_mlp_ratio = 4.0,
        encoder_dropout = 0.0,
        encoder_attention_dropout = 0.0,
        encoder_norm_layer_eps = 1e-6,
        input_length = 200,
        num_mel_bins = 128,
        **kwargs,
    ):

        self.decoder_mlp_ratio = decoder_mlp_ratio
        self.decoder_depth = decoder_depth
        self.decoder_num_heads = decoder_num_heads
        self.decoder_embedding_dim = decoder_embedding_dim
        self.decoder_window_sizes = decoder_window_sizes

        self.encoder_num_layers = encoder_num_layers
        self.encoder_num_heads = encoder_num_heads
        self.encoder_hidden_dim = encoder_hidden_dim
        self.encoder_mlp_ratio = encoder_mlp_ratio
        self.encoder_dropout = encoder_dropout
        self.encoder_attention_dropout = encoder_attention_dropout
        self.encoder_norm_layer_eps = encoder_norm_layer_eps


        self.input_length = input_length
        self.num_mel_bins = num_mel_bins
        super().__init__(**kwargs)