|
|
|
|
|
|
|
|
from typing import Any, Dict, Optional |
|
|
|
|
|
from torch.nn.common_types import _size_2_t |
|
|
from torch.nn.modules.utils import _pair |
|
|
from transformers import PretrainedConfig |
|
|
|
|
|
|
|
|
class JapaneseMuLanConfig(PretrainedConfig): |
|
|
config_class = "japanese-mulan" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
music_encoder_config: Optional[Dict[str, Any]] = None, |
|
|
text_encoder_config: Optional[Dict[str, Any]] = None, |
|
|
**kwargs, |
|
|
) -> None: |
|
|
super().__init__(**kwargs) |
|
|
|
|
|
if music_encoder_config is None: |
|
|
music_encoder_config = {} |
|
|
|
|
|
if text_encoder_config is None: |
|
|
text_encoder_config = {} |
|
|
|
|
|
self.music_encoder_config = JapaneseMuLanMusicEncoderConfig( |
|
|
**music_encoder_config |
|
|
) |
|
|
self.text_encoder_config = JapaneseMuLanTextEncoderConfig(**text_encoder_config) |
|
|
|
|
|
def to_diff_dict(self) -> Dict[str, Any]: |
|
|
serializable_config_dict = super().to_diff_dict() |
|
|
sub_serializable_config_dict = { |
|
|
"music_encoder_config": _to_diff_dict(self.music_encoder_config), |
|
|
"text_encoder_config": _to_diff_dict(self.text_encoder_config), |
|
|
} |
|
|
self.dict_torch_dtype_to_str(sub_serializable_config_dict) |
|
|
serializable_config_dict.update(sub_serializable_config_dict) |
|
|
|
|
|
return serializable_config_dict |
|
|
|
|
|
|
|
|
class JapaneseMuLanMusicEncoderConfig(PretrainedConfig): |
|
|
def __init__( |
|
|
self, |
|
|
model_name: str = "ast-base384", |
|
|
out_channels: int = 128, |
|
|
stride: _size_2_t = 10, |
|
|
n_bins: int = 128, |
|
|
n_frames: int = 1024, |
|
|
n_pretrained_frames: Optional[int] = None, |
|
|
**kwargs, |
|
|
) -> None: |
|
|
super().__init__(**kwargs) |
|
|
|
|
|
if n_pretrained_frames is None: |
|
|
n_pretrained_frames = n_frames |
|
|
|
|
|
self.model_name = model_name |
|
|
self.out_channels = out_channels |
|
|
self.stride = _pair(stride) |
|
|
self.n_bins = n_bins |
|
|
self.n_frames = n_frames |
|
|
self.n_pretrained_frames = n_pretrained_frames |
|
|
|
|
|
|
|
|
class JapaneseMuLanTextEncoderConfig(PretrainedConfig): |
|
|
def __init__( |
|
|
self, |
|
|
model_name: str = "pkshatech/GLuCoSE-base-ja", |
|
|
out_channels: int = 128, |
|
|
**kwargs, |
|
|
) -> None: |
|
|
super().__init__(**kwargs) |
|
|
|
|
|
self.model_name = model_name |
|
|
self.out_channels = out_channels |
|
|
|
|
|
|
|
|
def _to_diff_dict(c: PretrainedConfig) -> Dict[str, Any]: |
|
|
"""Function to override PretrainedConfig.to_diff_dict() |
|
|
NOTE |
|
|
---- |
|
|
In transformers==4.38.1, |
|
|
PretrainedConfig.__repr__ may not be able to show configs that has some sub-configs |
|
|
""" |
|
|
d = c.to_diff_dict() |
|
|
|
|
|
if "transformers_version" in d: |
|
|
d.pop("transformers_version") |
|
|
|
|
|
return d |
|
|
|