File size: 9,512 Bytes
9c6594c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
r"""Root package info."""

import logging as __logging
import os

from lightning_utilities.core.imports import package_available

from torchmetrics.__about__ import *  # noqa: F403

_logger = __logging.getLogger("torchmetrics")
_logger.addHandler(__logging.StreamHandler())
_logger.setLevel(__logging.INFO)

_PACKAGE_ROOT = os.path.dirname(__file__)
_PROJECT_ROOT = os.path.dirname(_PACKAGE_ROOT)

if package_available("numpy"):
    # compatibility for AttributeError: `np.Inf` was removed in the NumPy 2.0 release. Use `np.inf` instead
    import numpy

    numpy.Inf = numpy.inf


if package_available("PIL"):
    import PIL

    if not hasattr(PIL, "PILLOW_VERSION"):
        PIL.PILLOW_VERSION = PIL.__version__

if package_available("scipy"):
    import scipy.signal

    # back compatibility patch due to SMRMpy using scipy.signal.hamming
    if not hasattr(scipy.signal, "hamming"):
        scipy.signal.hamming = scipy.signal.windows.hamming

from torchmetrics import functional  # noqa: E402
from torchmetrics.aggregation import (  # noqa: E402
    CatMetric,
    MaxMetric,
    MeanMetric,
    MinMetric,
    RunningMean,
    RunningSum,
    SumMetric,
)
from torchmetrics.audio._deprecated import _PermutationInvariantTraining as PermutationInvariantTraining  # noqa: E402
from torchmetrics.audio._deprecated import (  # noqa: E402
    _ScaleInvariantSignalDistortionRatio as ScaleInvariantSignalDistortionRatio,
)
from torchmetrics.audio._deprecated import (  # noqa: E402
    _ScaleInvariantSignalNoiseRatio as ScaleInvariantSignalNoiseRatio,
)
from torchmetrics.audio._deprecated import _SignalDistortionRatio as SignalDistortionRatio  # noqa: E402
from torchmetrics.audio._deprecated import _SignalNoiseRatio as SignalNoiseRatio  # noqa: E402
from torchmetrics.classification import (  # noqa: E402
    AUROC,
    ROC,
    Accuracy,
    AveragePrecision,
    CalibrationError,
    CohenKappa,
    ConfusionMatrix,
    Dice,
    ExactMatch,
    F1Score,
    FBetaScore,
    HammingDistance,
    HingeLoss,
    JaccardIndex,
    LogAUC,
    MatthewsCorrCoef,
    NegativePredictiveValue,
    Precision,
    PrecisionAtFixedRecall,
    PrecisionRecallCurve,
    Recall,
    RecallAtFixedPrecision,
    SensitivityAtSpecificity,
    Specificity,
    SpecificityAtSensitivity,
    StatScores,
)
from torchmetrics.collections import MetricCollection  # noqa: E402
from torchmetrics.detection._deprecated import _ModifiedPanopticQuality as ModifiedPanopticQuality  # noqa: E402
from torchmetrics.detection._deprecated import _PanopticQuality as PanopticQuality  # noqa: E402
from torchmetrics.image._deprecated import (  # noqa: E402
    _ErrorRelativeGlobalDimensionlessSynthesis as ErrorRelativeGlobalDimensionlessSynthesis,
)
from torchmetrics.image._deprecated import (  # noqa: E402
    _MultiScaleStructuralSimilarityIndexMeasure as MultiScaleStructuralSimilarityIndexMeasure,
)
from torchmetrics.image._deprecated import _PeakSignalNoiseRatio as PeakSignalNoiseRatio  # noqa: E402
from torchmetrics.image._deprecated import _RelativeAverageSpectralError as RelativeAverageSpectralError  # noqa: E402
from torchmetrics.image._deprecated import (  # noqa: E402
    _RootMeanSquaredErrorUsingSlidingWindow as RootMeanSquaredErrorUsingSlidingWindow,
)
from torchmetrics.image._deprecated import _SpectralAngleMapper as SpectralAngleMapper  # noqa: E402
from torchmetrics.image._deprecated import _SpectralDistortionIndex as SpectralDistortionIndex  # noqa: E402
from torchmetrics.image._deprecated import (  # noqa: E402
    _StructuralSimilarityIndexMeasure as StructuralSimilarityIndexMeasure,
)
from torchmetrics.image._deprecated import _TotalVariation as TotalVariation  # noqa: E402
from torchmetrics.image._deprecated import _UniversalImageQualityIndex as UniversalImageQualityIndex  # noqa: E402
from torchmetrics.metric import Metric  # noqa: E402
from torchmetrics.nominal import (  # noqa: E402
    CramersV,
    FleissKappa,
    PearsonsContingencyCoefficient,
    TheilsU,
    TschuprowsT,
)
from torchmetrics.regression import (  # noqa: E402
    ConcordanceCorrCoef,
    CosineSimilarity,
    CriticalSuccessIndex,
    ExplainedVariance,
    KendallRankCorrCoef,
    KLDivergence,
    LogCoshError,
    MeanAbsoluteError,
    MeanAbsolutePercentageError,
    MeanSquaredError,
    MeanSquaredLogError,
    MinkowskiDistance,
    NormalizedRootMeanSquaredError,
    PearsonCorrCoef,
    R2Score,
    RelativeSquaredError,
    SpearmanCorrCoef,
    SymmetricMeanAbsolutePercentageError,
    TweedieDevianceScore,
    WeightedMeanAbsolutePercentageError,
)
from torchmetrics.retrieval._deprecated import _RetrievalFallOut as RetrievalFallOut  # noqa: E402
from torchmetrics.retrieval._deprecated import _RetrievalHitRate as RetrievalHitRate  # noqa: E402
from torchmetrics.retrieval._deprecated import _RetrievalMAP as RetrievalMAP  # noqa: E402
from torchmetrics.retrieval._deprecated import _RetrievalMRR as RetrievalMRR  # noqa: E402
from torchmetrics.retrieval._deprecated import _RetrievalNormalizedDCG as RetrievalNormalizedDCG  # noqa: E402
from torchmetrics.retrieval._deprecated import _RetrievalPrecision as RetrievalPrecision  # noqa: E402
from torchmetrics.retrieval._deprecated import (  # noqa: E402
    _RetrievalPrecisionRecallCurve as RetrievalPrecisionRecallCurve,
)
from torchmetrics.retrieval._deprecated import _RetrievalRecall as RetrievalRecall  # noqa: E402
from torchmetrics.retrieval._deprecated import (  # noqa: E402
    _RetrievalRecallAtFixedPrecision as RetrievalRecallAtFixedPrecision,
)
from torchmetrics.retrieval._deprecated import _RetrievalRPrecision as RetrievalRPrecision  # noqa: E402
from torchmetrics.text._deprecated import _BLEUScore as BLEUScore  # noqa: E402
from torchmetrics.text._deprecated import _CharErrorRate as CharErrorRate  # noqa: E402
from torchmetrics.text._deprecated import _CHRFScore as CHRFScore  # noqa: E402
from torchmetrics.text._deprecated import _ExtendedEditDistance as ExtendedEditDistance  # noqa: E402
from torchmetrics.text._deprecated import _MatchErrorRate as MatchErrorRate  # noqa: E402
from torchmetrics.text._deprecated import _Perplexity as Perplexity  # noqa: E402
from torchmetrics.text._deprecated import _SacreBLEUScore as SacreBLEUScore  # noqa: E402
from torchmetrics.text._deprecated import _SQuAD as SQuAD  # noqa: E402
from torchmetrics.text._deprecated import _TranslationEditRate as TranslationEditRate  # noqa: E402
from torchmetrics.text._deprecated import _WordErrorRate as WordErrorRate  # noqa: E402
from torchmetrics.text._deprecated import _WordInfoLost as WordInfoLost  # noqa: E402
from torchmetrics.text._deprecated import _WordInfoPreserved as WordInfoPreserved  # noqa: E402
from torchmetrics.wrappers import (  # noqa: E402
    BootStrapper,
    ClasswiseWrapper,
    MetricTracker,
    MinMaxMetric,
    MultioutputWrapper,
    MultitaskWrapper,
)

__all__ = [
    "AUROC",
    "ROC",
    "Accuracy",
    "AveragePrecision",
    "BLEUScore",
    "BootStrapper",
    "CHRFScore",
    "CalibrationError",
    "CatMetric",
    "CharErrorRate",
    "ClasswiseWrapper",
    "CohenKappa",
    "ConcordanceCorrCoef",
    "ConfusionMatrix",
    "CosineSimilarity",
    "CramersV",
    "CriticalSuccessIndex",
    "Dice",
    "ErrorRelativeGlobalDimensionlessSynthesis",
    "ExactMatch",
    "ExplainedVariance",
    "ExtendedEditDistance",
    "F1Score",
    "FBetaScore",
    "FleissKappa",
    "HammingDistance",
    "HingeLoss",
    "JaccardIndex",
    "KLDivergence",
    "KendallRankCorrCoef",
    "LogAUC",
    "LogCoshError",
    "MatchErrorRate",
    "MatthewsCorrCoef",
    "MaxMetric",
    "MeanAbsoluteError",
    "MeanAbsolutePercentageError",
    "MeanMetric",
    "MeanSquaredError",
    "MeanSquaredLogError",
    "Metric",
    "MetricCollection",
    "MetricTracker",
    "MinMaxMetric",
    "MinMetric",
    "MinkowskiDistance",
    "ModifiedPanopticQuality",
    "MultiScaleStructuralSimilarityIndexMeasure",
    "MultioutputWrapper",
    "MultitaskWrapper",
    "NegativePredictiveValue",
    "NormalizedRootMeanSquaredError",
    "PanopticQuality",
    "PeakSignalNoiseRatio",
    "PearsonCorrCoef",
    "PearsonsContingencyCoefficient",
    "PermutationInvariantTraining",
    "Perplexity",
    "Precision",
    "PrecisionAtFixedRecall",
    "PrecisionRecallCurve",
    "R2Score",
    "Recall",
    "RecallAtFixedPrecision",
    "RelativeAverageSpectralError",
    "RelativeSquaredError",
    "RetrievalFallOut",
    "RetrievalHitRate",
    "RetrievalMAP",
    "RetrievalMRR",
    "RetrievalNormalizedDCG",
    "RetrievalPrecision",
    "RetrievalPrecisionRecallCurve",
    "RetrievalRPrecision",
    "RetrievalRecall",
    "RetrievalRecallAtFixedPrecision",
    "RootMeanSquaredErrorUsingSlidingWindow",
    "RunningMean",
    "RunningSum",
    "SQuAD",
    "SacreBLEUScore",
    "ScaleInvariantSignalDistortionRatio",
    "ScaleInvariantSignalNoiseRatio",
    "SensitivityAtSpecificity",
    "SignalDistortionRatio",
    "SignalNoiseRatio",
    "SpearmanCorrCoef",
    "Specificity",
    "SpecificityAtSensitivity",
    "SpectralAngleMapper",
    "SpectralDistortionIndex",
    "StatScores",
    "StructuralSimilarityIndexMeasure",
    "SumMetric",
    "SymmetricMeanAbsolutePercentageError",
    "TheilsU",
    "TotalVariation",
    "TranslationEditRate",
    "TschuprowsT",
    "TweedieDevianceScore",
    "UniversalImageQualityIndex",
    "WeightedMeanAbsolutePercentageError",
    "WordErrorRate",
    "WordInfoLost",
    "WordInfoPreserved",
    "functional",
]