# Copyright 2025 LY Corporation import functools import warnings from typing import Any, Dict, List, Optional import torch import torch.nn.functional as F import torchaudio.compliance.kaldi as aCK import torchaudio.functional as aF import torchaudio.transforms as aT from packaging import version from transformers import ( AutoFeatureExtractor, AutoProcessor, AutoTokenizer, BatchEncoding, BatchFeature, FeatureExtractionMixin, PreTrainedTokenizer, ProcessorMixin, ) from transformers.utils import logging logger = logging.get_logger(__name__) class MuLanProcessor(ProcessorMixin): attributes = ["music_processor", "text_processor"] music_processor: "AudioSpectrogramTransformerFeatureExtractor" text_processor: "GLuCoSETokenizer" music_processor_class = "AudioSpectrogramTransformerFeatureExtractor" text_processor_class = "GLuCoSETokenizer" def __init__( self, music_processor: "AudioSpectrogramTransformerFeatureExtractor", text_processor: "GLuCoSETokenizer", **kwargs, ) -> None: super().__init__(music_processor, text_processor, **kwargs) def __call__( self, audio: torch.Tensor, text: List[str], sample_rate: Optional[int] = None ) -> BatchFeature: music_features = self.get_music_feature(audio, sample_rate=sample_rate) text_features = self.get_text_feature(text) return BatchFeature( { "music": music_features, "text": text_features, } ) def get_music_feature( self, audio: torch.Tensor, sample_rate: Optional[int] = None ) -> BatchFeature: """Get music feature from audio. Args: audio (torch.Tensor): Audio waveform of shape (batch_size, timesteps). sample_rate (int, optional): Sampling rate of audio. Returns: BatchFeature: Batched music feature. """ spectrogram = self.music_processor(audio, sample_rate=sample_rate) music_features = BatchFeature({"spectrogram": spectrogram}) return music_features def get_text_feature(self, text: List[str]) -> BatchEncoding: """Get text feature from text. Args: text (list): Text to be tokenized. Returns: BatchEncoding: Batch encoding of text feature. """ text_features = self.text_processor(text) return text_features def train(self, mode: bool = True) -> None: """Set training mode.""" if hasattr(self.music_processor, "train") and callable( self.music_processor.train ): self.music_processor.train(mode=mode) if hasattr(self.text_processor, "train") and callable( self.text_processor.train ): self.text_processor.train(mode=mode) def eval(self) -> None: """Set evaluation mode.""" if hasattr(self.music_processor, "eval") and callable( self.music_processor.eval ): self.music_processor.eval() if hasattr(self.text_processor, "eval") and callable(self.text_processor.eval): self.text_processor.eval() class AudioSpectrogramTransformerFeatureExtractor(FeatureExtractionMixin): """Audio processor for official implementation of audio spectrogram transformer. Args: sample_rate (int): Sampling rate. duration (float): Duration of audio in seconds. n_mels (int): Number of Mel-frequency bins. n_frames (int): Number of time frames. fbank_kwargs (int): Keyword arguments given to ``torchaudio.compliance.kaldi.fbank``. spec_norm (bool): Whether to apply normalization to spectrogram. If ``True``, input spectrogram is normalized by ``spec_mean`` and ``spec_std``. spec_mean (float): Mean of spectrogram. spec_std (float): Standard deviation of spectrogram. freq_mask_param (int): Parameter given to torchaudio.FrequencyMasking. time_mask_param (int): Parameter given to torchaudio.TimeMasking. .. note:: Frequency and time maskings are deactivated when ``self.training = False``. """ def __init__( self, sample_rate: int, duration: float = None, n_mels: int = None, n_frames: int = None, fbank_kwargs: Dict[str, Any] = None, spec_norm: bool = False, spec_mean: float = 0, spec_std: float = 1, freq_mask_param: int = None, time_mask_param: int = None, **kwargs, ) -> None: super().__init__(**kwargs) if fbank_kwargs is None: fbank_kwargs = {} if freq_mask_param is None or freq_mask_param == 0: self.freq_mask = None else: self.freq_mask = aT.FrequencyMasking(freq_mask_param) if time_mask_param is None or time_mask_param == 0: self.time_mask = None else: self.time_mask = aT.TimeMasking(time_mask_param) if sample_rate is None: sample_rate = fbank_kwargs.get("sample_frequency") else: if ( "sample_frequency" in fbank_kwargs.keys() and fbank_kwargs["sample_frequency"] is not None ): assert sample_rate == fbank_kwargs["sample_frequency"], ( "sample_rate should be equal to sample_frequency in fbank_kwargs." ) else: fbank_kwargs["sample_frequency"] = sample_rate if n_mels is None: n_mels = fbank_kwargs.get("num_mel_bins") else: if ( "num_mel_bins" in fbank_kwargs.keys() and fbank_kwargs["num_mel_bins"] is not None ): assert n_mels == fbank_kwargs["num_mel_bins"], ( "n_mels should be equal to num_mel_bins in fbank_kwargs." ) else: fbank_kwargs["num_mel_bins"] = n_mels super().__init__( sample_rate=sample_rate, duration=duration, n_mels=n_mels, n_frames=n_frames, fbank_kwargs=fbank_kwargs, spec_norm=spec_norm, spec_mean=spec_mean, spec_std=spec_std, **kwargs, ) self.training = True def __call__( self, waveform: torch.Tensor, sample_rate: Optional[int] = None ) -> torch.Tensor: """Forward pass of AudioSpectrogramTransformerFeatureExtractor. Args: waveform (torch.Tensor): Waveform of shape (batch_size, timesteps). Returns: torch.Tensor: Spectrogram of shape (batch_size, n_bins, n_frames). """ required_sample_rate = self.sample_rate required_duration = self.duration fbank_kwargs = self.fbank_kwargs assert waveform.dim() == 2 if sample_rate is None: warnings.warn("It is recommended to set sample_rate.", UserWarning) sample_rate = required_sample_rate else: if sample_rate != required_sample_rate: waveform = aF.resample(waveform, sample_rate, required_sample_rate) sample_rate = required_sample_rate if self.training: timesteps = waveform.size(-1) required_timesteps = int(sample_rate * required_duration) if timesteps <= required_timesteps: waveform = F.pad(waveform, (0, required_timesteps - timesteps)) else: start_idx = torch.randint(0, timesteps - required_timesteps, ()) sections = [ start_idx, required_timesteps, timesteps - required_timesteps - start_idx, ] _, waveform, _ = torch.split(waveform, sections, dim=-1) else: if required_duration is None: # Any shape of input is allowed. pass else: assert waveform.size(0) == 1, ( "Only batch size = 1 is supported during evaluation if required_duration is given." ) timesteps = waveform.size(-1) required_timesteps = int(sample_rate * required_duration) padding = required_timesteps - timesteps % required_timesteps padding = padding % required_timesteps valid_length = timesteps + padding num_segments = valid_length // required_timesteps waveform = F.pad(waveform, (0, padding)) waveform = waveform.view(num_segments, required_timesteps) waveform = waveform - waveform.mean(dim=-1, keepdim=True) if version.parse(torch.__version__) < version.parse("2.0.0"): spectrogram = self._sequential_fbank(waveform, **fbank_kwargs) else: spectrogram = self._parallel_fbank(waveform, **fbank_kwargs) if self.n_frames is not None: padding = self.n_frames - spectrogram.size(-1) spectrogram = F.pad(spectrogram, (0, padding)) if self.freq_mask is not None and self.training: spectrogram = self.freq_mask(spectrogram) if self.time_mask is not None and self.training: spectrogram = self.time_mask(spectrogram) if self.spec_norm: spectrogram = (spectrogram - self.spec_mean) / (2 * self.spec_std) return spectrogram def train(self, mode: bool = True) -> None: """Set training mode.""" self.training = mode if hasattr(self.freq_mask, "train") and callable(self.freq_mask.train): self.freq_mask.train(mode=mode) if hasattr(self.time_mask, "train") and callable(self.time_mask.train): self.time_mask.train(mode=mode) def eval(self) -> None: """Set evaluation mode.""" self.training = False if hasattr(self.freq_mask, "eval") and callable(self.freq_mask.eval): self.freq_mask.eval() if hasattr(self.time_mask, "eval") and callable(self.time_mask.eval): self.time_mask.eval() def to_dict(self) -> Dict[str, Any]: selialized = super().to_dict() freq_mask: aT.FrequencyMasking = selialized.pop("freq_mask") time_mask: aT.TimeMasking = selialized.pop("time_mask") if freq_mask is None: selialized["freq_mask"] = None else: selialized["freq_mask"] = { "mask_param": freq_mask.mask_param, } if time_mask is None: selialized["time_mask"] = None else: selialized["time_mask"] = { "mask_param": time_mask.mask_param, } return selialized @classmethod def from_dict(cls, feature_extractor_dict: Dict[str, Any], **kwargs) -> None: freq_mask_kwargs = feature_extractor_dict.pop("freq_mask") time_mask_kwargs = feature_extractor_dict.pop("time_mask") if freq_mask_kwargs is None: freq_mask_param = None else: freq_mask_param = freq_mask_kwargs["mask_param"] if time_mask_kwargs is None: time_mask_param = None else: time_mask_param = time_mask_kwargs["mask_param"] feature_extractor_dict["freq_mask_param"] = freq_mask_param feature_extractor_dict["time_mask_param"] = time_mask_param return_unused_kwargs = kwargs.pop("return_unused_kwargs", False) # Update feature_extractor with kwargs if needed to_remove = [] for key, value in kwargs.items(): if key in feature_extractor_dict: feature_extractor_dict[key] = value to_remove.append(key) for key in to_remove: kwargs.pop(key, None) feature_extractor = cls(**feature_extractor_dict) logger.info(f"Feature extractor {feature_extractor}") if return_unused_kwargs: return feature_extractor, kwargs else: return feature_extractor @staticmethod def _sequential_fbank(waveform: torch.Tensor, **kwargs) -> torch.Tensor: spectrogram = [] for _waveform in waveform: _spectrogram = _fbank_fn(_waveform, **kwargs) spectrogram.append(_spectrogram) spectrogram = torch.stack(spectrogram, dim=0) return spectrogram @staticmethod def _parallel_fbank(waveform: torch.Tensor, **kwargs) -> torch.Tensor: vfbank_fn = torch.vmap(functools.partial(_fbank_fn, **kwargs)) spectrogram = vfbank_fn(waveform) return spectrogram class GLuCoSETokenizer(PreTrainedTokenizer): """Processor to tokenize text for MuLan. Args: model_name_or_path (sr): Name of tokenizer. padding (bool, optional): If ``True``, padding is applied to tokenization output. truncation (bool, optional): If ``True``, truncation is applied to tokenization output. tokenizer_kwargs (dict, optional): Keyward arguments given to tokenizer. """ def __init__( self, model_name_or_path: str = None, padding: Optional[bool] = None, truncation: Optional[bool] = None, tokenizer_kwargs: Optional[Dict[str, Any]] = None, **kwargs, ) -> None: if tokenizer_kwargs is None: tokenizer_kwargs = {} if model_name_or_path == "pkshatech/GLuCoSE-base-ja": if padding is None: padding = True if truncation is None: truncation = True else: raise NotImplementedError( f"{model_name_or_path} is not supported as model_name_or_path." ) self.tokenizer = AutoTokenizer.from_pretrained( model_name_or_path, **tokenizer_kwargs ) super().__init__( model_name_or_path=model_name_or_path, padding=padding, truncation=truncation, **kwargs, ) self.padding = padding self.truncation = truncation self.model_max_length = self.tokenizer.model_max_length def __call__(self, text: List[str], return_tensors: str = "pt") -> BatchEncoding: padding = self.padding truncation = self.truncation output = self.tokenizer( text, padding=padding, truncation=truncation, return_tensors=return_tensors ) return output @property def vocab_size(self) -> int: return self.tokenizer.vocab_size def get_vocab(self) -> Dict[str, int]: return self.tokenizer.get_vocab() def save_vocabulary( self, save_directory: str, filename_prefix: Optional[str] = None ) -> tuple[str]: return self.tokenizer.save_vocabulary( save_directory, filename_prefix=filename_prefix ) def _tokenize(self, text, **kwargs) -> Dict[str, int]: return self.tokenizer._tokenize(text, **kwargs) def _convert_token_to_id(self, token: str) -> int: return self.tokenizer._convert_token_to_id(token) def _convert_id_to_token(self, index: int) -> str: return self.tokenizer._convert_id_to_token(index) def _fbank_fn(waveform: torch.Tensor, **kwargs) -> torch.Tensor: """Wrapper function of torchaudio.compliance.kaldi.fbank. Args: waveform (torch.Tensor): Waveform of shape (time,). kwargs: Keyword arguments given to torchaudio.compliance.kaldi.fbank. Returns: torch.Tensor: Spectrogram of shape (n_mels, n_frames). """ # torchaudio.compliance.kaldi.fbank accepts tensor # of shape (n_channels, time). waveform = waveform.unsqueeze(dim=0) spectrogram = aCK.fbank(waveform, **kwargs) # (n_frames, n_mels) -> (n_mels, n_frames) spectrogram = spectrogram.transpose(1, 0) return spectrogram AutoFeatureExtractor.register( "AudioSpectrogramTransformerFeatureExtractor", AudioSpectrogramTransformerFeatureExtractor, ) AutoTokenizer.register("GLuCoSETokenizer", GLuCoSETokenizer) AutoProcessor.register("MuLanProcessor", MuLanProcessor)