jamtur01's picture
Upload folder using huggingface_hub
9c6594c verified
from typing import Any, Callable, Optional
from typing_extensions import Literal
from torchmetrics.audio.pit import PermutationInvariantTraining
from torchmetrics.audio.sdr import ScaleInvariantSignalDistortionRatio, SignalDistortionRatio
from torchmetrics.audio.snr import ScaleInvariantSignalNoiseRatio, SignalNoiseRatio
from torchmetrics.utilities.prints import _deprecated_root_import_class
class _PermutationInvariantTraining(PermutationInvariantTraining):
"""Wrapper for deprecated import.
>>> import torch
>>> from torchmetrics.functional import scale_invariant_signal_noise_ratio
>>> preds = torch.randn(3, 2, 5) # [batch, spk, time]
>>> target = torch.randn(3, 2, 5) # [batch, spk, time]
>>> pit = _PermutationInvariantTraining(scale_invariant_signal_noise_ratio,
... mode="speaker-wise", eval_func="max")
>>> pit(preds, target)
tensor(-2.1065)
"""
def __init__(
self,
metric_func: Callable,
mode: Literal["speaker-wise", "permutation-wise"] = "speaker-wise",
eval_func: Literal["max", "min"] = "max",
**kwargs: Any,
) -> None:
_deprecated_root_import_class("PermutationInvariantTraining", "audio")
super().__init__(metric_func=metric_func, mode=mode, eval_func=eval_func, **kwargs)
class _ScaleInvariantSignalDistortionRatio(ScaleInvariantSignalDistortionRatio):
"""Wrapper for deprecated import.
>>> from torch import tensor
>>> target = tensor([3.0, -0.5, 2.0, 7.0])
>>> preds = tensor([2.5, 0.0, 2.0, 8.0])
>>> si_sdr = _ScaleInvariantSignalDistortionRatio()
>>> si_sdr(preds, target)
tensor(18.4030)
"""
def __init__(
self,
zero_mean: bool = False,
**kwargs: Any,
) -> None:
_deprecated_root_import_class("ScaleInvariantSignalDistortionRatio", "audio")
super().__init__(zero_mean=zero_mean, **kwargs)
class _ScaleInvariantSignalNoiseRatio(ScaleInvariantSignalNoiseRatio):
"""Wrapper for deprecated import.
>>> from torch import tensor
>>> target = tensor([3.0, -0.5, 2.0, 7.0])
>>> preds = tensor([2.5, 0.0, 2.0, 8.0])
>>> si_snr = _ScaleInvariantSignalNoiseRatio()
>>> si_snr(preds, target)
tensor(15.0918)
"""
def __init__(
self,
**kwargs: Any,
) -> None:
_deprecated_root_import_class("ScaleInvariantSignalNoiseRatio", "audio")
super().__init__(**kwargs)
class _SignalDistortionRatio(SignalDistortionRatio):
"""Wrapper for deprecated import.
>>> import torch
>>> preds = torch.randn(8000)
>>> target = torch.randn(8000)
>>> sdr = _SignalDistortionRatio()
>>> sdr(preds, target)
tensor(-11.9930)
>>> # use with pit
>>> from torchmetrics.functional import signal_distortion_ratio
>>> preds = torch.randn(4, 2, 8000) # [batch, spk, time]
>>> target = torch.randn(4, 2, 8000)
>>> pit = _PermutationInvariantTraining(signal_distortion_ratio,
... mode="speaker-wise", eval_func="max")
>>> pit(preds, target)
tensor(-11.7277)
"""
def __init__(
self,
use_cg_iter: Optional[int] = None,
filter_length: int = 512,
zero_mean: bool = False,
load_diag: Optional[float] = None,
**kwargs: Any,
) -> None:
_deprecated_root_import_class("SignalDistortionRatio", "audio")
super().__init__(
use_cg_iter=use_cg_iter, filter_length=filter_length, zero_mean=zero_mean, load_diag=load_diag, **kwargs
)
class _SignalNoiseRatio(SignalNoiseRatio):
"""Wrapper for deprecated import.
>>> from torch import tensor
>>> target = tensor([3.0, -0.5, 2.0, 7.0])
>>> preds = tensor([2.5, 0.0, 2.0, 8.0])
>>> snr = _SignalNoiseRatio()
>>> snr(preds, target)
tensor(16.1805)
"""
def __init__(
self,
zero_mean: bool = False,
**kwargs: Any,
) -> None:
_deprecated_root_import_class("SignalNoiseRatio", "audio")
super().__init__(zero_mean=zero_mean, **kwargs)