|
|
|
|
|
import functools |
|
|
import warnings |
|
|
from typing import Any, Dict, List, Optional |
|
|
|
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
import torchaudio.compliance.kaldi as aCK |
|
|
import torchaudio.functional as aF |
|
|
import torchaudio.transforms as aT |
|
|
from packaging import version |
|
|
from transformers import ( |
|
|
AutoFeatureExtractor, |
|
|
AutoProcessor, |
|
|
AutoTokenizer, |
|
|
BatchEncoding, |
|
|
BatchFeature, |
|
|
FeatureExtractionMixin, |
|
|
PreTrainedTokenizer, |
|
|
ProcessorMixin, |
|
|
) |
|
|
from transformers.utils import logging |
|
|
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
|
|
|
class MuLanProcessor(ProcessorMixin): |
|
|
attributes = ["music_processor", "text_processor"] |
|
|
|
|
|
music_processor: "AudioSpectrogramTransformerFeatureExtractor" |
|
|
text_processor: "GLuCoSETokenizer" |
|
|
|
|
|
music_processor_class = "AudioSpectrogramTransformerFeatureExtractor" |
|
|
text_processor_class = "GLuCoSETokenizer" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
music_processor: "AudioSpectrogramTransformerFeatureExtractor", |
|
|
text_processor: "GLuCoSETokenizer", |
|
|
**kwargs, |
|
|
) -> None: |
|
|
super().__init__(music_processor, text_processor, **kwargs) |
|
|
|
|
|
def __call__( |
|
|
self, audio: torch.Tensor, text: List[str], sample_rate: Optional[int] = None |
|
|
) -> BatchFeature: |
|
|
music_features = self.get_music_feature(audio, sample_rate=sample_rate) |
|
|
text_features = self.get_text_feature(text) |
|
|
|
|
|
return BatchFeature( |
|
|
{ |
|
|
"music": music_features, |
|
|
"text": text_features, |
|
|
} |
|
|
) |
|
|
|
|
|
def get_music_feature( |
|
|
self, audio: torch.Tensor, sample_rate: Optional[int] = None |
|
|
) -> BatchFeature: |
|
|
"""Get music feature from audio. |
|
|
|
|
|
Args: |
|
|
audio (torch.Tensor): Audio waveform of shape (batch_size, timesteps). |
|
|
sample_rate (int, optional): Sampling rate of audio. |
|
|
|
|
|
Returns: |
|
|
BatchFeature: Batched music feature. |
|
|
|
|
|
""" |
|
|
spectrogram = self.music_processor(audio, sample_rate=sample_rate) |
|
|
music_features = BatchFeature({"spectrogram": spectrogram}) |
|
|
|
|
|
return music_features |
|
|
|
|
|
def get_text_feature(self, text: List[str]) -> BatchEncoding: |
|
|
"""Get text feature from text. |
|
|
|
|
|
Args: |
|
|
text (list): Text to be tokenized. |
|
|
|
|
|
Returns: |
|
|
BatchEncoding: Batch encoding of text feature. |
|
|
|
|
|
""" |
|
|
text_features = self.text_processor(text) |
|
|
|
|
|
return text_features |
|
|
|
|
|
def train(self, mode: bool = True) -> None: |
|
|
"""Set training mode.""" |
|
|
if hasattr(self.music_processor, "train") and callable( |
|
|
self.music_processor.train |
|
|
): |
|
|
self.music_processor.train(mode=mode) |
|
|
|
|
|
if hasattr(self.text_processor, "train") and callable( |
|
|
self.text_processor.train |
|
|
): |
|
|
self.text_processor.train(mode=mode) |
|
|
|
|
|
def eval(self) -> None: |
|
|
"""Set evaluation mode.""" |
|
|
if hasattr(self.music_processor, "eval") and callable( |
|
|
self.music_processor.eval |
|
|
): |
|
|
self.music_processor.eval() |
|
|
|
|
|
if hasattr(self.text_processor, "eval") and callable(self.text_processor.eval): |
|
|
self.text_processor.eval() |
|
|
|
|
|
|
|
|
class AudioSpectrogramTransformerFeatureExtractor(FeatureExtractionMixin): |
|
|
"""Audio processor for official implementation of audio spectrogram transformer. |
|
|
|
|
|
Args: |
|
|
sample_rate (int): Sampling rate. |
|
|
duration (float): Duration of audio in seconds. |
|
|
n_mels (int): Number of Mel-frequency bins. |
|
|
n_frames (int): Number of time frames. |
|
|
fbank_kwargs (int): Keyword arguments given to ``torchaudio.compliance.kaldi.fbank``. |
|
|
spec_norm (bool): Whether to apply normalization to spectrogram. If ``True``, |
|
|
input spectrogram is normalized by ``spec_mean`` and ``spec_std``. |
|
|
spec_mean (float): Mean of spectrogram. |
|
|
spec_std (float): Standard deviation of spectrogram. |
|
|
freq_mask_param (int): Parameter given to torchaudio.FrequencyMasking. |
|
|
time_mask_param (int): Parameter given to torchaudio.TimeMasking. |
|
|
|
|
|
.. note:: |
|
|
|
|
|
Frequency and time maskings are deactivated when ``self.training = False``. |
|
|
|
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
sample_rate: int, |
|
|
duration: float = None, |
|
|
n_mels: int = None, |
|
|
n_frames: int = None, |
|
|
fbank_kwargs: Dict[str, Any] = None, |
|
|
spec_norm: bool = False, |
|
|
spec_mean: float = 0, |
|
|
spec_std: float = 1, |
|
|
freq_mask_param: int = None, |
|
|
time_mask_param: int = None, |
|
|
**kwargs, |
|
|
) -> None: |
|
|
super().__init__(**kwargs) |
|
|
|
|
|
if fbank_kwargs is None: |
|
|
fbank_kwargs = {} |
|
|
|
|
|
if freq_mask_param is None or freq_mask_param == 0: |
|
|
self.freq_mask = None |
|
|
else: |
|
|
self.freq_mask = aT.FrequencyMasking(freq_mask_param) |
|
|
|
|
|
if time_mask_param is None or time_mask_param == 0: |
|
|
self.time_mask = None |
|
|
else: |
|
|
self.time_mask = aT.TimeMasking(time_mask_param) |
|
|
|
|
|
if sample_rate is None: |
|
|
sample_rate = fbank_kwargs.get("sample_frequency") |
|
|
else: |
|
|
if ( |
|
|
"sample_frequency" in fbank_kwargs.keys() |
|
|
and fbank_kwargs["sample_frequency"] is not None |
|
|
): |
|
|
assert sample_rate == fbank_kwargs["sample_frequency"], ( |
|
|
"sample_rate should be equal to sample_frequency in fbank_kwargs." |
|
|
) |
|
|
else: |
|
|
fbank_kwargs["sample_frequency"] = sample_rate |
|
|
|
|
|
if n_mels is None: |
|
|
n_mels = fbank_kwargs.get("num_mel_bins") |
|
|
else: |
|
|
if ( |
|
|
"num_mel_bins" in fbank_kwargs.keys() |
|
|
and fbank_kwargs["num_mel_bins"] is not None |
|
|
): |
|
|
assert n_mels == fbank_kwargs["num_mel_bins"], ( |
|
|
"n_mels should be equal to num_mel_bins in fbank_kwargs." |
|
|
) |
|
|
else: |
|
|
fbank_kwargs["num_mel_bins"] = n_mels |
|
|
|
|
|
super().__init__( |
|
|
sample_rate=sample_rate, |
|
|
duration=duration, |
|
|
n_mels=n_mels, |
|
|
n_frames=n_frames, |
|
|
fbank_kwargs=fbank_kwargs, |
|
|
spec_norm=spec_norm, |
|
|
spec_mean=spec_mean, |
|
|
spec_std=spec_std, |
|
|
**kwargs, |
|
|
) |
|
|
|
|
|
self.training = True |
|
|
|
|
|
def __call__( |
|
|
self, waveform: torch.Tensor, sample_rate: Optional[int] = None |
|
|
) -> torch.Tensor: |
|
|
"""Forward pass of AudioSpectrogramTransformerFeatureExtractor. |
|
|
|
|
|
Args: |
|
|
waveform (torch.Tensor): Waveform of shape (batch_size, timesteps). |
|
|
|
|
|
Returns: |
|
|
torch.Tensor: Spectrogram of shape (batch_size, n_bins, n_frames). |
|
|
|
|
|
""" |
|
|
required_sample_rate = self.sample_rate |
|
|
required_duration = self.duration |
|
|
fbank_kwargs = self.fbank_kwargs |
|
|
|
|
|
assert waveform.dim() == 2 |
|
|
|
|
|
if sample_rate is None: |
|
|
warnings.warn("It is recommended to set sample_rate.", UserWarning) |
|
|
|
|
|
sample_rate = required_sample_rate |
|
|
else: |
|
|
if sample_rate != required_sample_rate: |
|
|
waveform = aF.resample(waveform, sample_rate, required_sample_rate) |
|
|
sample_rate = required_sample_rate |
|
|
|
|
|
if self.training: |
|
|
timesteps = waveform.size(-1) |
|
|
required_timesteps = int(sample_rate * required_duration) |
|
|
|
|
|
if timesteps <= required_timesteps: |
|
|
waveform = F.pad(waveform, (0, required_timesteps - timesteps)) |
|
|
else: |
|
|
start_idx = torch.randint(0, timesteps - required_timesteps, ()) |
|
|
sections = [ |
|
|
start_idx, |
|
|
required_timesteps, |
|
|
timesteps - required_timesteps - start_idx, |
|
|
] |
|
|
_, waveform, _ = torch.split(waveform, sections, dim=-1) |
|
|
else: |
|
|
if required_duration is None: |
|
|
|
|
|
pass |
|
|
else: |
|
|
assert waveform.size(0) == 1, ( |
|
|
"Only batch size = 1 is supported during evaluation if required_duration is given." |
|
|
) |
|
|
|
|
|
timesteps = waveform.size(-1) |
|
|
required_timesteps = int(sample_rate * required_duration) |
|
|
padding = required_timesteps - timesteps % required_timesteps |
|
|
padding = padding % required_timesteps |
|
|
valid_length = timesteps + padding |
|
|
num_segments = valid_length // required_timesteps |
|
|
|
|
|
waveform = F.pad(waveform, (0, padding)) |
|
|
waveform = waveform.view(num_segments, required_timesteps) |
|
|
|
|
|
waveform = waveform - waveform.mean(dim=-1, keepdim=True) |
|
|
|
|
|
if version.parse(torch.__version__) < version.parse("2.0.0"): |
|
|
spectrogram = self._sequential_fbank(waveform, **fbank_kwargs) |
|
|
else: |
|
|
spectrogram = self._parallel_fbank(waveform, **fbank_kwargs) |
|
|
|
|
|
if self.n_frames is not None: |
|
|
padding = self.n_frames - spectrogram.size(-1) |
|
|
spectrogram = F.pad(spectrogram, (0, padding)) |
|
|
|
|
|
if self.freq_mask is not None and self.training: |
|
|
spectrogram = self.freq_mask(spectrogram) |
|
|
|
|
|
if self.time_mask is not None and self.training: |
|
|
spectrogram = self.time_mask(spectrogram) |
|
|
|
|
|
if self.spec_norm: |
|
|
spectrogram = (spectrogram - self.spec_mean) / (2 * self.spec_std) |
|
|
|
|
|
return spectrogram |
|
|
|
|
|
def train(self, mode: bool = True) -> None: |
|
|
"""Set training mode.""" |
|
|
self.training = mode |
|
|
|
|
|
if hasattr(self.freq_mask, "train") and callable(self.freq_mask.train): |
|
|
self.freq_mask.train(mode=mode) |
|
|
|
|
|
if hasattr(self.time_mask, "train") and callable(self.time_mask.train): |
|
|
self.time_mask.train(mode=mode) |
|
|
|
|
|
def eval(self) -> None: |
|
|
"""Set evaluation mode.""" |
|
|
self.training = False |
|
|
|
|
|
if hasattr(self.freq_mask, "eval") and callable(self.freq_mask.eval): |
|
|
self.freq_mask.eval() |
|
|
|
|
|
if hasattr(self.time_mask, "eval") and callable(self.time_mask.eval): |
|
|
self.time_mask.eval() |
|
|
|
|
|
def to_dict(self) -> Dict[str, Any]: |
|
|
selialized = super().to_dict() |
|
|
|
|
|
freq_mask: aT.FrequencyMasking = selialized.pop("freq_mask") |
|
|
time_mask: aT.TimeMasking = selialized.pop("time_mask") |
|
|
|
|
|
if freq_mask is None: |
|
|
selialized["freq_mask"] = None |
|
|
else: |
|
|
selialized["freq_mask"] = { |
|
|
"mask_param": freq_mask.mask_param, |
|
|
} |
|
|
|
|
|
if time_mask is None: |
|
|
selialized["time_mask"] = None |
|
|
else: |
|
|
selialized["time_mask"] = { |
|
|
"mask_param": time_mask.mask_param, |
|
|
} |
|
|
|
|
|
return selialized |
|
|
|
|
|
@classmethod |
|
|
def from_dict(cls, feature_extractor_dict: Dict[str, Any], **kwargs) -> None: |
|
|
freq_mask_kwargs = feature_extractor_dict.pop("freq_mask") |
|
|
time_mask_kwargs = feature_extractor_dict.pop("time_mask") |
|
|
|
|
|
if freq_mask_kwargs is None: |
|
|
freq_mask_param = None |
|
|
else: |
|
|
freq_mask_param = freq_mask_kwargs["mask_param"] |
|
|
|
|
|
if time_mask_kwargs is None: |
|
|
time_mask_param = None |
|
|
else: |
|
|
time_mask_param = time_mask_kwargs["mask_param"] |
|
|
|
|
|
feature_extractor_dict["freq_mask_param"] = freq_mask_param |
|
|
feature_extractor_dict["time_mask_param"] = time_mask_param |
|
|
|
|
|
return_unused_kwargs = kwargs.pop("return_unused_kwargs", False) |
|
|
|
|
|
|
|
|
to_remove = [] |
|
|
for key, value in kwargs.items(): |
|
|
if key in feature_extractor_dict: |
|
|
feature_extractor_dict[key] = value |
|
|
to_remove.append(key) |
|
|
for key in to_remove: |
|
|
kwargs.pop(key, None) |
|
|
|
|
|
feature_extractor = cls(**feature_extractor_dict) |
|
|
|
|
|
logger.info(f"Feature extractor {feature_extractor}") |
|
|
|
|
|
if return_unused_kwargs: |
|
|
return feature_extractor, kwargs |
|
|
else: |
|
|
return feature_extractor |
|
|
|
|
|
@staticmethod |
|
|
def _sequential_fbank(waveform: torch.Tensor, **kwargs) -> torch.Tensor: |
|
|
spectrogram = [] |
|
|
|
|
|
for _waveform in waveform: |
|
|
_spectrogram = _fbank_fn(_waveform, **kwargs) |
|
|
spectrogram.append(_spectrogram) |
|
|
|
|
|
spectrogram = torch.stack(spectrogram, dim=0) |
|
|
|
|
|
return spectrogram |
|
|
|
|
|
@staticmethod |
|
|
def _parallel_fbank(waveform: torch.Tensor, **kwargs) -> torch.Tensor: |
|
|
vfbank_fn = torch.vmap(functools.partial(_fbank_fn, **kwargs)) |
|
|
spectrogram = vfbank_fn(waveform) |
|
|
|
|
|
return spectrogram |
|
|
|
|
|
|
|
|
class GLuCoSETokenizer(PreTrainedTokenizer): |
|
|
"""Processor to tokenize text for MuLan. |
|
|
|
|
|
Args: |
|
|
model_name_or_path (sr): Name of tokenizer. |
|
|
padding (bool, optional): If ``True``, padding is applied to tokenization output. |
|
|
truncation (bool, optional): If ``True``, truncation is applied to tokenization output. |
|
|
tokenizer_kwargs (dict, optional): Keyward arguments given to tokenizer. |
|
|
|
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
model_name_or_path: str = None, |
|
|
padding: Optional[bool] = None, |
|
|
truncation: Optional[bool] = None, |
|
|
tokenizer_kwargs: Optional[Dict[str, Any]] = None, |
|
|
**kwargs, |
|
|
) -> None: |
|
|
if tokenizer_kwargs is None: |
|
|
tokenizer_kwargs = {} |
|
|
|
|
|
if model_name_or_path == "pkshatech/GLuCoSE-base-ja": |
|
|
if padding is None: |
|
|
padding = True |
|
|
|
|
|
if truncation is None: |
|
|
truncation = True |
|
|
else: |
|
|
raise NotImplementedError( |
|
|
f"{model_name_or_path} is not supported as model_name_or_path." |
|
|
) |
|
|
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained( |
|
|
model_name_or_path, **tokenizer_kwargs |
|
|
) |
|
|
|
|
|
super().__init__( |
|
|
model_name_or_path=model_name_or_path, |
|
|
padding=padding, |
|
|
truncation=truncation, |
|
|
**kwargs, |
|
|
) |
|
|
|
|
|
self.padding = padding |
|
|
self.truncation = truncation |
|
|
|
|
|
self.model_max_length = self.tokenizer.model_max_length |
|
|
|
|
|
def __call__(self, text: List[str], return_tensors: str = "pt") -> BatchEncoding: |
|
|
padding = self.padding |
|
|
truncation = self.truncation |
|
|
|
|
|
output = self.tokenizer( |
|
|
text, padding=padding, truncation=truncation, return_tensors=return_tensors |
|
|
) |
|
|
|
|
|
return output |
|
|
|
|
|
@property |
|
|
def vocab_size(self) -> int: |
|
|
return self.tokenizer.vocab_size |
|
|
|
|
|
def get_vocab(self) -> Dict[str, int]: |
|
|
return self.tokenizer.get_vocab() |
|
|
|
|
|
def save_vocabulary( |
|
|
self, save_directory: str, filename_prefix: Optional[str] = None |
|
|
) -> tuple[str]: |
|
|
return self.tokenizer.save_vocabulary( |
|
|
save_directory, filename_prefix=filename_prefix |
|
|
) |
|
|
|
|
|
def _tokenize(self, text, **kwargs) -> Dict[str, int]: |
|
|
return self.tokenizer._tokenize(text, **kwargs) |
|
|
|
|
|
def _convert_token_to_id(self, token: str) -> int: |
|
|
return self.tokenizer._convert_token_to_id(token) |
|
|
|
|
|
def _convert_id_to_token(self, index: int) -> str: |
|
|
return self.tokenizer._convert_id_to_token(index) |
|
|
|
|
|
|
|
|
def _fbank_fn(waveform: torch.Tensor, **kwargs) -> torch.Tensor: |
|
|
"""Wrapper function of torchaudio.compliance.kaldi.fbank. |
|
|
|
|
|
Args: |
|
|
waveform (torch.Tensor): Waveform of shape (time,). |
|
|
kwargs: Keyword arguments given to torchaudio.compliance.kaldi.fbank. |
|
|
|
|
|
Returns: |
|
|
torch.Tensor: Spectrogram of shape (n_mels, n_frames). |
|
|
|
|
|
""" |
|
|
|
|
|
|
|
|
waveform = waveform.unsqueeze(dim=0) |
|
|
|
|
|
spectrogram = aCK.fbank(waveform, **kwargs) |
|
|
|
|
|
|
|
|
spectrogram = spectrogram.transpose(1, 0) |
|
|
|
|
|
return spectrogram |
|
|
|
|
|
|
|
|
AutoFeatureExtractor.register( |
|
|
"AudioSpectrogramTransformerFeatureExtractor", |
|
|
AudioSpectrogramTransformerFeatureExtractor, |
|
|
) |
|
|
AutoTokenizer.register("GLuCoSETokenizer", GLuCoSETokenizer) |
|
|
AutoProcessor.register("MuLanProcessor", MuLanProcessor) |
|
|
|