japanese-mulan-base / configuration_mulan.py
tky823's picture
Upload configuration_mulan.py with huggingface_hub
4ca828d verified
raw
history blame
2.85 kB
# Copyright 2025 LY Corporation
# ported from https://huggingface.co/line-corporation/clip-japanese-base/blob/main/configuration_clyp.py
from typing import Any, Dict, Optional
from torch.nn.common_types import _size_2_t
from torch.nn.modules.utils import _pair
from transformers import PretrainedConfig
class JapaneseMuLanConfig(PretrainedConfig):
config_class = "japanese-mulan"
def __init__(
self,
music_encoder_config: Optional[Dict[str, Any]] = None,
text_encoder_config: Optional[Dict[str, Any]] = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
if music_encoder_config is None:
music_encoder_config = {}
if text_encoder_config is None:
text_encoder_config = {}
self.music_encoder_config = JapaneseMuLanMusicEncoderConfig(
**music_encoder_config
)
self.text_encoder_config = JapaneseMuLanTextEncoderConfig(**text_encoder_config)
def to_diff_dict(self) -> Dict[str, Any]:
serializable_config_dict = super().to_diff_dict()
sub_serializable_config_dict = {
"music_encoder_config": _to_diff_dict(self.music_encoder_config),
"text_encoder_config": _to_diff_dict(self.text_encoder_config),
}
self.dict_torch_dtype_to_str(sub_serializable_config_dict)
serializable_config_dict.update(sub_serializable_config_dict)
return serializable_config_dict
class JapaneseMuLanMusicEncoderConfig(PretrainedConfig):
def __init__(
self,
model_name: str = "ast-base384",
out_channels: int = 128,
stride: _size_2_t = 10,
n_bins: int = 128,
n_frames: int = 1024,
n_pretrained_frames: Optional[int] = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
if n_pretrained_frames is None:
n_pretrained_frames = n_frames
self.model_name = model_name
self.out_channels = out_channels
self.stride = _pair(stride)
self.n_bins = n_bins
self.n_frames = n_frames
self.n_pretrained_frames = n_pretrained_frames
class JapaneseMuLanTextEncoderConfig(PretrainedConfig):
def __init__(
self,
model_name: str = "pkshatech/GLuCoSE-base-ja",
out_channels: int = 128,
**kwargs,
) -> None:
super().__init__(**kwargs)
self.model_name = model_name
self.out_channels = out_channels
def _to_diff_dict(c: PretrainedConfig) -> Dict[str, Any]:
"""Function to override PretrainedConfig.to_diff_dict()
NOTE
----
In transformers==4.38.1,
PretrainedConfig.__repr__ may not be able to show configs that has some sub-configs
"""
d = c.to_diff_dict()
if "transformers_version" in d:
d.pop("transformers_version")
return d