File size: 4,197 Bytes
3b4af99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
# --------------------------------------------------------
# SenseTime
# Copyright (c) 2025 SenseTime
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
import copy

from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging

logger = logging.get_logger(__name__)

class FlowConfig(PretrainedConfig):
    def __init__(
            self,
            input_size = 512,
            output_size= 80,
            spk_embed_dim = 192,
            output_type = 'mel',
            vocab_size = 6561,
            input_frame_rate = 25, 
            only_mask_loss = True,
            token_mel_ratio=2,
            pre_lookahead_len=3,
            encoder_config={'output_size': 512,
                            'attention_heads': 8,
                            'linear_units': 2048,
                            'num_blocks': 6,
                            'dropout_rate': 0.1,
                            'positional_dropout_rate': 0.1,
                            'attention_dropout_rate': 0.1,
                            'normalize_before': True,
                            'input_layer': 'linear',
                            'pos_enc_layer_type': 'rel_pos_espnet',
                            'selfattention_layer_type': 'rel_selfattn',
                            'input_size': 512,
                            'use_cnn_module': False,
                            'macaron_style': False,
                            },
            decoder_config={'in_channels': 240,
                            'n_spks': 1,
                            'spk_emb_dim': 80,
                            'cfm_params': {
                                'sigma_min': 1e-06,
                                'solver': 'euler',
                                't_scheduler': 'cosine',
                                'training_cfg_rate': 0.2,
                                'inference_cfg_rate': 0.7,
                               'reg_loss_type': 'l1',
                               },
                            'estimator_config':{
                                'in_channels': 320,
                                'out_channels': 80,
                                'causal': True,
                                'channels': [256],
                                'dropout': 0.0,
                                'attention_head_dim': 64,
                                'n_blocks': 4,
                                'num_mid_blocks': 12,
                                'num_heads': 8,
                                'act_fn': 'gelu'
                                }
                            },
            **kwargs):
        super().__init__(**kwargs)

        self.encoder_config = encoder_config
        self.decoder_config = decoder_config
        
        self.input_size = input_size
        self.output_size = output_size
        self.spk_embed_dim = spk_embed_dim
        self.output_type = output_type
        self.vocab_size = vocab_size
        self.input_frame_rate = input_frame_rate
        self.only_mask_loss = only_mask_loss
        self.token_mel_ratio = token_mel_ratio
        self.pre_lookahead_len = pre_lookahead_len
        pass

    def to_dict(self):
        """
        Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].

        Returns:
            `Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
        """
        output = copy.deepcopy(self.__dict__)
        output['encoder_config'] = self.encoder_config
        output['decoder_config'] = self.decoder_config
        
        output['input_size'] = self.input_size
        output['output_size'] = self.output_size
        output['spk_embed_dim'] = self.spk_embed_dim
        output['output_type'] = self.output_type
        output['vocab_size'] = self.vocab_size
        output['input_frame_rate'] = self.input_frame_rate
        output['only_mask_loss'] = self.only_mask_loss
        output['token_mel_ratio'] = self.token_mel_ratio
        output['pre_lookahead_len'] = self.pre_lookahead_len

        return output