jamtur01's picture
Upload folder using huggingface_hub
9c6594c verified
# Copyright The Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Note: without special mention, the functions in this file are mainly translated from
# the SRMRpy package for batched processing with pytorch
from functools import lru_cache
from math import ceil, pi
from typing import Optional
import torch
from torch import Tensor
from torch.nn.functional import pad
from torchmetrics.utilities import rank_zero_warn
from torchmetrics.utilities.imports import (
_GAMMATONE_AVAILABLE,
_TORCHAUDIO_AVAILABLE,
)
if not _TORCHAUDIO_AVAILABLE or not _GAMMATONE_AVAILABLE:
__doctest_skip__ = ["speech_reverberation_modulation_energy_ratio"]
@lru_cache(maxsize=100)
def _calc_erbs(low_freq: float, fs: int, n_filters: int, device: torch.device) -> Tensor:
from gammatone.filters import centre_freqs
ear_q = 9.26449 # Glasberg and Moore Parameters
min_bw = 24.7
order = 1
erbs = ((centre_freqs(fs, n_filters, low_freq) / ear_q) ** order + min_bw**order) ** (1 / order)
return torch.tensor(erbs, device=device)
@lru_cache(maxsize=100)
def _make_erb_filters(fs: int, num_freqs: int, cutoff: float, device: torch.device) -> Tensor:
from gammatone.filters import centre_freqs, make_erb_filters
cfs = centre_freqs(fs, num_freqs, cutoff)
fcoefs = make_erb_filters(fs, cfs)
return torch.tensor(fcoefs, device=device)
@lru_cache(maxsize=100)
def _compute_modulation_filterbank_and_cutoffs(
min_cf: float, max_cf: float, n: int, fs: float, q: int, device: torch.device
) -> tuple[Tensor, Tensor, Tensor, Tensor]:
# this function is translated from the SRMRpy packaged
spacing_factor = (max_cf / min_cf) ** (1.0 / (n - 1))
cfs = torch.zeros(n, dtype=torch.float64)
cfs[0] = min_cf
for k in range(1, n):
cfs[k] = cfs[k - 1] * spacing_factor
def _make_modulation_filter(w0: Tensor, q: int) -> Tensor:
w0 = torch.tan(w0 / 2)
b0 = w0 / q
b = torch.tensor([b0, 0, -b0], dtype=torch.float64)
a = torch.tensor([(1 + b0 + w0**2), (2 * w0**2 - 2), (1 - b0 + w0**2)], dtype=torch.float64)
return torch.stack([b, a], dim=0)
mfb = torch.stack([_make_modulation_filter(w0, q) for w0 in 2 * pi * cfs / fs], dim=0)
def _calc_cutoffs(cfs: Tensor, fs: float, q: int) -> tuple[Tensor, Tensor]:
# Calculates cutoff frequencies (3 dB) for 2nd order bandpass
w0 = 2 * pi * cfs / fs
b0 = torch.tan(w0 / 2) / q
ll = cfs - (b0 * fs / (2 * pi))
rr = cfs + (b0 * fs / (2 * pi))
return ll, rr
cfs = cfs.to(device=device)
mfb = mfb.to(device=device)
ll, rr = _calc_cutoffs(cfs, fs, q)
return cfs, mfb, ll, rr
def _hilbert(x: Tensor, n: Optional[int] = None) -> Tensor:
if x.is_complex():
raise ValueError("x must be real.")
if n is None:
n = x.shape[-1]
# Make N multiple of 16 to make sure the transform will be fast
if n % 16:
n = ceil(n / 16) * 16
if n <= 0:
raise ValueError("N must be positive.")
x_fft = torch.fft.fft(x, n=n, dim=-1)
h = torch.zeros(n, dtype=x.dtype, device=x.device, requires_grad=False)
if n % 2 == 0:
h[0] = h[n // 2] = 1
h[1 : n // 2] = 2
else:
h[0] = 1
h[1 : (n + 1) // 2] = 2
y = torch.fft.ifft(x_fft * h, dim=-1)
return y[..., : x.shape[-1]]
def _erb_filterbank(wave: Tensor, coefs: Tensor) -> Tensor:
"""Translated from gammatone package.
Args:
wave: shape [B, time]
coefs: shape [N, 10]
Returns:
Tensor: shape [B, N, time]
"""
from torchaudio.functional.filtering import lfilter
num_batch, time = wave.shape
wave = wave.to(dtype=coefs.dtype).reshape(num_batch, 1, time) # [B, time]
wave = wave.expand(-1, coefs.shape[0], -1) # [B, N, time]
gain = coefs[:, 9]
as1 = coefs[:, (0, 1, 5)] # A0, A11, A2
as2 = coefs[:, (0, 2, 5)] # A0, A12, A2
as3 = coefs[:, (0, 3, 5)] # A0, A13, A2
as4 = coefs[:, (0, 4, 5)] # A0, A14, A2
bs = coefs[:, 6:9] # B0, B1, B2
y1 = lfilter(wave, bs, as1, batching=True)
y2 = lfilter(y1, bs, as2, batching=True)
y3 = lfilter(y2, bs, as3, batching=True)
y4 = lfilter(y3, bs, as4, batching=True)
return y4 / gain.reshape(1, -1, 1)
def _normalize_energy(energy: Tensor, drange: float = 30.0) -> Tensor:
"""Normalize energy to a dynamic range of 30 dB.
Args:
energy: shape [B, N_filters, 8, n_frames]
drange: dynamic range in dB
"""
peak_energy = torch.mean(energy, dim=1, keepdim=True).max(dim=2, keepdim=True).values
peak_energy = peak_energy.max(dim=3, keepdim=True).values
min_energy = peak_energy * 10.0 ** (-drange / 10.0)
energy = torch.where(energy < min_energy, min_energy, energy)
return torch.where(energy > peak_energy, peak_energy, energy)
def _cal_srmr_score(bw: Tensor, avg_energy: Tensor, cutoffs: Tensor) -> Tensor:
"""Calculate srmr score."""
if (cutoffs[4] <= bw) and (cutoffs[5] > bw):
kstar = 5
elif (cutoffs[5] <= bw) and (cutoffs[6] > bw):
kstar = 6
elif (cutoffs[6] <= bw) and (cutoffs[7] > bw):
kstar = 7
elif cutoffs[7] <= bw:
kstar = 8
else:
raise ValueError("Something wrong with the cutoffs compared to bw values.")
return torch.sum(avg_energy[:, :4]) / torch.sum(avg_energy[:, 4:kstar])
def speech_reverberation_modulation_energy_ratio(
preds: Tensor,
fs: int,
n_cochlear_filters: int = 23,
low_freq: float = 125,
min_cf: float = 4,
max_cf: Optional[float] = None,
norm: bool = False,
fast: bool = False,
) -> Tensor:
"""Calculate `Speech-to-Reverberation Modulation Energy Ratio`_ (SRMR).
SRMR is a non-intrusive metric for speech quality and intelligibility based on
a modulation spectral representation of the speech signal.
This code is translated from `SRMRToolbox`_ and `SRMRpy`_.
Args:
preds: shape ``(..., time)``
fs: the sampling rate
n_cochlear_filters: Number of filters in the acoustic filterbank
low_freq: determines the frequency cutoff for the corresponding gammatone filterbank.
min_cf: Center frequency in Hz of the first modulation filter.
max_cf: Center frequency in Hz of the last modulation filter. If None is given,
then 30 Hz will be used for `norm==False`, otherwise 128 Hz will be used.
norm: Use modulation spectrum energy normalization
fast: Use the faster version based on the gammatonegram.
Note: this argument is inherited from `SRMRpy`_. As the translated code is based to pytorch,
setting `fast=True` may slow down the speed for calculating this metric on GPU.
.. hint::
Usingsing this metrics requires you to have ``gammatone`` and ``torchaudio`` installed.
Either install as ``pip install torchmetrics[audio]`` or ``pip install torchaudio``
and ``pip install git+https://github.com/detly/gammatone``.
.. attention::
This implementation is experimental, and might not be consistent with the matlab
implementation `SRMRToolbox`_, especially the fast implementation.
The slow versions, a) fast=False, norm=False, max_cf=128, b) fast=False, norm=True, max_cf=30, have
a relatively small inconsistency.
Returns:
Scalar tensor with srmr value with shape ``(...)``
Raises:
ModuleNotFoundError:
If ``gammatone`` or ``torchaudio`` package is not installed
Example:
>>> from torch import randn
>>> from torchmetrics.functional.audio import speech_reverberation_modulation_energy_ratio
>>> preds = randn(8000)
>>> speech_reverberation_modulation_energy_ratio(preds, 8000)
tensor([0.3191], dtype=torch.float64)
"""
if not _TORCHAUDIO_AVAILABLE or not _GAMMATONE_AVAILABLE:
raise ModuleNotFoundError(
"speech_reverberation_modulation_energy_ratio requires you to have `gammatone` and"
" `torchaudio>=0.10` installed. Either install as ``pip install torchmetrics[audio]`` or "
"``pip install torchaudio>=0.10`` and ``pip install git+https://github.com/detly/gammatone``"
)
from gammatone.fftweight import fft_gtgram
from torchaudio.functional.filtering import lfilter
_srmr_arg_validate(
fs=fs,
n_cochlear_filters=n_cochlear_filters,
low_freq=low_freq,
min_cf=min_cf,
max_cf=max_cf,
norm=norm,
fast=fast,
)
shape = preds.shape
preds = preds.reshape(1, -1) if len(shape) == 1 else preds.reshape(-1, shape[-1])
num_batch, time = preds.shape
# convert int type to float
if not torch.is_floating_point(preds):
preds = preds.to(torch.float64) / torch.finfo(preds.dtype).max
# norm values in preds to [-1, 1], as lfilter requires an input in this range
max_vals = preds.abs().max(dim=-1, keepdim=True).values
val_norm = torch.where(
max_vals > 1,
max_vals,
torch.tensor(1.0, dtype=max_vals.dtype, device=max_vals.device),
)
preds = preds / val_norm
w_length_s = 0.256
w_inc_s = 0.064
# Computing gammatone envelopes
if fast:
rank_zero_warn("`fast=True` may slow down the speed of SRMR metric on GPU.")
mfs = 400.0
temp = []
preds_np = preds.detach().cpu().numpy()
for b in range(num_batch):
gt_env_b = fft_gtgram(preds_np[b], fs, 0.010, 0.0025, n_cochlear_filters, low_freq)
temp.append(torch.tensor(gt_env_b))
gt_env = torch.stack(temp, dim=0).to(device=preds.device)
else:
fcoefs = _make_erb_filters(fs, n_cochlear_filters, low_freq, device=preds.device) # [N_filters, 10]
gt_env = torch.abs(_hilbert(_erb_filterbank(preds, fcoefs))) # [B, N_filters, time]
mfs = fs
w_length = ceil(w_length_s * mfs)
w_inc = ceil(w_inc_s * mfs)
# Computing modulation filterbank with Q = 2 and 8 channels
if max_cf is None:
max_cf = 30 if norm else 128
_, mf, cutoffs, _ = _compute_modulation_filterbank_and_cutoffs(
min_cf, max_cf, n=8, fs=mfs, q=2, device=preds.device
)
num_frames = int(1 + (time - w_length) // w_inc)
w = torch.hamming_window(w_length + 1, dtype=torch.float64, device=preds.device)[:-1]
mod_out = lfilter(
gt_env.unsqueeze(-2).expand(-1, -1, mf.shape[0], -1), mf[:, 1, :], mf[:, 0, :], clamp=False, batching=True
) # [B, N_filters, 8, time]
# pad signal if it's shorter than window or it is not multiple of wInc
padding = (0, max(ceil(time / w_inc) * w_inc - time, w_length - time))
mod_out_pad = pad(mod_out, pad=padding, mode="constant", value=0)
mod_out_frame = mod_out_pad.unfold(-1, w_length, w_inc)
energy = ((mod_out_frame[..., :num_frames, :] * w) ** 2).sum(dim=-1) # [B, N_filters, 8, n_frames]
if norm:
energy = _normalize_energy(energy)
erbs = torch.flipud(_calc_erbs(low_freq, fs, n_cochlear_filters, device=preds.device))
avg_energy = torch.mean(energy, dim=-1)
total_energy = torch.sum(avg_energy.reshape(num_batch, -1), dim=-1)
ac_energy = torch.sum(avg_energy, dim=2)
ac_perc = ac_energy * 100 / total_energy.reshape(-1, 1)
ac_perc_cumsum = ac_perc.flip(-1).cumsum(-1)
k90perc_idx = torch.nonzero((ac_perc_cumsum > 90).cumsum(-1) == 1)[:, 1]
bw = erbs[k90perc_idx]
temp = []
for b in range(num_batch):
score = _cal_srmr_score(bw[b], avg_energy[b], cutoffs=cutoffs)
temp.append(score)
score = torch.stack(temp)
return score.reshape(*shape[:-1]) if len(shape) > 1 else score # recover original shape
def _srmr_arg_validate(
fs: int,
n_cochlear_filters: int = 23,
low_freq: float = 125,
min_cf: float = 4,
max_cf: Optional[float] = 128,
norm: bool = False,
fast: bool = False,
) -> None:
"""Validate the arguments for speech_reverberation_modulation_energy_ratio.
Args:
fs: the sampling rate
n_cochlear_filters: Number of filters in the acoustic filterbank
low_freq: determines the frequency cutoff for the corresponding gammatone filterbank.
min_cf: Center frequency in Hz of the first modulation filter.
max_cf: Center frequency in Hz of the last modulation filter. If None is given,
norm: Use modulation spectrum energy normalization
fast: Use the faster version based on the gammatonegram.
"""
if not (isinstance(fs, int) and fs > 0):
raise ValueError(f"Expected argument `fs` to be an int larger than 0, but got {fs}")
if not (isinstance(n_cochlear_filters, int) and n_cochlear_filters > 0):
raise ValueError(
f"Expected argument `n_cochlear_filters` to be an int larger than 0, but got {n_cochlear_filters}"
)
if not ((isinstance(low_freq, (float, int))) and low_freq > 0):
raise ValueError(f"Expected argument `low_freq` to be a float larger than 0, but got {low_freq}")
if not ((isinstance(min_cf, (float, int))) and min_cf > 0):
raise ValueError(f"Expected argument `min_cf` to be a float larger than 0, but got {min_cf}")
if max_cf is not None and not ((isinstance(max_cf, (float, int))) and max_cf > 0):
raise ValueError(f"Expected argument `max_cf` to be a float larger than 0, but got {max_cf}")
if not isinstance(norm, bool):
raise ValueError("Expected argument `norm` to be a bool value")
if not isinstance(fast, bool):
raise ValueError("Expected argument `fast` to be a bool value")