File size: 4,065 Bytes
9c6594c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
from typing import Any, Callable, Optional
from typing_extensions import Literal
from torchmetrics.audio.pit import PermutationInvariantTraining
from torchmetrics.audio.sdr import ScaleInvariantSignalDistortionRatio, SignalDistortionRatio
from torchmetrics.audio.snr import ScaleInvariantSignalNoiseRatio, SignalNoiseRatio
from torchmetrics.utilities.prints import _deprecated_root_import_class
class _PermutationInvariantTraining(PermutationInvariantTraining):
"""Wrapper for deprecated import.
>>> import torch
>>> from torchmetrics.functional import scale_invariant_signal_noise_ratio
>>> preds = torch.randn(3, 2, 5) # [batch, spk, time]
>>> target = torch.randn(3, 2, 5) # [batch, spk, time]
>>> pit = _PermutationInvariantTraining(scale_invariant_signal_noise_ratio,
... mode="speaker-wise", eval_func="max")
>>> pit(preds, target)
tensor(-2.1065)
"""
def __init__(
self,
metric_func: Callable,
mode: Literal["speaker-wise", "permutation-wise"] = "speaker-wise",
eval_func: Literal["max", "min"] = "max",
**kwargs: Any,
) -> None:
_deprecated_root_import_class("PermutationInvariantTraining", "audio")
super().__init__(metric_func=metric_func, mode=mode, eval_func=eval_func, **kwargs)
class _ScaleInvariantSignalDistortionRatio(ScaleInvariantSignalDistortionRatio):
"""Wrapper for deprecated import.
>>> from torch import tensor
>>> target = tensor([3.0, -0.5, 2.0, 7.0])
>>> preds = tensor([2.5, 0.0, 2.0, 8.0])
>>> si_sdr = _ScaleInvariantSignalDistortionRatio()
>>> si_sdr(preds, target)
tensor(18.4030)
"""
def __init__(
self,
zero_mean: bool = False,
**kwargs: Any,
) -> None:
_deprecated_root_import_class("ScaleInvariantSignalDistortionRatio", "audio")
super().__init__(zero_mean=zero_mean, **kwargs)
class _ScaleInvariantSignalNoiseRatio(ScaleInvariantSignalNoiseRatio):
"""Wrapper for deprecated import.
>>> from torch import tensor
>>> target = tensor([3.0, -0.5, 2.0, 7.0])
>>> preds = tensor([2.5, 0.0, 2.0, 8.0])
>>> si_snr = _ScaleInvariantSignalNoiseRatio()
>>> si_snr(preds, target)
tensor(15.0918)
"""
def __init__(
self,
**kwargs: Any,
) -> None:
_deprecated_root_import_class("ScaleInvariantSignalNoiseRatio", "audio")
super().__init__(**kwargs)
class _SignalDistortionRatio(SignalDistortionRatio):
"""Wrapper for deprecated import.
>>> import torch
>>> preds = torch.randn(8000)
>>> target = torch.randn(8000)
>>> sdr = _SignalDistortionRatio()
>>> sdr(preds, target)
tensor(-11.9930)
>>> # use with pit
>>> from torchmetrics.functional import signal_distortion_ratio
>>> preds = torch.randn(4, 2, 8000) # [batch, spk, time]
>>> target = torch.randn(4, 2, 8000)
>>> pit = _PermutationInvariantTraining(signal_distortion_ratio,
... mode="speaker-wise", eval_func="max")
>>> pit(preds, target)
tensor(-11.7277)
"""
def __init__(
self,
use_cg_iter: Optional[int] = None,
filter_length: int = 512,
zero_mean: bool = False,
load_diag: Optional[float] = None,
**kwargs: Any,
) -> None:
_deprecated_root_import_class("SignalDistortionRatio", "audio")
super().__init__(
use_cg_iter=use_cg_iter, filter_length=filter_length, zero_mean=zero_mean, load_diag=load_diag, **kwargs
)
class _SignalNoiseRatio(SignalNoiseRatio):
"""Wrapper for deprecated import.
>>> from torch import tensor
>>> target = tensor([3.0, -0.5, 2.0, 7.0])
>>> preds = tensor([2.5, 0.0, 2.0, 8.0])
>>> snr = _SignalNoiseRatio()
>>> snr(preds, target)
tensor(16.1805)
"""
def __init__(
self,
zero_mean: bool = False,
**kwargs: Any,
) -> None:
_deprecated_root_import_class("SignalNoiseRatio", "audio")
super().__init__(zero_mean=zero_mean, **kwargs)
|