jamtur01's picture
Upload folder using huggingface_hub
9c6594c verified
# Copyright The Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections.abc import Sequence
from typing import Any, Optional, Union
from torch import Tensor, tensor
from torchmetrics.functional.audio.srmr import (
_srmr_arg_validate,
speech_reverberation_modulation_energy_ratio,
)
from torchmetrics.metric import Metric
from torchmetrics.utilities.imports import (
_GAMMATONE_AVAILABLE,
_MATPLOTLIB_AVAILABLE,
_TORCHAUDIO_AVAILABLE,
)
from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE
if not all([_GAMMATONE_AVAILABLE, _TORCHAUDIO_AVAILABLE]):
__doctest_skip__ = ["SpeechReverberationModulationEnergyRatio", "SpeechReverberationModulationEnergyRatio.plot"]
elif not _MATPLOTLIB_AVAILABLE:
__doctest_skip__ = ["SpeechReverberationModulationEnergyRatio.plot"]
class SpeechReverberationModulationEnergyRatio(Metric):
"""Calculate `Speech-to-Reverberation Modulation Energy Ratio`_ (SRMR).
SRMR is a non-intrusive metric for speech quality and intelligibility based on
a modulation spectral representation of the speech signal.
This code is translated from `SRMRToolbox`_ and `SRMRpy`_.
As input to ``forward`` and ``update`` the metric accepts the following input
- ``preds`` (:class:`~torch.Tensor`): float tensor with shape ``(...,time)``
As output of `forward` and `compute` the metric returns the following output
- ``srmr`` (:class:`~torch.Tensor`): float scaler tensor
.. hint::
Using this metrics requires you to have ``gammatone`` and ``torchaudio`` installed.
Either install as ``pip install torchmetrics[audio]`` or ``pip install torchaudio``
and ``pip install git+https://github.com/detly/gammatone``.
.. attention::
This implementation is experimental, and might not be consistent with the matlab
implementation `SRMRToolbox`_, especially the fast implementation.
The slow versions, a) fast=False, norm=False, max_cf=128, b) fast=False, norm=True, max_cf=30, have
a relatively small inconsistency.
Args:
fs: the sampling rate
n_cochlear_filters: Number of filters in the acoustic filterbank
low_freq: determines the frequency cutoff for the corresponding gammatone filterbank.
min_cf: Center frequency in Hz of the first modulation filter.
max_cf: Center frequency in Hz of the last modulation filter. If None is given,
then 30 Hz will be used for `norm==False`, otherwise 128 Hz will be used.
norm: Use modulation spectrum energy normalization
fast: Use the faster version based on the gammatonegram.
Note: this argument is inherited from `SRMRpy`_. As the translated code is based to pytorch,
setting `fast=True` may slow down the speed for calculating this metric on GPU.
Raises:
ModuleNotFoundError:
If ``gammatone`` or ``torchaudio`` package is not installed
Example:
>>> from torch import randn
>>> from torchmetrics.audio import SpeechReverberationModulationEnergyRatio
>>> preds = randn(8000)
>>> srmr = SpeechReverberationModulationEnergyRatio(8000)
>>> srmr(preds)
tensor(0.3191)
"""
msum: Tensor
total: Tensor
full_state_update: bool = False
is_differentiable: bool = True
higher_is_better: bool = True
plot_lower_bound: Optional[float] = None
plot_upper_bound: Optional[float] = None
def __init__(
self,
fs: int,
n_cochlear_filters: int = 23,
low_freq: float = 125,
min_cf: float = 4,
max_cf: Optional[float] = None,
norm: bool = False,
fast: bool = False,
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
if not _TORCHAUDIO_AVAILABLE or not _GAMMATONE_AVAILABLE:
raise ModuleNotFoundError(
"speech_reverberation_modulation_energy_ratio requires you to have `gammatone` and"
" `torchaudio>=0.10` installed. Either install as ``pip install torchmetrics[audio]`` or "
"``pip install torchaudio>=0.10`` and ``pip install git+https://github.com/detly/gammatone``"
)
_srmr_arg_validate(
fs=fs,
n_cochlear_filters=n_cochlear_filters,
low_freq=low_freq,
min_cf=min_cf,
max_cf=max_cf,
norm=norm,
fast=fast,
)
self.fs = fs
self.n_cochlear_filters = n_cochlear_filters
self.low_freq = low_freq
self.min_cf = min_cf
self.max_cf = max_cf
self.norm = norm
self.fast = fast
self.add_state("msum", default=tensor(0.0), dist_reduce_fx="sum")
self.add_state("total", default=tensor(0), dist_reduce_fx="sum")
def update(self, preds: Tensor) -> None:
"""Update state with predictions."""
metric_val_batch = speech_reverberation_modulation_energy_ratio(
preds, self.fs, self.n_cochlear_filters, self.low_freq, self.min_cf, self.max_cf, self.norm, self.fast
).to(self.msum.device)
self.msum += metric_val_batch.sum()
self.total += metric_val_batch.numel()
def compute(self) -> Tensor:
"""Compute metric."""
return self.msum / self.total
def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_AX_TYPE] = None) -> _PLOT_OUT_TYPE:
"""Plot a single or multiple values from the metric.
Args:
val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
If no value is provided, will automatically call `metric.compute` and plot that result.
ax: An matplotlib axis object. If provided will add plot to that axis
Returns:
Figure and Axes object
Raises:
ModuleNotFoundError:
If `matplotlib` is not installed
.. plot::
:scale: 75
>>> # Example plotting a single value
>>> import torch
>>> from torchmetrics.audio import SpeechReverberationModulationEnergyRatio
>>> metric = SpeechReverberationModulationEnergyRatio(8000)
>>> metric.update(torch.rand(8000))
>>> fig_, ax_ = metric.plot()
.. plot::
:scale: 75
>>> # Example plotting multiple values
>>> import torch
>>> from torchmetrics.audio import SpeechReverberationModulationEnergyRatio
>>> metric = SpeechReverberationModulationEnergyRatio(8000)
>>> values = [ ]
>>> for _ in range(10):
... values.append(metric(torch.rand(8000)))
>>> fig_, ax_ = metric.plot(values)
"""
return self._plot(val, ax)