File size: 2,854 Bytes
4ca828d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
# Copyright 2025 LY Corporation
# ported from https://huggingface.co/line-corporation/clip-japanese-base/blob/main/configuration_clyp.py
from typing import Any, Dict, Optional

from torch.nn.common_types import _size_2_t
from torch.nn.modules.utils import _pair
from transformers import PretrainedConfig


class JapaneseMuLanConfig(PretrainedConfig):
    config_class = "japanese-mulan"

    def __init__(
        self,
        music_encoder_config: Optional[Dict[str, Any]] = None,
        text_encoder_config: Optional[Dict[str, Any]] = None,
        **kwargs,
    ) -> None:
        super().__init__(**kwargs)

        if music_encoder_config is None:
            music_encoder_config = {}

        if text_encoder_config is None:
            text_encoder_config = {}

        self.music_encoder_config = JapaneseMuLanMusicEncoderConfig(
            **music_encoder_config
        )
        self.text_encoder_config = JapaneseMuLanTextEncoderConfig(**text_encoder_config)

    def to_diff_dict(self) -> Dict[str, Any]:
        serializable_config_dict = super().to_diff_dict()
        sub_serializable_config_dict = {
            "music_encoder_config": _to_diff_dict(self.music_encoder_config),
            "text_encoder_config": _to_diff_dict(self.text_encoder_config),
        }
        self.dict_torch_dtype_to_str(sub_serializable_config_dict)
        serializable_config_dict.update(sub_serializable_config_dict)

        return serializable_config_dict


class JapaneseMuLanMusicEncoderConfig(PretrainedConfig):
    def __init__(
        self,
        model_name: str = "ast-base384",
        out_channels: int = 128,
        stride: _size_2_t = 10,
        n_bins: int = 128,
        n_frames: int = 1024,
        n_pretrained_frames: Optional[int] = None,
        **kwargs,
    ) -> None:
        super().__init__(**kwargs)

        if n_pretrained_frames is None:
            n_pretrained_frames = n_frames

        self.model_name = model_name
        self.out_channels = out_channels
        self.stride = _pair(stride)
        self.n_bins = n_bins
        self.n_frames = n_frames
        self.n_pretrained_frames = n_pretrained_frames


class JapaneseMuLanTextEncoderConfig(PretrainedConfig):
    def __init__(
        self,
        model_name: str = "pkshatech/GLuCoSE-base-ja",
        out_channels: int = 128,
        **kwargs,
    ) -> None:
        super().__init__(**kwargs)

        self.model_name = model_name
        self.out_channels = out_channels


def _to_diff_dict(c: PretrainedConfig) -> Dict[str, Any]:
    """Function to override PretrainedConfig.to_diff_dict()
    NOTE
    ----
    In transformers==4.38.1,
    PretrainedConfig.__repr__ may not be able to show configs that has some sub-configs
    """
    d = c.to_diff_dict()

    if "transformers_version" in d:
        d.pop("transformers_version")

    return d