|
r"""Root package info.""" |
|
|
|
import logging as __logging |
|
import os |
|
|
|
from lightning_utilities.core.imports import package_available |
|
|
|
from torchmetrics.__about__ import * |
|
|
|
_logger = __logging.getLogger("torchmetrics") |
|
_logger.addHandler(__logging.StreamHandler()) |
|
_logger.setLevel(__logging.INFO) |
|
|
|
_PACKAGE_ROOT = os.path.dirname(__file__) |
|
_PROJECT_ROOT = os.path.dirname(_PACKAGE_ROOT) |
|
|
|
if package_available("numpy"): |
|
|
|
import numpy |
|
|
|
numpy.Inf = numpy.inf |
|
|
|
|
|
if package_available("PIL"): |
|
import PIL |
|
|
|
if not hasattr(PIL, "PILLOW_VERSION"): |
|
PIL.PILLOW_VERSION = PIL.__version__ |
|
|
|
if package_available("scipy"): |
|
import scipy.signal |
|
|
|
|
|
if not hasattr(scipy.signal, "hamming"): |
|
scipy.signal.hamming = scipy.signal.windows.hamming |
|
|
|
from torchmetrics import functional |
|
from torchmetrics.aggregation import ( |
|
CatMetric, |
|
MaxMetric, |
|
MeanMetric, |
|
MinMetric, |
|
RunningMean, |
|
RunningSum, |
|
SumMetric, |
|
) |
|
from torchmetrics.audio._deprecated import _PermutationInvariantTraining as PermutationInvariantTraining |
|
from torchmetrics.audio._deprecated import ( |
|
_ScaleInvariantSignalDistortionRatio as ScaleInvariantSignalDistortionRatio, |
|
) |
|
from torchmetrics.audio._deprecated import ( |
|
_ScaleInvariantSignalNoiseRatio as ScaleInvariantSignalNoiseRatio, |
|
) |
|
from torchmetrics.audio._deprecated import _SignalDistortionRatio as SignalDistortionRatio |
|
from torchmetrics.audio._deprecated import _SignalNoiseRatio as SignalNoiseRatio |
|
from torchmetrics.classification import ( |
|
AUROC, |
|
ROC, |
|
Accuracy, |
|
AveragePrecision, |
|
CalibrationError, |
|
CohenKappa, |
|
ConfusionMatrix, |
|
Dice, |
|
ExactMatch, |
|
F1Score, |
|
FBetaScore, |
|
HammingDistance, |
|
HingeLoss, |
|
JaccardIndex, |
|
LogAUC, |
|
MatthewsCorrCoef, |
|
NegativePredictiveValue, |
|
Precision, |
|
PrecisionAtFixedRecall, |
|
PrecisionRecallCurve, |
|
Recall, |
|
RecallAtFixedPrecision, |
|
SensitivityAtSpecificity, |
|
Specificity, |
|
SpecificityAtSensitivity, |
|
StatScores, |
|
) |
|
from torchmetrics.collections import MetricCollection |
|
from torchmetrics.detection._deprecated import _ModifiedPanopticQuality as ModifiedPanopticQuality |
|
from torchmetrics.detection._deprecated import _PanopticQuality as PanopticQuality |
|
from torchmetrics.image._deprecated import ( |
|
_ErrorRelativeGlobalDimensionlessSynthesis as ErrorRelativeGlobalDimensionlessSynthesis, |
|
) |
|
from torchmetrics.image._deprecated import ( |
|
_MultiScaleStructuralSimilarityIndexMeasure as MultiScaleStructuralSimilarityIndexMeasure, |
|
) |
|
from torchmetrics.image._deprecated import _PeakSignalNoiseRatio as PeakSignalNoiseRatio |
|
from torchmetrics.image._deprecated import _RelativeAverageSpectralError as RelativeAverageSpectralError |
|
from torchmetrics.image._deprecated import ( |
|
_RootMeanSquaredErrorUsingSlidingWindow as RootMeanSquaredErrorUsingSlidingWindow, |
|
) |
|
from torchmetrics.image._deprecated import _SpectralAngleMapper as SpectralAngleMapper |
|
from torchmetrics.image._deprecated import _SpectralDistortionIndex as SpectralDistortionIndex |
|
from torchmetrics.image._deprecated import ( |
|
_StructuralSimilarityIndexMeasure as StructuralSimilarityIndexMeasure, |
|
) |
|
from torchmetrics.image._deprecated import _TotalVariation as TotalVariation |
|
from torchmetrics.image._deprecated import _UniversalImageQualityIndex as UniversalImageQualityIndex |
|
from torchmetrics.metric import Metric |
|
from torchmetrics.nominal import ( |
|
CramersV, |
|
FleissKappa, |
|
PearsonsContingencyCoefficient, |
|
TheilsU, |
|
TschuprowsT, |
|
) |
|
from torchmetrics.regression import ( |
|
ConcordanceCorrCoef, |
|
CosineSimilarity, |
|
CriticalSuccessIndex, |
|
ExplainedVariance, |
|
KendallRankCorrCoef, |
|
KLDivergence, |
|
LogCoshError, |
|
MeanAbsoluteError, |
|
MeanAbsolutePercentageError, |
|
MeanSquaredError, |
|
MeanSquaredLogError, |
|
MinkowskiDistance, |
|
NormalizedRootMeanSquaredError, |
|
PearsonCorrCoef, |
|
R2Score, |
|
RelativeSquaredError, |
|
SpearmanCorrCoef, |
|
SymmetricMeanAbsolutePercentageError, |
|
TweedieDevianceScore, |
|
WeightedMeanAbsolutePercentageError, |
|
) |
|
from torchmetrics.retrieval._deprecated import _RetrievalFallOut as RetrievalFallOut |
|
from torchmetrics.retrieval._deprecated import _RetrievalHitRate as RetrievalHitRate |
|
from torchmetrics.retrieval._deprecated import _RetrievalMAP as RetrievalMAP |
|
from torchmetrics.retrieval._deprecated import _RetrievalMRR as RetrievalMRR |
|
from torchmetrics.retrieval._deprecated import _RetrievalNormalizedDCG as RetrievalNormalizedDCG |
|
from torchmetrics.retrieval._deprecated import _RetrievalPrecision as RetrievalPrecision |
|
from torchmetrics.retrieval._deprecated import ( |
|
_RetrievalPrecisionRecallCurve as RetrievalPrecisionRecallCurve, |
|
) |
|
from torchmetrics.retrieval._deprecated import _RetrievalRecall as RetrievalRecall |
|
from torchmetrics.retrieval._deprecated import ( |
|
_RetrievalRecallAtFixedPrecision as RetrievalRecallAtFixedPrecision, |
|
) |
|
from torchmetrics.retrieval._deprecated import _RetrievalRPrecision as RetrievalRPrecision |
|
from torchmetrics.text._deprecated import _BLEUScore as BLEUScore |
|
from torchmetrics.text._deprecated import _CharErrorRate as CharErrorRate |
|
from torchmetrics.text._deprecated import _CHRFScore as CHRFScore |
|
from torchmetrics.text._deprecated import _ExtendedEditDistance as ExtendedEditDistance |
|
from torchmetrics.text._deprecated import _MatchErrorRate as MatchErrorRate |
|
from torchmetrics.text._deprecated import _Perplexity as Perplexity |
|
from torchmetrics.text._deprecated import _SacreBLEUScore as SacreBLEUScore |
|
from torchmetrics.text._deprecated import _SQuAD as SQuAD |
|
from torchmetrics.text._deprecated import _TranslationEditRate as TranslationEditRate |
|
from torchmetrics.text._deprecated import _WordErrorRate as WordErrorRate |
|
from torchmetrics.text._deprecated import _WordInfoLost as WordInfoLost |
|
from torchmetrics.text._deprecated import _WordInfoPreserved as WordInfoPreserved |
|
from torchmetrics.wrappers import ( |
|
BootStrapper, |
|
ClasswiseWrapper, |
|
MetricTracker, |
|
MinMaxMetric, |
|
MultioutputWrapper, |
|
MultitaskWrapper, |
|
) |
|
|
|
__all__ = [ |
|
"AUROC", |
|
"ROC", |
|
"Accuracy", |
|
"AveragePrecision", |
|
"BLEUScore", |
|
"BootStrapper", |
|
"CHRFScore", |
|
"CalibrationError", |
|
"CatMetric", |
|
"CharErrorRate", |
|
"ClasswiseWrapper", |
|
"CohenKappa", |
|
"ConcordanceCorrCoef", |
|
"ConfusionMatrix", |
|
"CosineSimilarity", |
|
"CramersV", |
|
"CriticalSuccessIndex", |
|
"Dice", |
|
"ErrorRelativeGlobalDimensionlessSynthesis", |
|
"ExactMatch", |
|
"ExplainedVariance", |
|
"ExtendedEditDistance", |
|
"F1Score", |
|
"FBetaScore", |
|
"FleissKappa", |
|
"HammingDistance", |
|
"HingeLoss", |
|
"JaccardIndex", |
|
"KLDivergence", |
|
"KendallRankCorrCoef", |
|
"LogAUC", |
|
"LogCoshError", |
|
"MatchErrorRate", |
|
"MatthewsCorrCoef", |
|
"MaxMetric", |
|
"MeanAbsoluteError", |
|
"MeanAbsolutePercentageError", |
|
"MeanMetric", |
|
"MeanSquaredError", |
|
"MeanSquaredLogError", |
|
"Metric", |
|
"MetricCollection", |
|
"MetricTracker", |
|
"MinMaxMetric", |
|
"MinMetric", |
|
"MinkowskiDistance", |
|
"ModifiedPanopticQuality", |
|
"MultiScaleStructuralSimilarityIndexMeasure", |
|
"MultioutputWrapper", |
|
"MultitaskWrapper", |
|
"NegativePredictiveValue", |
|
"NormalizedRootMeanSquaredError", |
|
"PanopticQuality", |
|
"PeakSignalNoiseRatio", |
|
"PearsonCorrCoef", |
|
"PearsonsContingencyCoefficient", |
|
"PermutationInvariantTraining", |
|
"Perplexity", |
|
"Precision", |
|
"PrecisionAtFixedRecall", |
|
"PrecisionRecallCurve", |
|
"R2Score", |
|
"Recall", |
|
"RecallAtFixedPrecision", |
|
"RelativeAverageSpectralError", |
|
"RelativeSquaredError", |
|
"RetrievalFallOut", |
|
"RetrievalHitRate", |
|
"RetrievalMAP", |
|
"RetrievalMRR", |
|
"RetrievalNormalizedDCG", |
|
"RetrievalPrecision", |
|
"RetrievalPrecisionRecallCurve", |
|
"RetrievalRPrecision", |
|
"RetrievalRecall", |
|
"RetrievalRecallAtFixedPrecision", |
|
"RootMeanSquaredErrorUsingSlidingWindow", |
|
"RunningMean", |
|
"RunningSum", |
|
"SQuAD", |
|
"SacreBLEUScore", |
|
"ScaleInvariantSignalDistortionRatio", |
|
"ScaleInvariantSignalNoiseRatio", |
|
"SensitivityAtSpecificity", |
|
"SignalDistortionRatio", |
|
"SignalNoiseRatio", |
|
"SpearmanCorrCoef", |
|
"Specificity", |
|
"SpecificityAtSensitivity", |
|
"SpectralAngleMapper", |
|
"SpectralDistortionIndex", |
|
"StatScores", |
|
"StructuralSimilarityIndexMeasure", |
|
"SumMetric", |
|
"SymmetricMeanAbsolutePercentageError", |
|
"TheilsU", |
|
"TotalVariation", |
|
"TranslationEditRate", |
|
"TschuprowsT", |
|
"TweedieDevianceScore", |
|
"UniversalImageQualityIndex", |
|
"WeightedMeanAbsolutePercentageError", |
|
"WordErrorRate", |
|
"WordInfoLost", |
|
"WordInfoPreserved", |
|
"functional", |
|
] |
|
|