File size: 4,065 Bytes
9c6594c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
from typing import Any, Callable, Optional

from typing_extensions import Literal

from torchmetrics.audio.pit import PermutationInvariantTraining
from torchmetrics.audio.sdr import ScaleInvariantSignalDistortionRatio, SignalDistortionRatio
from torchmetrics.audio.snr import ScaleInvariantSignalNoiseRatio, SignalNoiseRatio
from torchmetrics.utilities.prints import _deprecated_root_import_class


class _PermutationInvariantTraining(PermutationInvariantTraining):
    """Wrapper for deprecated import.

    >>> import torch
    >>> from torchmetrics.functional import scale_invariant_signal_noise_ratio
    >>> preds = torch.randn(3, 2, 5) # [batch, spk, time]
    >>> target = torch.randn(3, 2, 5) # [batch, spk, time]
    >>> pit = _PermutationInvariantTraining(scale_invariant_signal_noise_ratio,
    ...     mode="speaker-wise", eval_func="max")
    >>> pit(preds, target)
    tensor(-2.1065)

    """

    def __init__(
        self,
        metric_func: Callable,
        mode: Literal["speaker-wise", "permutation-wise"] = "speaker-wise",
        eval_func: Literal["max", "min"] = "max",
        **kwargs: Any,
    ) -> None:
        _deprecated_root_import_class("PermutationInvariantTraining", "audio")
        super().__init__(metric_func=metric_func, mode=mode, eval_func=eval_func, **kwargs)


class _ScaleInvariantSignalDistortionRatio(ScaleInvariantSignalDistortionRatio):
    """Wrapper for deprecated import.

    >>> from torch import tensor
    >>> target = tensor([3.0, -0.5, 2.0, 7.0])
    >>> preds = tensor([2.5, 0.0, 2.0, 8.0])
    >>> si_sdr = _ScaleInvariantSignalDistortionRatio()
    >>> si_sdr(preds, target)
    tensor(18.4030)

    """

    def __init__(
        self,
        zero_mean: bool = False,
        **kwargs: Any,
    ) -> None:
        _deprecated_root_import_class("ScaleInvariantSignalDistortionRatio", "audio")
        super().__init__(zero_mean=zero_mean, **kwargs)


class _ScaleInvariantSignalNoiseRatio(ScaleInvariantSignalNoiseRatio):
    """Wrapper for deprecated import.

    >>> from torch import tensor
    >>> target = tensor([3.0, -0.5, 2.0, 7.0])
    >>> preds = tensor([2.5, 0.0, 2.0, 8.0])
    >>> si_snr = _ScaleInvariantSignalNoiseRatio()
    >>> si_snr(preds, target)
    tensor(15.0918)

    """

    def __init__(
        self,
        **kwargs: Any,
    ) -> None:
        _deprecated_root_import_class("ScaleInvariantSignalNoiseRatio", "audio")
        super().__init__(**kwargs)


class _SignalDistortionRatio(SignalDistortionRatio):
    """Wrapper for deprecated import.

    >>> import torch
    >>> preds = torch.randn(8000)
    >>> target = torch.randn(8000)
    >>> sdr = _SignalDistortionRatio()
    >>> sdr(preds, target)
    tensor(-11.9930)
    >>> # use with pit
    >>> from torchmetrics.functional import signal_distortion_ratio
    >>> preds = torch.randn(4, 2, 8000)  # [batch, spk, time]
    >>> target = torch.randn(4, 2, 8000)
    >>> pit = _PermutationInvariantTraining(signal_distortion_ratio,
    ...     mode="speaker-wise", eval_func="max")
    >>> pit(preds, target)
    tensor(-11.7277)

    """

    def __init__(
        self,
        use_cg_iter: Optional[int] = None,
        filter_length: int = 512,
        zero_mean: bool = False,
        load_diag: Optional[float] = None,
        **kwargs: Any,
    ) -> None:
        _deprecated_root_import_class("SignalDistortionRatio", "audio")
        super().__init__(
            use_cg_iter=use_cg_iter, filter_length=filter_length, zero_mean=zero_mean, load_diag=load_diag, **kwargs
        )


class _SignalNoiseRatio(SignalNoiseRatio):
    """Wrapper for deprecated import.

    >>> from torch import tensor
    >>> target = tensor([3.0, -0.5, 2.0, 7.0])
    >>> preds = tensor([2.5, 0.0, 2.0, 8.0])
    >>> snr = _SignalNoiseRatio()
    >>> snr(preds, target)
    tensor(16.1805)

    """

    def __init__(
        self,
        zero_mean: bool = False,
        **kwargs: Any,
    ) -> None:
        _deprecated_root_import_class("SignalNoiseRatio", "audio")
        super().__init__(zero_mean=zero_mean, **kwargs)