# Copyright 2025 LY Corporation # ported from https://huggingface.co/line-corporation/clip-japanese-base/blob/main/modeling_clyp.py from typing import Any, Optional import torch import torch.nn as nn import torch.nn.functional as F from sentence_transformers import SentenceTransformer from torch.nn.modules.utils import _pair from transformers import PreTrainedModel from transformers.tokenization_utils_base import BatchEncoding from .configuration_mulan import ( JapaneseMuLanConfig, JapaneseMuLanMusicEncoderConfig, JapaneseMuLanTextEncoderConfig, ) from .modeling_ast import ( AudioSpectrogramTransformer, HeadTokenAggregator, PositionalPatchEmbedding, ) class MuLanPreTrainedModel(PreTrainedModel): config_class = JapaneseMuLanConfig def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) def _init_weights(self, module: Any) -> None: pass class MuLanModel(MuLanPreTrainedModel): def __init__(self, config: JapaneseMuLanConfig) -> None: super().__init__(config) self.music_encoder = create_music_encoder(config.music_encoder_config) self.text_encoder = create_text_encoder(config.text_encoder_config) def get_music_features( self, spectrogram: torch.Tensor, batch_mean: bool = True ) -> torch.Tensor: if batch_mean is None: if self.training: batch_mean = False else: batch_mean = True music_embedding = self.music_encoder(spectrogram, batch_mean=batch_mean) return music_embedding def get_text_features( self, input_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: text_embedding = self.text_encoder( { "input_ids": input_ids, "attention_mask": attention_mask, }, batch_mean=False, ) return text_embedding class ModalEncoderWrapper(nn.Module): """Wrapper class of modal tower.""" def __init__( self, backbone: nn.Module, out_channels: int, hidden_channels: Optional[int] = None, freeze_backbone: bool = False, ) -> None: super().__init__() self.backbone = backbone if hidden_channels is None: if isinstance(backbone, AudioSpectrogramTransformer): backbone: AudioSpectrogramTransformer hidden_channels = backbone.embedding.embedding_dim elif isinstance(backbone, SentenceTransformer): backbone: SentenceTransformer hidden_channels = backbone[-1].word_embedding_dimension else: raise NotImplementedError( f"{type(backbone)} is not supported as backbone network." ) self.linear = nn.Linear(hidden_channels, out_channels) self.freeze_backbone = freeze_backbone if self.freeze_backbone: for p in self.backbone.parameters(): p.requires_grad = False self.out_channels = out_channels def forward(self, *args, batch_mean: bool = None, **kwargs) -> torch.Tensor: """Forward pass of tower wrapper. Args: args (tuple): Positional arguments given to backbone. kwargs (dict): Keyword arguments given to backbone. Returns: torch.Tensor: Embedding of shape (*, out_channels). """ embed = self.backbone(*args, **kwargs) if isinstance(self.backbone, SentenceTransformer): if isinstance(embed, (dict, BatchEncoding)): embed = embed["sentence_embedding"] else: raise ValueError( f"Invalid type {type(embed)} is detected as sentence transformer output." ) else: assert isinstance(embed, torch.Tensor), ( f"Invalid type {type(embed)} is detected." ) x = self.linear(embed) output = F.normalize(x, p=2, dim=-1) if self.training: assert not batch_mean else: if batch_mean is None: batch_mean = False if batch_mean: output = output.mean(dim=0, keepdim=True) return output class MusicEncoder(ModalEncoderWrapper): """Alias of ModalEncoderWrapper for music modal.""" class TextEncoder(ModalEncoderWrapper): """Alias of ModalEncoderWrapper for text modal.""" def create_music_encoder(config: JapaneseMuLanMusicEncoderConfig) -> MusicEncoder: stride = _pair(config.stride) n_bins = config.n_bins n_frames = config.n_pretrained_frames model_name = config.model_name out_channels = config.out_channels ast_prefix = "ast-" if model_name.startswith(ast_prefix): model_size = model_name[len(ast_prefix) :] assert model_size == "base384", "Only base384 is supported as model_size." kernel_size = (16, 16) embedding_dim = 768 nhead = 12 dim_feedforward = 3072 activation = "gelu" num_layers = 12 layer_norm_eps = 1e-6 embedding = PositionalPatchEmbedding( embedding_dim=embedding_dim, kernel_size=kernel_size, stride=stride, insert_cls_token=True, insert_dist_token=True, n_bins=n_bins, n_frames=n_frames, ) encoder_layer = nn.TransformerEncoderLayer( d_model=embedding_dim, nhead=nhead, dim_feedforward=dim_feedforward, activation=activation, batch_first=True, norm_first=True, layer_norm_eps=layer_norm_eps, ) norm = nn.LayerNorm(embedding_dim, eps=layer_norm_eps) backbone = nn.TransformerEncoder( encoder_layer, num_layers=num_layers, norm=norm ) aggregator = HeadTokenAggregator(position=0) backbone = AudioSpectrogramTransformer( embedding, backbone, aggregator=aggregator, ) else: raise NotImplementedError( f"{model_name} is not supported as model_name of MusicEncoder." ) return MusicEncoder(backbone, out_channels) def create_text_encoder(config: JapaneseMuLanTextEncoderConfig) -> TextEncoder: model_name = config.model_name out_channels = config.out_channels if model_name == "pkshatech/GLuCoSE-base-ja": # NOTE: hack to avoid meta tensor error backbone = SentenceTransformer( model_name_or_path=model_name, device="meta", ) backbone.to_empty(device="cpu") else: raise NotImplementedError( f"{model_name} is not supported as model_name of TextEncoder." ) return TextEncoder(backbone, out_channels)