|
from collections.abc import Sequence |
|
from typing import Any, Optional, Union |
|
|
|
from typing_extensions import Literal |
|
|
|
from torchmetrics.image.d_lambda import SpectralDistortionIndex |
|
from torchmetrics.image.ergas import ErrorRelativeGlobalDimensionlessSynthesis |
|
from torchmetrics.image.psnr import PeakSignalNoiseRatio |
|
from torchmetrics.image.rase import RelativeAverageSpectralError |
|
from torchmetrics.image.rmse_sw import RootMeanSquaredErrorUsingSlidingWindow |
|
from torchmetrics.image.sam import SpectralAngleMapper |
|
from torchmetrics.image.ssim import MultiScaleStructuralSimilarityIndexMeasure, StructuralSimilarityIndexMeasure |
|
from torchmetrics.image.tv import TotalVariation |
|
from torchmetrics.image.uqi import UniversalImageQualityIndex |
|
from torchmetrics.utilities.prints import _deprecated_root_import_class |
|
|
|
|
|
class _ErrorRelativeGlobalDimensionlessSynthesis(ErrorRelativeGlobalDimensionlessSynthesis): |
|
"""Wrapper for deprecated import. |
|
|
|
>>> from torch import rand |
|
>>> preds = rand([16, 1, 16, 16]) |
|
>>> target = preds * 0.75 |
|
>>> ergas = _ErrorRelativeGlobalDimensionlessSynthesis() |
|
>>> ergas(preds, target).round() |
|
tensor(10.) |
|
|
|
""" |
|
|
|
def __init__( |
|
self, |
|
ratio: float = 4, |
|
reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean", |
|
**kwargs: Any, |
|
) -> None: |
|
_deprecated_root_import_class("ErrorRelativeGlobalDimensionlessSynthesis", "image") |
|
super().__init__(ratio=ratio, reduction=reduction, **kwargs) |
|
|
|
|
|
class _MultiScaleStructuralSimilarityIndexMeasure(MultiScaleStructuralSimilarityIndexMeasure): |
|
"""Wrapper for deprecated import. |
|
|
|
>>> from torch import rand |
|
>>> preds = rand([3, 3, 256, 256]) |
|
>>> target = preds * 0.75 |
|
>>> ms_ssim = _MultiScaleStructuralSimilarityIndexMeasure(data_range=1.0) |
|
>>> ms_ssim(preds, target) |
|
tensor(0.9628) |
|
|
|
""" |
|
|
|
def __init__( |
|
self, |
|
gaussian_kernel: bool = True, |
|
kernel_size: Union[int, Sequence[int]] = 11, |
|
sigma: Union[float, Sequence[float]] = 1.5, |
|
reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean", |
|
data_range: Optional[Union[float, tuple[float, float]]] = None, |
|
k1: float = 0.01, |
|
k2: float = 0.03, |
|
betas: tuple[float, ...] = (0.0448, 0.2856, 0.3001, 0.2363, 0.1333), |
|
normalize: Literal["relu", "simple", None] = "relu", |
|
**kwargs: Any, |
|
) -> None: |
|
_deprecated_root_import_class("MultiScaleStructuralSimilarityIndexMeasure", "image") |
|
super().__init__( |
|
gaussian_kernel=gaussian_kernel, |
|
kernel_size=kernel_size, |
|
sigma=sigma, |
|
reduction=reduction, |
|
data_range=data_range, |
|
k1=k1, |
|
k2=k2, |
|
betas=betas, |
|
normalize=normalize, |
|
**kwargs, |
|
) |
|
|
|
|
|
class _PeakSignalNoiseRatio(PeakSignalNoiseRatio): |
|
"""Wrapper for deprecated import. |
|
|
|
>>> from torch import tensor |
|
>>> psnr = _PeakSignalNoiseRatio() |
|
>>> preds = tensor([[0.0, 1.0], [2.0, 3.0]]) |
|
>>> target = tensor([[3.0, 2.0], [1.0, 0.0]]) |
|
>>> psnr(preds, target) |
|
tensor(2.5527) |
|
|
|
""" |
|
|
|
def __init__( |
|
self, |
|
data_range: Optional[Union[float, tuple[float, float]]] = None, |
|
base: float = 10.0, |
|
reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean", |
|
dim: Optional[Union[int, tuple[int, ...]]] = None, |
|
**kwargs: Any, |
|
) -> None: |
|
_deprecated_root_import_class("PeakSignalNoiseRatio", "image") |
|
super().__init__(data_range=data_range, base=base, reduction=reduction, dim=dim, **kwargs) |
|
|
|
|
|
class _RelativeAverageSpectralError(RelativeAverageSpectralError): |
|
"""Wrapper for deprecated import. |
|
|
|
>>> from torch import rand |
|
>>> preds = rand(4, 3, 16, 16) |
|
>>> target = rand(4, 3, 16, 16) |
|
>>> rase = _RelativeAverageSpectralError() |
|
>>> rase(preds, target) |
|
tensor(5326.40...) |
|
|
|
""" |
|
|
|
def __init__( |
|
self, |
|
window_size: int = 8, |
|
**kwargs: dict[str, Any], |
|
) -> None: |
|
_deprecated_root_import_class("RelativeAverageSpectralError", "image") |
|
super().__init__(window_size=window_size, **kwargs) |
|
|
|
|
|
class _RootMeanSquaredErrorUsingSlidingWindow(RootMeanSquaredErrorUsingSlidingWindow): |
|
"""Wrapper for deprecated import. |
|
|
|
>>> from torch import rand |
|
>>> preds = rand(4, 3, 16, 16) |
|
>>> target = rand(4, 3, 16, 16) |
|
>>> rmse_sw = RootMeanSquaredErrorUsingSlidingWindow() |
|
>>> rmse_sw(preds, target) |
|
tensor(0.4158) |
|
|
|
""" |
|
|
|
def __init__( |
|
self, |
|
window_size: int = 8, |
|
**kwargs: dict[str, Any], |
|
) -> None: |
|
_deprecated_root_import_class("RootMeanSquaredErrorUsingSlidingWindow", "image") |
|
super().__init__(window_size=window_size, **kwargs) |
|
|
|
|
|
class _SpectralAngleMapper(SpectralAngleMapper): |
|
"""Wrapper for deprecated import. |
|
|
|
>>> from torch import rand |
|
>>> preds = rand([16, 3, 16, 16]) |
|
>>> target = rand([16, 3, 16, 16]) |
|
>>> sam = _SpectralAngleMapper() |
|
>>> sam(preds, target) |
|
tensor(0.5914) |
|
|
|
""" |
|
|
|
def __init__( |
|
self, |
|
reduction: Literal["elementwise_mean", "sum", "none"] = "elementwise_mean", |
|
**kwargs: Any, |
|
) -> None: |
|
_deprecated_root_import_class("SpectralAngleMapper", "image") |
|
super().__init__(reduction=reduction, **kwargs) |
|
|
|
|
|
class _SpectralDistortionIndex(SpectralDistortionIndex): |
|
"""Wrapper for deprecated import. |
|
|
|
>>> from torch import rand |
|
>>> preds = rand([16, 3, 16, 16]) |
|
>>> target = rand([16, 3, 16, 16]) |
|
>>> sdi = _SpectralDistortionIndex() |
|
>>> sdi(preds, target) |
|
tensor(0.0234) |
|
|
|
""" |
|
|
|
def __init__( |
|
self, p: int = 1, reduction: Literal["elementwise_mean", "sum", "none"] = "elementwise_mean", **kwargs: Any |
|
) -> None: |
|
_deprecated_root_import_class("SpectralDistortionIndex", "image") |
|
super().__init__(p=p, reduction=reduction, **kwargs) |
|
|
|
|
|
class _StructuralSimilarityIndexMeasure(StructuralSimilarityIndexMeasure): |
|
"""Wrapper for deprecated import. |
|
|
|
>>> import torch |
|
>>> preds = torch.rand([3, 3, 256, 256]) |
|
>>> target = preds * 0.75 |
|
>>> ssim = _StructuralSimilarityIndexMeasure(data_range=1.0) |
|
>>> ssim(preds, target) |
|
tensor(0.9219) |
|
|
|
""" |
|
|
|
def __init__( |
|
self, |
|
gaussian_kernel: bool = True, |
|
sigma: Union[float, Sequence[float]] = 1.5, |
|
kernel_size: Union[int, Sequence[int]] = 11, |
|
reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean", |
|
data_range: Optional[Union[float, tuple[float, float]]] = None, |
|
k1: float = 0.01, |
|
k2: float = 0.03, |
|
return_full_image: bool = False, |
|
return_contrast_sensitivity: bool = False, |
|
**kwargs: Any, |
|
) -> None: |
|
_deprecated_root_import_class("StructuralSimilarityIndexMeasure", "image") |
|
super().__init__( |
|
gaussian_kernel=gaussian_kernel, |
|
sigma=sigma, |
|
kernel_size=kernel_size, |
|
reduction=reduction, |
|
data_range=data_range, |
|
k1=k1, |
|
k2=k2, |
|
return_full_image=return_full_image, |
|
return_contrast_sensitivity=return_contrast_sensitivity, |
|
**kwargs, |
|
) |
|
|
|
|
|
class _TotalVariation(TotalVariation): |
|
"""Wrapper for deprecated import. |
|
|
|
>>> from torch import rand |
|
>>> tv = _TotalVariation() |
|
>>> img = rand(5, 3, 28, 28) |
|
>>> tv(img) |
|
tensor(7546.8018) |
|
|
|
""" |
|
|
|
def __init__(self, reduction: Literal["mean", "sum", "none", None] = "sum", **kwargs: Any) -> None: |
|
_deprecated_root_import_class("TotalVariation", "image") |
|
super().__init__(reduction=reduction, **kwargs) |
|
|
|
|
|
class _UniversalImageQualityIndex(UniversalImageQualityIndex): |
|
"""Wrapper for deprecated import. |
|
|
|
>>> import torch |
|
>>> preds = torch.rand([16, 1, 16, 16]) |
|
>>> target = preds * 0.75 |
|
>>> uqi = _UniversalImageQualityIndex() |
|
>>> uqi(preds, target) |
|
tensor(0.9216) |
|
|
|
""" |
|
|
|
def __init__( |
|
self, |
|
kernel_size: Sequence[int] = (11, 11), |
|
sigma: Sequence[float] = (1.5, 1.5), |
|
reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean", |
|
**kwargs: Any, |
|
) -> None: |
|
_deprecated_root_import_class("UniversalImageQualityIndex", "image") |
|
super().__init__(kernel_size=kernel_size, sigma=sigma, reduction=reduction, **kwargs) |
|
|