japanese-mulan-base / processor.py
tky823's picture
Upload processor.py with huggingface_hub
fd7d75e verified
# Copyright 2025 LY Corporation
import functools
import warnings
from typing import Any, Dict, List, Optional
import torch
import torch.nn.functional as F
import torchaudio.compliance.kaldi as aCK
import torchaudio.functional as aF
import torchaudio.transforms as aT
from packaging import version
from transformers import (
AutoFeatureExtractor,
AutoProcessor,
AutoTokenizer,
BatchEncoding,
BatchFeature,
FeatureExtractionMixin,
PreTrainedTokenizer,
ProcessorMixin,
)
from transformers.utils import logging
logger = logging.get_logger(__name__)
class MuLanProcessor(ProcessorMixin):
attributes = ["music_processor", "text_processor"]
music_processor: "AudioSpectrogramTransformerFeatureExtractor"
text_processor: "GLuCoSETokenizer"
music_processor_class = "AudioSpectrogramTransformerFeatureExtractor"
text_processor_class = "GLuCoSETokenizer"
def __init__(
self,
music_processor: "AudioSpectrogramTransformerFeatureExtractor",
text_processor: "GLuCoSETokenizer",
**kwargs,
) -> None:
super().__init__(music_processor, text_processor, **kwargs)
def __call__(
self, audio: torch.Tensor, text: List[str], sample_rate: Optional[int] = None
) -> BatchFeature:
music_features = self.get_music_feature(audio, sample_rate=sample_rate)
text_features = self.get_text_feature(text)
return BatchFeature(
{
"music": music_features,
"text": text_features,
}
)
def get_music_feature(
self, audio: torch.Tensor, sample_rate: Optional[int] = None
) -> BatchFeature:
"""Get music feature from audio.
Args:
audio (torch.Tensor): Audio waveform of shape (batch_size, timesteps).
sample_rate (int, optional): Sampling rate of audio.
Returns:
BatchFeature: Batched music feature.
"""
spectrogram = self.music_processor(audio, sample_rate=sample_rate)
music_features = BatchFeature({"spectrogram": spectrogram})
return music_features
def get_text_feature(self, text: List[str]) -> BatchEncoding:
"""Get text feature from text.
Args:
text (list): Text to be tokenized.
Returns:
BatchEncoding: Batch encoding of text feature.
"""
text_features = self.text_processor(text)
return text_features
def train(self, mode: bool = True) -> None:
"""Set training mode."""
if hasattr(self.music_processor, "train") and callable(
self.music_processor.train
):
self.music_processor.train(mode=mode)
if hasattr(self.text_processor, "train") and callable(
self.text_processor.train
):
self.text_processor.train(mode=mode)
def eval(self) -> None:
"""Set evaluation mode."""
if hasattr(self.music_processor, "eval") and callable(
self.music_processor.eval
):
self.music_processor.eval()
if hasattr(self.text_processor, "eval") and callable(self.text_processor.eval):
self.text_processor.eval()
class AudioSpectrogramTransformerFeatureExtractor(FeatureExtractionMixin):
"""Audio processor for official implementation of audio spectrogram transformer.
Args:
sample_rate (int): Sampling rate.
duration (float): Duration of audio in seconds.
n_mels (int): Number of Mel-frequency bins.
n_frames (int): Number of time frames.
fbank_kwargs (int): Keyword arguments given to ``torchaudio.compliance.kaldi.fbank``.
spec_norm (bool): Whether to apply normalization to spectrogram. If ``True``,
input spectrogram is normalized by ``spec_mean`` and ``spec_std``.
spec_mean (float): Mean of spectrogram.
spec_std (float): Standard deviation of spectrogram.
freq_mask_param (int): Parameter given to torchaudio.FrequencyMasking.
time_mask_param (int): Parameter given to torchaudio.TimeMasking.
.. note::
Frequency and time maskings are deactivated when ``self.training = False``.
"""
def __init__(
self,
sample_rate: int,
duration: float = None,
n_mels: int = None,
n_frames: int = None,
fbank_kwargs: Dict[str, Any] = None,
spec_norm: bool = False,
spec_mean: float = 0,
spec_std: float = 1,
freq_mask_param: int = None,
time_mask_param: int = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
if fbank_kwargs is None:
fbank_kwargs = {}
if freq_mask_param is None or freq_mask_param == 0:
self.freq_mask = None
else:
self.freq_mask = aT.FrequencyMasking(freq_mask_param)
if time_mask_param is None or time_mask_param == 0:
self.time_mask = None
else:
self.time_mask = aT.TimeMasking(time_mask_param)
if sample_rate is None:
sample_rate = fbank_kwargs.get("sample_frequency")
else:
if (
"sample_frequency" in fbank_kwargs.keys()
and fbank_kwargs["sample_frequency"] is not None
):
assert sample_rate == fbank_kwargs["sample_frequency"], (
"sample_rate should be equal to sample_frequency in fbank_kwargs."
)
else:
fbank_kwargs["sample_frequency"] = sample_rate
if n_mels is None:
n_mels = fbank_kwargs.get("num_mel_bins")
else:
if (
"num_mel_bins" in fbank_kwargs.keys()
and fbank_kwargs["num_mel_bins"] is not None
):
assert n_mels == fbank_kwargs["num_mel_bins"], (
"n_mels should be equal to num_mel_bins in fbank_kwargs."
)
else:
fbank_kwargs["num_mel_bins"] = n_mels
super().__init__(
sample_rate=sample_rate,
duration=duration,
n_mels=n_mels,
n_frames=n_frames,
fbank_kwargs=fbank_kwargs,
spec_norm=spec_norm,
spec_mean=spec_mean,
spec_std=spec_std,
**kwargs,
)
self.training = True
def __call__(
self, waveform: torch.Tensor, sample_rate: Optional[int] = None
) -> torch.Tensor:
"""Forward pass of AudioSpectrogramTransformerFeatureExtractor.
Args:
waveform (torch.Tensor): Waveform of shape (batch_size, timesteps).
Returns:
torch.Tensor: Spectrogram of shape (batch_size, n_bins, n_frames).
"""
required_sample_rate = self.sample_rate
required_duration = self.duration
fbank_kwargs = self.fbank_kwargs
assert waveform.dim() == 2
if sample_rate is None:
warnings.warn("It is recommended to set sample_rate.", UserWarning)
sample_rate = required_sample_rate
else:
if sample_rate != required_sample_rate:
waveform = aF.resample(waveform, sample_rate, required_sample_rate)
sample_rate = required_sample_rate
if self.training:
timesteps = waveform.size(-1)
required_timesteps = int(sample_rate * required_duration)
if timesteps <= required_timesteps:
waveform = F.pad(waveform, (0, required_timesteps - timesteps))
else:
start_idx = torch.randint(0, timesteps - required_timesteps, ())
sections = [
start_idx,
required_timesteps,
timesteps - required_timesteps - start_idx,
]
_, waveform, _ = torch.split(waveform, sections, dim=-1)
else:
if required_duration is None:
# Any shape of input is allowed.
pass
else:
assert waveform.size(0) == 1, (
"Only batch size = 1 is supported during evaluation if required_duration is given."
)
timesteps = waveform.size(-1)
required_timesteps = int(sample_rate * required_duration)
padding = required_timesteps - timesteps % required_timesteps
padding = padding % required_timesteps
valid_length = timesteps + padding
num_segments = valid_length // required_timesteps
waveform = F.pad(waveform, (0, padding))
waveform = waveform.view(num_segments, required_timesteps)
waveform = waveform - waveform.mean(dim=-1, keepdim=True)
if version.parse(torch.__version__) < version.parse("2.0.0"):
spectrogram = self._sequential_fbank(waveform, **fbank_kwargs)
else:
spectrogram = self._parallel_fbank(waveform, **fbank_kwargs)
if self.n_frames is not None:
padding = self.n_frames - spectrogram.size(-1)
spectrogram = F.pad(spectrogram, (0, padding))
if self.freq_mask is not None and self.training:
spectrogram = self.freq_mask(spectrogram)
if self.time_mask is not None and self.training:
spectrogram = self.time_mask(spectrogram)
if self.spec_norm:
spectrogram = (spectrogram - self.spec_mean) / (2 * self.spec_std)
return spectrogram
def train(self, mode: bool = True) -> None:
"""Set training mode."""
self.training = mode
if hasattr(self.freq_mask, "train") and callable(self.freq_mask.train):
self.freq_mask.train(mode=mode)
if hasattr(self.time_mask, "train") and callable(self.time_mask.train):
self.time_mask.train(mode=mode)
def eval(self) -> None:
"""Set evaluation mode."""
self.training = False
if hasattr(self.freq_mask, "eval") and callable(self.freq_mask.eval):
self.freq_mask.eval()
if hasattr(self.time_mask, "eval") and callable(self.time_mask.eval):
self.time_mask.eval()
def to_dict(self) -> Dict[str, Any]:
selialized = super().to_dict()
freq_mask: aT.FrequencyMasking = selialized.pop("freq_mask")
time_mask: aT.TimeMasking = selialized.pop("time_mask")
if freq_mask is None:
selialized["freq_mask"] = None
else:
selialized["freq_mask"] = {
"mask_param": freq_mask.mask_param,
}
if time_mask is None:
selialized["time_mask"] = None
else:
selialized["time_mask"] = {
"mask_param": time_mask.mask_param,
}
return selialized
@classmethod
def from_dict(cls, feature_extractor_dict: Dict[str, Any], **kwargs) -> None:
freq_mask_kwargs = feature_extractor_dict.pop("freq_mask")
time_mask_kwargs = feature_extractor_dict.pop("time_mask")
if freq_mask_kwargs is None:
freq_mask_param = None
else:
freq_mask_param = freq_mask_kwargs["mask_param"]
if time_mask_kwargs is None:
time_mask_param = None
else:
time_mask_param = time_mask_kwargs["mask_param"]
feature_extractor_dict["freq_mask_param"] = freq_mask_param
feature_extractor_dict["time_mask_param"] = time_mask_param
return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)
# Update feature_extractor with kwargs if needed
to_remove = []
for key, value in kwargs.items():
if key in feature_extractor_dict:
feature_extractor_dict[key] = value
to_remove.append(key)
for key in to_remove:
kwargs.pop(key, None)
feature_extractor = cls(**feature_extractor_dict)
logger.info(f"Feature extractor {feature_extractor}")
if return_unused_kwargs:
return feature_extractor, kwargs
else:
return feature_extractor
@staticmethod
def _sequential_fbank(waveform: torch.Tensor, **kwargs) -> torch.Tensor:
spectrogram = []
for _waveform in waveform:
_spectrogram = _fbank_fn(_waveform, **kwargs)
spectrogram.append(_spectrogram)
spectrogram = torch.stack(spectrogram, dim=0)
return spectrogram
@staticmethod
def _parallel_fbank(waveform: torch.Tensor, **kwargs) -> torch.Tensor:
vfbank_fn = torch.vmap(functools.partial(_fbank_fn, **kwargs))
spectrogram = vfbank_fn(waveform)
return spectrogram
class GLuCoSETokenizer(PreTrainedTokenizer):
"""Processor to tokenize text for MuLan.
Args:
model_name_or_path (sr): Name of tokenizer.
padding (bool, optional): If ``True``, padding is applied to tokenization output.
truncation (bool, optional): If ``True``, truncation is applied to tokenization output.
tokenizer_kwargs (dict, optional): Keyward arguments given to tokenizer.
"""
def __init__(
self,
model_name_or_path: str = None,
padding: Optional[bool] = None,
truncation: Optional[bool] = None,
tokenizer_kwargs: Optional[Dict[str, Any]] = None,
**kwargs,
) -> None:
if tokenizer_kwargs is None:
tokenizer_kwargs = {}
if model_name_or_path == "pkshatech/GLuCoSE-base-ja":
if padding is None:
padding = True
if truncation is None:
truncation = True
else:
raise NotImplementedError(
f"{model_name_or_path} is not supported as model_name_or_path."
)
self.tokenizer = AutoTokenizer.from_pretrained(
model_name_or_path, **tokenizer_kwargs
)
super().__init__(
model_name_or_path=model_name_or_path,
padding=padding,
truncation=truncation,
**kwargs,
)
self.padding = padding
self.truncation = truncation
self.model_max_length = self.tokenizer.model_max_length
def __call__(self, text: List[str], return_tensors: str = "pt") -> BatchEncoding:
padding = self.padding
truncation = self.truncation
output = self.tokenizer(
text, padding=padding, truncation=truncation, return_tensors=return_tensors
)
return output
@property
def vocab_size(self) -> int:
return self.tokenizer.vocab_size
def get_vocab(self) -> Dict[str, int]:
return self.tokenizer.get_vocab()
def save_vocabulary(
self, save_directory: str, filename_prefix: Optional[str] = None
) -> tuple[str]:
return self.tokenizer.save_vocabulary(
save_directory, filename_prefix=filename_prefix
)
def _tokenize(self, text, **kwargs) -> Dict[str, int]:
return self.tokenizer._tokenize(text, **kwargs)
def _convert_token_to_id(self, token: str) -> int:
return self.tokenizer._convert_token_to_id(token)
def _convert_id_to_token(self, index: int) -> str:
return self.tokenizer._convert_id_to_token(index)
def _fbank_fn(waveform: torch.Tensor, **kwargs) -> torch.Tensor:
"""Wrapper function of torchaudio.compliance.kaldi.fbank.
Args:
waveform (torch.Tensor): Waveform of shape (time,).
kwargs: Keyword arguments given to torchaudio.compliance.kaldi.fbank.
Returns:
torch.Tensor: Spectrogram of shape (n_mels, n_frames).
"""
# torchaudio.compliance.kaldi.fbank accepts tensor
# of shape (n_channels, time).
waveform = waveform.unsqueeze(dim=0)
spectrogram = aCK.fbank(waveform, **kwargs)
# (n_frames, n_mels) -> (n_mels, n_frames)
spectrogram = spectrogram.transpose(1, 0)
return spectrogram
AutoFeatureExtractor.register(
"AudioSpectrogramTransformerFeatureExtractor",
AudioSpectrogramTransformerFeatureExtractor,
)
AutoTokenizer.register("GLuCoSETokenizer", GLuCoSETokenizer)
AutoProcessor.register("MuLanProcessor", MuLanProcessor)