|
from typing import Any, Callable, Optional |
|
|
|
from typing_extensions import Literal |
|
|
|
from torchmetrics.audio.pit import PermutationInvariantTraining |
|
from torchmetrics.audio.sdr import ScaleInvariantSignalDistortionRatio, SignalDistortionRatio |
|
from torchmetrics.audio.snr import ScaleInvariantSignalNoiseRatio, SignalNoiseRatio |
|
from torchmetrics.utilities.prints import _deprecated_root_import_class |
|
|
|
|
|
class _PermutationInvariantTraining(PermutationInvariantTraining): |
|
"""Wrapper for deprecated import. |
|
|
|
>>> import torch |
|
>>> from torchmetrics.functional import scale_invariant_signal_noise_ratio |
|
>>> preds = torch.randn(3, 2, 5) # [batch, spk, time] |
|
>>> target = torch.randn(3, 2, 5) # [batch, spk, time] |
|
>>> pit = _PermutationInvariantTraining(scale_invariant_signal_noise_ratio, |
|
... mode="speaker-wise", eval_func="max") |
|
>>> pit(preds, target) |
|
tensor(-2.1065) |
|
|
|
""" |
|
|
|
def __init__( |
|
self, |
|
metric_func: Callable, |
|
mode: Literal["speaker-wise", "permutation-wise"] = "speaker-wise", |
|
eval_func: Literal["max", "min"] = "max", |
|
**kwargs: Any, |
|
) -> None: |
|
_deprecated_root_import_class("PermutationInvariantTraining", "audio") |
|
super().__init__(metric_func=metric_func, mode=mode, eval_func=eval_func, **kwargs) |
|
|
|
|
|
class _ScaleInvariantSignalDistortionRatio(ScaleInvariantSignalDistortionRatio): |
|
"""Wrapper for deprecated import. |
|
|
|
>>> from torch import tensor |
|
>>> target = tensor([3.0, -0.5, 2.0, 7.0]) |
|
>>> preds = tensor([2.5, 0.0, 2.0, 8.0]) |
|
>>> si_sdr = _ScaleInvariantSignalDistortionRatio() |
|
>>> si_sdr(preds, target) |
|
tensor(18.4030) |
|
|
|
""" |
|
|
|
def __init__( |
|
self, |
|
zero_mean: bool = False, |
|
**kwargs: Any, |
|
) -> None: |
|
_deprecated_root_import_class("ScaleInvariantSignalDistortionRatio", "audio") |
|
super().__init__(zero_mean=zero_mean, **kwargs) |
|
|
|
|
|
class _ScaleInvariantSignalNoiseRatio(ScaleInvariantSignalNoiseRatio): |
|
"""Wrapper for deprecated import. |
|
|
|
>>> from torch import tensor |
|
>>> target = tensor([3.0, -0.5, 2.0, 7.0]) |
|
>>> preds = tensor([2.5, 0.0, 2.0, 8.0]) |
|
>>> si_snr = _ScaleInvariantSignalNoiseRatio() |
|
>>> si_snr(preds, target) |
|
tensor(15.0918) |
|
|
|
""" |
|
|
|
def __init__( |
|
self, |
|
**kwargs: Any, |
|
) -> None: |
|
_deprecated_root_import_class("ScaleInvariantSignalNoiseRatio", "audio") |
|
super().__init__(**kwargs) |
|
|
|
|
|
class _SignalDistortionRatio(SignalDistortionRatio): |
|
"""Wrapper for deprecated import. |
|
|
|
>>> import torch |
|
>>> preds = torch.randn(8000) |
|
>>> target = torch.randn(8000) |
|
>>> sdr = _SignalDistortionRatio() |
|
>>> sdr(preds, target) |
|
tensor(-11.9930) |
|
>>> # use with pit |
|
>>> from torchmetrics.functional import signal_distortion_ratio |
|
>>> preds = torch.randn(4, 2, 8000) # [batch, spk, time] |
|
>>> target = torch.randn(4, 2, 8000) |
|
>>> pit = _PermutationInvariantTraining(signal_distortion_ratio, |
|
... mode="speaker-wise", eval_func="max") |
|
>>> pit(preds, target) |
|
tensor(-11.7277) |
|
|
|
""" |
|
|
|
def __init__( |
|
self, |
|
use_cg_iter: Optional[int] = None, |
|
filter_length: int = 512, |
|
zero_mean: bool = False, |
|
load_diag: Optional[float] = None, |
|
**kwargs: Any, |
|
) -> None: |
|
_deprecated_root_import_class("SignalDistortionRatio", "audio") |
|
super().__init__( |
|
use_cg_iter=use_cg_iter, filter_length=filter_length, zero_mean=zero_mean, load_diag=load_diag, **kwargs |
|
) |
|
|
|
|
|
class _SignalNoiseRatio(SignalNoiseRatio): |
|
"""Wrapper for deprecated import. |
|
|
|
>>> from torch import tensor |
|
>>> target = tensor([3.0, -0.5, 2.0, 7.0]) |
|
>>> preds = tensor([2.5, 0.0, 2.0, 8.0]) |
|
>>> snr = _SignalNoiseRatio() |
|
>>> snr(preds, target) |
|
tensor(16.1805) |
|
|
|
""" |
|
|
|
def __init__( |
|
self, |
|
zero_mean: bool = False, |
|
**kwargs: Any, |
|
) -> None: |
|
_deprecated_root_import_class("SignalNoiseRatio", "audio") |
|
super().__init__(zero_mean=zero_mean, **kwargs) |
|
|