File size: 5,002 Bytes
9c6594c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
# Copyright The Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from torch import Tensor
from torchmetrics.functional.audio.sdr import scale_invariant_signal_distortion_ratio
from torchmetrics.utilities.checks import _check_same_shape
def signal_noise_ratio(preds: Tensor, target: Tensor, zero_mean: bool = False) -> Tensor:
r"""Calculate `Signal-to-noise ratio`_ (SNR_) meric for evaluating quality of audio.
.. math::
\text{SNR} = \frac{P_{signal}}{P_{noise}}
where :math:`P` denotes the power of each signal. The SNR metric compares the level of the desired signal to
the level of background noise. Therefore, a high value of SNR means that the audio is clear.
Args:
preds: float tensor with shape ``(...,time)``
target: float tensor with shape ``(...,time)``
zero_mean: if to zero mean target and preds or not
Returns:
Float tensor with shape ``(...,)`` of SNR values per sample
Raises:
RuntimeError:
If ``preds`` and ``target`` does not have the same shape
Example:
>>> from torchmetrics.functional.audio import signal_noise_ratio
>>> target = torch.tensor([3.0, -0.5, 2.0, 7.0])
>>> preds = torch.tensor([2.5, 0.0, 2.0, 8.0])
>>> signal_noise_ratio(preds, target)
tensor(16.1805)
"""
_check_same_shape(preds, target)
eps = torch.finfo(preds.dtype).eps
if zero_mean:
target = target - torch.mean(target, dim=-1, keepdim=True)
preds = preds - torch.mean(preds, dim=-1, keepdim=True)
noise = target - preds
snr_value = (torch.sum(target**2, dim=-1) + eps) / (torch.sum(noise**2, dim=-1) + eps)
return 10 * torch.log10(snr_value)
def scale_invariant_signal_noise_ratio(preds: Tensor, target: Tensor) -> Tensor:
"""`Scale-invariant signal-to-noise ratio`_ (SI-SNR).
Args:
preds: float tensor with shape ``(...,time)``
target: float tensor with shape ``(...,time)``
Returns:
Float tensor with shape ``(...,)`` of SI-SNR values per sample
Raises:
RuntimeError:
If ``preds`` and ``target`` does not have the same shape
Example:
>>> import torch
>>> from torchmetrics.functional.audio import scale_invariant_signal_noise_ratio
>>> target = torch.tensor([3.0, -0.5, 2.0, 7.0])
>>> preds = torch.tensor([2.5, 0.0, 2.0, 8.0])
>>> scale_invariant_signal_noise_ratio(preds, target)
tensor(15.0918)
"""
return scale_invariant_signal_distortion_ratio(preds=preds, target=target, zero_mean=True)
def complex_scale_invariant_signal_noise_ratio(preds: Tensor, target: Tensor, zero_mean: bool = False) -> Tensor:
"""`Complex scale-invariant signal-to-noise ratio`_ (C-SI-SNR).
Args:
preds: real float tensor with shape ``(...,frequency,time,2)`` or complex float tensor with
shape ``(..., frequency,time)``
target: real float tensor with shape ``(...,frequency,time,2)`` or complex float tensor with
shape ``(..., frequency,time)``
zero_mean: When set to True, the mean of all signals is subtracted prior to computation of the metrics
Returns:
Float tensor with shape ``(...,)`` of C-SI-SNR values per sample
Raises:
RuntimeError:
If ``preds`` is not the shape (...,frequency,time,2) (after being converted to real if it is complex).
If ``preds`` and ``target`` does not have the same shape.
Example:
>>> from torch import randn
>>> from torchmetrics.functional.audio import complex_scale_invariant_signal_noise_ratio
>>> preds = randn((1,257,100,2))
>>> target = randn((1,257,100,2))
>>> complex_scale_invariant_signal_noise_ratio(preds, target)
tensor([-38.8832])
"""
if preds.is_complex():
preds = torch.view_as_real(preds)
if target.is_complex():
target = torch.view_as_real(target)
if (preds.ndim < 3 or preds.shape[-1] != 2) or (target.ndim < 3 or target.shape[-1] != 2):
raise RuntimeError(
"Predictions and targets are expected to have the shape (..., frequency, time, 2),"
f" but got {preds.shape} and {target.shape}."
)
preds = preds.reshape(*preds.shape[:-3], -1)
target = target.reshape(*target.shape[:-3], -1)
return scale_invariant_signal_distortion_ratio(preds=preds, target=target, zero_mean=zero_mean)
|