# Copyright 2025 LY Corporation # ported from https://huggingface.co/line-corporation/clip-japanese-base/blob/main/configuration_clyp.py from typing import Any, Dict, Optional from torch.nn.common_types import _size_2_t from torch.nn.modules.utils import _pair from transformers import PretrainedConfig class JapaneseMuLanConfig(PretrainedConfig): config_class = "japanese-mulan" def __init__( self, music_encoder_config: Optional[Dict[str, Any]] = None, text_encoder_config: Optional[Dict[str, Any]] = None, **kwargs, ) -> None: super().__init__(**kwargs) if music_encoder_config is None: music_encoder_config = {} if text_encoder_config is None: text_encoder_config = {} self.music_encoder_config = JapaneseMuLanMusicEncoderConfig( **music_encoder_config ) self.text_encoder_config = JapaneseMuLanTextEncoderConfig(**text_encoder_config) def to_diff_dict(self) -> Dict[str, Any]: serializable_config_dict = super().to_diff_dict() sub_serializable_config_dict = { "music_encoder_config": _to_diff_dict(self.music_encoder_config), "text_encoder_config": _to_diff_dict(self.text_encoder_config), } self.dict_torch_dtype_to_str(sub_serializable_config_dict) serializable_config_dict.update(sub_serializable_config_dict) return serializable_config_dict class JapaneseMuLanMusicEncoderConfig(PretrainedConfig): def __init__( self, model_name: str = "ast-base384", out_channels: int = 128, stride: _size_2_t = 10, n_bins: int = 128, n_frames: int = 1024, n_pretrained_frames: Optional[int] = None, **kwargs, ) -> None: super().__init__(**kwargs) if n_pretrained_frames is None: n_pretrained_frames = n_frames self.model_name = model_name self.out_channels = out_channels self.stride = _pair(stride) self.n_bins = n_bins self.n_frames = n_frames self.n_pretrained_frames = n_pretrained_frames class JapaneseMuLanTextEncoderConfig(PretrainedConfig): def __init__( self, model_name: str = "pkshatech/GLuCoSE-base-ja", out_channels: int = 128, **kwargs, ) -> None: super().__init__(**kwargs) self.model_name = model_name self.out_channels = out_channels def _to_diff_dict(c: PretrainedConfig) -> Dict[str, Any]: """Function to override PretrainedConfig.to_diff_dict() NOTE ---- In transformers==4.38.1, PretrainedConfig.__repr__ may not be able to show configs that has some sub-configs """ d = c.to_diff_dict() if "transformers_version" in d: d.pop("transformers_version") return d