from collections.abc import Sequence from typing import Any, Optional, Union from typing_extensions import Literal from torchmetrics.image.d_lambda import SpectralDistortionIndex from torchmetrics.image.ergas import ErrorRelativeGlobalDimensionlessSynthesis from torchmetrics.image.psnr import PeakSignalNoiseRatio from torchmetrics.image.rase import RelativeAverageSpectralError from torchmetrics.image.rmse_sw import RootMeanSquaredErrorUsingSlidingWindow from torchmetrics.image.sam import SpectralAngleMapper from torchmetrics.image.ssim import MultiScaleStructuralSimilarityIndexMeasure, StructuralSimilarityIndexMeasure from torchmetrics.image.tv import TotalVariation from torchmetrics.image.uqi import UniversalImageQualityIndex from torchmetrics.utilities.prints import _deprecated_root_import_class class _ErrorRelativeGlobalDimensionlessSynthesis(ErrorRelativeGlobalDimensionlessSynthesis): """Wrapper for deprecated import. >>> from torch import rand >>> preds = rand([16, 1, 16, 16]) >>> target = preds * 0.75 >>> ergas = _ErrorRelativeGlobalDimensionlessSynthesis() >>> ergas(preds, target).round() tensor(10.) """ def __init__( self, ratio: float = 4, reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean", **kwargs: Any, ) -> None: _deprecated_root_import_class("ErrorRelativeGlobalDimensionlessSynthesis", "image") super().__init__(ratio=ratio, reduction=reduction, **kwargs) class _MultiScaleStructuralSimilarityIndexMeasure(MultiScaleStructuralSimilarityIndexMeasure): """Wrapper for deprecated import. >>> from torch import rand >>> preds = rand([3, 3, 256, 256]) >>> target = preds * 0.75 >>> ms_ssim = _MultiScaleStructuralSimilarityIndexMeasure(data_range=1.0) >>> ms_ssim(preds, target) tensor(0.9628) """ def __init__( self, gaussian_kernel: bool = True, kernel_size: Union[int, Sequence[int]] = 11, sigma: Union[float, Sequence[float]] = 1.5, reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean", data_range: Optional[Union[float, tuple[float, float]]] = None, k1: float = 0.01, k2: float = 0.03, betas: tuple[float, ...] = (0.0448, 0.2856, 0.3001, 0.2363, 0.1333), normalize: Literal["relu", "simple", None] = "relu", **kwargs: Any, ) -> None: _deprecated_root_import_class("MultiScaleStructuralSimilarityIndexMeasure", "image") super().__init__( gaussian_kernel=gaussian_kernel, kernel_size=kernel_size, sigma=sigma, reduction=reduction, data_range=data_range, k1=k1, k2=k2, betas=betas, normalize=normalize, **kwargs, ) class _PeakSignalNoiseRatio(PeakSignalNoiseRatio): """Wrapper for deprecated import. >>> from torch import tensor >>> psnr = _PeakSignalNoiseRatio() >>> preds = tensor([[0.0, 1.0], [2.0, 3.0]]) >>> target = tensor([[3.0, 2.0], [1.0, 0.0]]) >>> psnr(preds, target) tensor(2.5527) """ def __init__( self, data_range: Optional[Union[float, tuple[float, float]]] = None, base: float = 10.0, reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean", dim: Optional[Union[int, tuple[int, ...]]] = None, **kwargs: Any, ) -> None: _deprecated_root_import_class("PeakSignalNoiseRatio", "image") super().__init__(data_range=data_range, base=base, reduction=reduction, dim=dim, **kwargs) class _RelativeAverageSpectralError(RelativeAverageSpectralError): """Wrapper for deprecated import. >>> from torch import rand >>> preds = rand(4, 3, 16, 16) >>> target = rand(4, 3, 16, 16) >>> rase = _RelativeAverageSpectralError() >>> rase(preds, target) tensor(5326.40...) """ def __init__( self, window_size: int = 8, **kwargs: dict[str, Any], ) -> None: _deprecated_root_import_class("RelativeAverageSpectralError", "image") super().__init__(window_size=window_size, **kwargs) class _RootMeanSquaredErrorUsingSlidingWindow(RootMeanSquaredErrorUsingSlidingWindow): """Wrapper for deprecated import. >>> from torch import rand >>> preds = rand(4, 3, 16, 16) >>> target = rand(4, 3, 16, 16) >>> rmse_sw = RootMeanSquaredErrorUsingSlidingWindow() >>> rmse_sw(preds, target) tensor(0.4158) """ def __init__( self, window_size: int = 8, **kwargs: dict[str, Any], ) -> None: _deprecated_root_import_class("RootMeanSquaredErrorUsingSlidingWindow", "image") super().__init__(window_size=window_size, **kwargs) class _SpectralAngleMapper(SpectralAngleMapper): """Wrapper for deprecated import. >>> from torch import rand >>> preds = rand([16, 3, 16, 16]) >>> target = rand([16, 3, 16, 16]) >>> sam = _SpectralAngleMapper() >>> sam(preds, target) tensor(0.5914) """ def __init__( self, reduction: Literal["elementwise_mean", "sum", "none"] = "elementwise_mean", **kwargs: Any, ) -> None: _deprecated_root_import_class("SpectralAngleMapper", "image") super().__init__(reduction=reduction, **kwargs) class _SpectralDistortionIndex(SpectralDistortionIndex): """Wrapper for deprecated import. >>> from torch import rand >>> preds = rand([16, 3, 16, 16]) >>> target = rand([16, 3, 16, 16]) >>> sdi = _SpectralDistortionIndex() >>> sdi(preds, target) tensor(0.0234) """ def __init__( self, p: int = 1, reduction: Literal["elementwise_mean", "sum", "none"] = "elementwise_mean", **kwargs: Any ) -> None: _deprecated_root_import_class("SpectralDistortionIndex", "image") super().__init__(p=p, reduction=reduction, **kwargs) class _StructuralSimilarityIndexMeasure(StructuralSimilarityIndexMeasure): """Wrapper for deprecated import. >>> import torch >>> preds = torch.rand([3, 3, 256, 256]) >>> target = preds * 0.75 >>> ssim = _StructuralSimilarityIndexMeasure(data_range=1.0) >>> ssim(preds, target) tensor(0.9219) """ def __init__( self, gaussian_kernel: bool = True, sigma: Union[float, Sequence[float]] = 1.5, kernel_size: Union[int, Sequence[int]] = 11, reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean", data_range: Optional[Union[float, tuple[float, float]]] = None, k1: float = 0.01, k2: float = 0.03, return_full_image: bool = False, return_contrast_sensitivity: bool = False, **kwargs: Any, ) -> None: _deprecated_root_import_class("StructuralSimilarityIndexMeasure", "image") super().__init__( gaussian_kernel=gaussian_kernel, sigma=sigma, kernel_size=kernel_size, reduction=reduction, data_range=data_range, k1=k1, k2=k2, return_full_image=return_full_image, return_contrast_sensitivity=return_contrast_sensitivity, **kwargs, ) class _TotalVariation(TotalVariation): """Wrapper for deprecated import. >>> from torch import rand >>> tv = _TotalVariation() >>> img = rand(5, 3, 28, 28) >>> tv(img) tensor(7546.8018) """ def __init__(self, reduction: Literal["mean", "sum", "none", None] = "sum", **kwargs: Any) -> None: _deprecated_root_import_class("TotalVariation", "image") super().__init__(reduction=reduction, **kwargs) class _UniversalImageQualityIndex(UniversalImageQualityIndex): """Wrapper for deprecated import. >>> import torch >>> preds = torch.rand([16, 1, 16, 16]) >>> target = preds * 0.75 >>> uqi = _UniversalImageQualityIndex() >>> uqi(preds, target) tensor(0.9216) """ def __init__( self, kernel_size: Sequence[int] = (11, 11), sigma: Sequence[float] = (1.5, 1.5), reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean", **kwargs: Any, ) -> None: _deprecated_root_import_class("UniversalImageQualityIndex", "image") super().__init__(kernel_size=kernel_size, sigma=sigma, reduction=reduction, **kwargs)