jamtur01's picture
Upload folder using huggingface_hub
9c6594c verified
# Copyright The Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Code related main NISQA model definition are under the following copyright
# Copyright (c) 2021 Gabriel Mittag, Quality and Usability Lab
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
import copy
import math
import os
import warnings
from functools import lru_cache
from typing import Any
import numpy as np
import torch
import torch.nn as nn
from torch import Tensor
from torch.nn.functional import adaptive_max_pool2d, relu, softmax
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
from torchmetrics.utilities import rank_zero_info
from torchmetrics.utilities.imports import _LIBROSA_AVAILABLE, _REQUESTS_AVAILABLE
if _LIBROSA_AVAILABLE and _REQUESTS_AVAILABLE:
import librosa
import requests
else:
librosa, requests = None, None # type:ignore
__doctest_requires__ = {("non_intrusive_speech_quality_assessment",): ["librosa", "requests"]}
NISQA_DIR = "~/.torchmetrics/NISQA"
def non_intrusive_speech_quality_assessment(preds: Tensor, fs: int) -> Tensor:
"""`Non-Intrusive Speech Quality Assessment`_ (NISQA v2.0) [1], [2].
.. hint::
Usingsing this metric requires you to have ``librosa`` and ``requests`` installed. Install as
``pip install librosa requests``.
Args:
preds: float tensor with shape ``(...,time)``
fs: sampling frequency of input
Returns:
Float tensor with shape ``(...,5)`` corresponding to overall MOS, noisiness, discontinuity, coloration and
loudness in that order
Raises:
ModuleNotFoundError:
If ``librosa`` or ``requests`` are not installed
RuntimeError:
If the input is too short, causing the number of mel spectrogram windows to be zero
RuntimeError:
If the input is too long, causing the number of mel spectrogram windows to exceed the maximum allowed
Example:
>>> import torch
>>> from torchmetrics.functional.audio.nisqa import non_intrusive_speech_quality_assessment
>>> _ = torch.manual_seed(42)
>>> preds = torch.randn(16000)
>>> non_intrusive_speech_quality_assessment(preds, 16000)
tensor([1.0433, 1.9545, 2.6087, 1.3460, 1.7117])
References:
- [1] G. Mittag and S. MΓΆller, "Non-intrusive speech quality assessment for super-wideband speech communication
networks", in Proc. ICASSP, 2019.
- [2] G. Mittag, B. Naderi, A. Chehadi and S. MΓΆller, "NISQA: A deep CNN-self-attention model for
multidimensional speech quality prediction with crowdsourced datasets", in Proc. INTERSPEECH, 2021.
"""
if not _LIBROSA_AVAILABLE or not _REQUESTS_AVAILABLE:
raise ModuleNotFoundError(
"NISQA metric requires that librosa and requests are installed. Install as `pip install librosa requests`."
)
model, args = _load_nisqa_model()
if not isinstance(fs, int) or fs <= 0:
raise ValueError(f"Argument `fs` expected to be a positive integer, but got {fs}")
model.eval()
x = preds.reshape(-1, preds.shape[-1])
x = _get_librosa_melspec(x.cpu().numpy(), fs, args)
x, n_wins = _segment_specs(torch.from_numpy(x), args)
with torch.no_grad():
x = model(x, n_wins.expand(x.shape[0]))
# ["mos_pred", "noi_pred", "dis_pred", "col_pred", "loud_pred"]
# the dimensions are always listed in the papers as MOS, noisiness, coloration, discontinuity and loudness
# but based on original code the actual model output order is MOS, noisiness, discontinuity, coloration, loudness
return x.reshape(preds.shape[:-1] + (5,))
@lru_cache
def _load_nisqa_model() -> tuple[nn.Module, dict[str, Any]]:
"""Load NISQA model and its parameters.
Returns:
Tuple ``(model,args)`` where ``model`` is the NISQA model and ``args`` is a dictionary with all its parameters
"""
model_path = os.path.expanduser(os.path.join(NISQA_DIR, "nisqa.tar"))
if not os.path.exists(model_path):
_download_weights()
checkpoint = torch.load(model_path, map_location="cpu", weights_only=True)
args = checkpoint["args"]
model = _NISQADIM(args)
model.load_state_dict(checkpoint["model_state_dict"], strict=True)
return model, args
def _download_weights() -> None:
"""Download NISQA model weights."""
url = "https://github.com/gabrielmittag/NISQA/raw/refs/heads/master/weights/nisqa.tar"
nisqa_dir = os.path.expanduser(NISQA_DIR)
os.makedirs(nisqa_dir, exist_ok=True)
saveto = os.path.join(nisqa_dir, "nisqa.tar")
if os.path.exists(saveto):
return
rank_zero_info(f"downloading {url} to {saveto}")
myfile = requests.get(url)
with open(saveto, "wb") as f:
f.write(myfile.content)
class _NISQADIM(nn.Module):
# main NISQA model definition
# ported from https://github.com/gabrielmittag/NISQA
# Copyright (c) 2021 Gabriel Mittag, Quality and Usability Lab
# MIT License
def __init__(self, args: dict[str, Any]) -> None:
super().__init__()
self.cnn = _Framewise(args)
self.time_dependency = _TimeDependency(args)
pool = _Pooling(args)
self.pool_layers = _get_clones(pool, 5)
def forward(self, x: Tensor, n_wins: Tensor) -> Tensor:
x = self.cnn(x, n_wins)
x, n_wins = self.time_dependency(x, n_wins)
out = [mod(x, n_wins) for mod in self.pool_layers]
return torch.cat(out, dim=1)
class _Framewise(nn.Module):
# part of NISQA model definition
def __init__(self, args: dict[str, Any]) -> None:
super().__init__()
self.model = _AdaptCNN(args)
def forward(self, x: Tensor, n_wins: Tensor) -> Tensor:
x_packed = pack_padded_sequence(x, n_wins, batch_first=True, enforce_sorted=False)
x = self.model(x_packed.data.unsqueeze(1))
x = x_packed._replace(data=x)
x, _ = pad_packed_sequence(x, batch_first=True, padding_value=0.0, total_length=int(n_wins.max()))
return x
class _AdaptCNN(nn.Module):
# part of NISQA model definition
def __init__(self, args: dict[str, Any]) -> None:
super().__init__()
self.pool_1 = args["cnn_pool_1"]
self.pool_2 = args["cnn_pool_2"]
self.pool_3 = args["cnn_pool_3"]
self.dropout = nn.Dropout2d(p=args["cnn_dropout"])
cnn_pad = (1, 0) if args["cnn_kernel_size"][0] == 1 else (1, 1)
self.conv1 = nn.Conv2d(1, args["cnn_c_out_1"], args["cnn_kernel_size"], padding=cnn_pad)
self.bn1 = nn.BatchNorm2d(self.conv1.out_channels)
self.conv2 = nn.Conv2d(self.conv1.out_channels, args["cnn_c_out_2"], args["cnn_kernel_size"], padding=cnn_pad)
self.bn2 = nn.BatchNorm2d(self.conv2.out_channels)
self.conv3 = nn.Conv2d(self.conv2.out_channels, args["cnn_c_out_3"], args["cnn_kernel_size"], padding=cnn_pad)
self.bn3 = nn.BatchNorm2d(self.conv3.out_channels)
self.conv4 = nn.Conv2d(self.conv3.out_channels, args["cnn_c_out_3"], args["cnn_kernel_size"], padding=cnn_pad)
self.bn4 = nn.BatchNorm2d(self.conv4.out_channels)
self.conv5 = nn.Conv2d(self.conv4.out_channels, args["cnn_c_out_3"], args["cnn_kernel_size"], padding=cnn_pad)
self.bn5 = nn.BatchNorm2d(self.conv5.out_channels)
self.conv6 = nn.Conv2d(
self.conv5.out_channels,
args["cnn_c_out_3"],
(args["cnn_kernel_size"][0], args["cnn_pool_3"][1]),
padding=(1, 0),
)
self.bn6 = nn.BatchNorm2d(self.conv6.out_channels)
def forward(self, x: Tensor) -> Tensor:
x = relu(self.bn1(self.conv1(x)))
x = adaptive_max_pool2d(x, output_size=(self.pool_1))
x = relu(self.bn2(self.conv2(x)))
x = adaptive_max_pool2d(x, output_size=(self.pool_2))
x = self.dropout(x)
x = relu(self.bn3(self.conv3(x)))
x = self.dropout(x)
x = relu(self.bn4(self.conv4(x)))
x = adaptive_max_pool2d(x, output_size=(self.pool_3))
x = self.dropout(x)
x = relu(self.bn5(self.conv5(x)))
x = self.dropout(x)
x = relu(self.bn6(self.conv6(x)))
return x.view(-1, self.conv6.out_channels * self.pool_3[0])
class _TimeDependency(nn.Module):
# part of NISQA model definition
def __init__(self, args: dict[str, Any]) -> None:
super().__init__()
self.model = _SelfAttention(args)
def forward(self, x: Tensor, n_wins: Tensor) -> Tensor:
return self.model(x, n_wins)
class _SelfAttention(nn.Module):
# part of NISQA model definition
def __init__(self, args: dict[str, Any]) -> None:
super().__init__()
encoder_layer = _SelfAttentionLayer(args)
self.norm1 = nn.LayerNorm(args["td_sa_d_model"])
self.linear = nn.Linear(args["cnn_c_out_3"] * args["cnn_pool_3"][0], args["td_sa_d_model"])
self.layers = _get_clones(encoder_layer, args["td_sa_num_layers"])
self._reset_parameters()
def _reset_parameters(self) -> None:
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
def forward(self, src: Tensor, n_wins: Tensor) -> tuple[Tensor, Tensor]:
src = self.linear(src)
output = src.transpose(1, 0)
output = self.norm1(output)
for mod in self.layers:
output, n_wins = mod(output, n_wins)
return output.transpose(1, 0), n_wins
class _SelfAttentionLayer(nn.Module):
# part of NISQA model definition
def __init__(self, args: dict[str, Any]) -> None:
super().__init__()
self.self_attn = nn.MultiheadAttention(args["td_sa_d_model"], args["td_sa_nhead"], args["td_sa_dropout"])
self.linear1 = nn.Linear(args["td_sa_d_model"], args["td_sa_h"])
self.dropout = nn.Dropout(args["td_sa_dropout"])
self.linear2 = nn.Linear(args["td_sa_h"], args["td_sa_d_model"])
self.norm1 = nn.LayerNorm(args["td_sa_d_model"])
self.norm2 = nn.LayerNorm(args["td_sa_d_model"])
self.dropout1 = nn.Dropout(args["td_sa_dropout"])
self.dropout2 = nn.Dropout(args["td_sa_dropout"])
self.activation = relu
def forward(self, src: Tensor, n_wins: Tensor) -> tuple[Tensor, Tensor]:
mask = torch.arange(src.shape[0])[None, :] < n_wins[:, None]
src2 = self.self_attn(src, src, src, key_padding_mask=~mask)[0]
src = src + self.dropout1(src2)
src = self.norm1(src)
src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
src = src + self.dropout2(src2)
src = self.norm2(src)
return src, n_wins
class _Pooling(nn.Module):
# part of NISQA model definition
def __init__(self, args: dict[str, Any]) -> None:
super().__init__()
self.model = _PoolAttFF(args)
def forward(self, x: Tensor, n_wins: Tensor) -> Tensor:
return self.model(x, n_wins)
class _PoolAttFF(torch.nn.Module):
# part of NISQA model definition
def __init__(self, args: dict[str, Any]) -> None:
super().__init__()
self.linear1 = nn.Linear(args["td_sa_d_model"], args["pool_att_h"])
self.linear2 = nn.Linear(args["pool_att_h"], 1)
self.linear3 = nn.Linear(args["td_sa_d_model"], 1)
self.activation = relu
self.dropout = nn.Dropout(args["pool_att_dropout"])
def forward(self, x: Tensor, n_wins: Tensor) -> Tensor:
att = self.linear2(self.dropout(self.activation(self.linear1(x))))
att = att.transpose(2, 1)
mask = torch.arange(att.shape[2])[None, :] < n_wins[:, None]
att[~mask.unsqueeze(1)] = float("-inf")
att = softmax(att, dim=2)
x = torch.bmm(att, x)
x = x.squeeze(1)
return self.linear3(x)
def _get_librosa_melspec(y: np.ndarray, sr: int, args: dict[str, Any]) -> np.ndarray:
"""Compute mel spectrogram from waveform using librosa.
Args:
y: waveform with shape ``(batch_size,time)``
sr: sampling rate
args: dictionary with all NISQA parameters
Returns:
Mel spectrogram with shape ``(batch_size,n_mels,n_frames)``
"""
hop_length = int(sr * args["ms_hop_length"])
win_length = int(sr * args["ms_win_length"])
with warnings.catch_warnings():
# ignore empty mel filter warning since this is expected when input signal is not fullband
# see https://github.com/gabrielmittag/NISQA/issues/6#issuecomment-838157571
warnings.filterwarnings("ignore", message="Empty filters detected in mel frequency basis")
melspec = librosa.feature.melspectrogram(
y=y,
sr=sr,
S=None,
n_fft=args["ms_n_fft"],
hop_length=hop_length,
win_length=win_length,
window="hann",
center=True,
pad_mode="reflect",
power=1.0,
n_mels=args["ms_n_mels"],
fmin=0.0,
fmax=args["ms_fmax"],
htk=False,
norm="slaney",
)
# batch processing of librosa.core.amplitude_to_db is not equivalent to individual processing due to top_db being
# relative to max value
# so process individually and then stack
return np.stack([librosa.amplitude_to_db(m, ref=1.0, amin=1e-4, top_db=80.0) for m in melspec])
def _segment_specs(x: Tensor, args: dict[str, Any]) -> tuple[Tensor, Tensor]:
"""Segment mel spectrogram into overlapping windows.
Args:
x: mel spectrogram with shape ``(batch_size,n_mels,n_frames)``
args: dictionary with all NISQA parameters
Returns:
Tuple ``(x_padded,n_wins)```, where ``x_padded`` is the segmented mel spectrogram with shape
``(batch_size,max_length,n_mels,seg_length)`` where the second dimension is the number of windows and was
padded to ``max_length``, and ``n_wins`` is the number of windows and is 0-dimensional
"""
seg_length = args["ms_seg_length"]
seg_hop = args["ms_seg_hop_length"]
max_length = args["ms_max_segments"]
n_wins = x.shape[2] - (seg_length - 1)
if n_wins < 1:
raise RuntimeError("Input signal is too short.")
idx1 = torch.arange(seg_length)
idx2 = torch.arange(n_wins)
idx3 = idx1.unsqueeze(0) + idx2.unsqueeze(1)
x = x.transpose(2, 1)[:, idx3, :].transpose(3, 2)
x = x[:, ::seg_hop]
n_wins = math.ceil(n_wins / seg_hop)
if max_length < n_wins:
raise RuntimeError("Maximum number of mel spectrogram windows exceeded. Use shorter audio.")
x_padded = torch.zeros((x.shape[0], max_length, x.shape[2], x.shape[3]))
x_padded[:, :n_wins] = x
return x_padded, torch.tensor(n_wins)
def _get_clones(module: nn.Module, n: int) -> nn.ModuleList:
"""Create ``n`` copies of a module."""
return nn.ModuleList([copy.deepcopy(module) for i in range(n)])