Spaces:
Runtime error
Runtime error
| import math | |
| from typing import Callable | |
| import torch | |
| import torch.nn as nn | |
| from torchlibrosa.stft import STFT | |
| from bytesep.models.pytorch_modules import Base | |
| def l1(output: torch.Tensor, target: torch.Tensor, **kwargs) -> torch.Tensor: | |
| r"""L1 loss. | |
| Args: | |
| output: torch.Tensor | |
| target: torch.Tensor | |
| Returns: | |
| loss: torch.float | |
| """ | |
| return torch.mean(torch.abs(output - target)) | |
| def l1_wav(output: torch.Tensor, target: torch.Tensor, **kwargs) -> torch.Tensor: | |
| r"""L1 loss in the time-domain. | |
| Args: | |
| output: torch.Tensor | |
| target: torch.Tensor | |
| Returns: | |
| loss: torch.float | |
| """ | |
| return l1(output, target) | |
| class L1_Wav_L1_Sp(nn.Module, Base): | |
| def __init__(self): | |
| r"""L1 loss in the time-domain and L1 loss on the spectrogram.""" | |
| super(L1_Wav_L1_Sp, self).__init__() | |
| self.window_size = 2048 | |
| hop_size = 441 | |
| center = True | |
| pad_mode = "reflect" | |
| window = "hann" | |
| self.stft = STFT( | |
| n_fft=self.window_size, | |
| hop_length=hop_size, | |
| win_length=self.window_size, | |
| window=window, | |
| center=center, | |
| pad_mode=pad_mode, | |
| freeze_parameters=True, | |
| ) | |
| def __call__( | |
| self, output: torch.Tensor, target: torch.Tensor, **kwargs | |
| ) -> torch.Tensor: | |
| r"""L1 loss in the time-domain and on the spectrogram. | |
| Args: | |
| output: torch.Tensor | |
| target: torch.Tensor | |
| Returns: | |
| loss: torch.float | |
| """ | |
| # L1 loss in the time-domain. | |
| wav_loss = l1_wav(output, target) | |
| # L1 loss on the spectrogram. | |
| sp_loss = l1( | |
| self.wav_to_spectrogram(output, eps=1e-8), | |
| self.wav_to_spectrogram(target, eps=1e-8), | |
| ) | |
| # sp_loss /= math.sqrt(self.window_size) | |
| # sp_loss *= 1. | |
| # Total loss. | |
| return wav_loss + sp_loss | |
| return sp_loss | |
| def get_loss_function(loss_type: str) -> Callable: | |
| r"""Get loss function. | |
| Args: | |
| loss_type: str | |
| Returns: | |
| loss function: Callable | |
| """ | |
| if loss_type == "l1_wav": | |
| return l1_wav | |
| elif loss_type == "l1_wav_l1_sp": | |
| return L1_Wav_L1_Sp() | |
| else: | |
| raise NotImplementedError | |