File size: 1,619 Bytes
f0e612b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 |
from transformers import PretrainedConfig
from typing import List
class GRAMTBinauralTimeConfig(PretrainedConfig):
model_type = "gramt-binaural-time"
model_size = "base"
in_channels: int = 2
patch_size = (128,2)
frequency_stride = 128
time_stride = 2
def __init__(
self,
decoder_mlp_ratio: float = 4.0,
decoder_depth: int = 8,
decoder_num_heads: int = 8,
decoder_embedding_dim: int = 512,
decoder_window_sizes: List[int] = [2, 5, 10, 25, 50, 0, 0, 0],
encoder_num_layers = 12,
encoder_num_heads = 12,
encoder_hidden_dim = 768,
encoder_mlp_ratio = 4.0,
encoder_dropout = 0.0,
encoder_attention_dropout = 0.0,
encoder_norm_layer_eps = 1e-6,
input_length = 200,
num_mel_bins = 128,
**kwargs,
):
self.decoder_mlp_ratio = decoder_mlp_ratio
self.decoder_depth = decoder_depth
self.decoder_num_heads = decoder_num_heads
self.decoder_embedding_dim = decoder_embedding_dim
self.decoder_window_sizes = decoder_window_sizes
self.encoder_num_layers = encoder_num_layers
self.encoder_num_heads = encoder_num_heads
self.encoder_hidden_dim = encoder_hidden_dim
self.encoder_mlp_ratio = encoder_mlp_ratio
self.encoder_dropout = encoder_dropout
self.encoder_attention_dropout = encoder_attention_dropout
self.encoder_norm_layer_eps = encoder_norm_layer_eps
self.input_length = input_length
self.num_mel_bins = num_mel_bins
super().__init__(**kwargs) |