jamtur01's picture
Upload folder using huggingface_hub
9c6594c verified
from collections.abc import Sequence
from typing import Any, Optional, Union
from typing_extensions import Literal
from torchmetrics.image.d_lambda import SpectralDistortionIndex
from torchmetrics.image.ergas import ErrorRelativeGlobalDimensionlessSynthesis
from torchmetrics.image.psnr import PeakSignalNoiseRatio
from torchmetrics.image.rase import RelativeAverageSpectralError
from torchmetrics.image.rmse_sw import RootMeanSquaredErrorUsingSlidingWindow
from torchmetrics.image.sam import SpectralAngleMapper
from torchmetrics.image.ssim import MultiScaleStructuralSimilarityIndexMeasure, StructuralSimilarityIndexMeasure
from torchmetrics.image.tv import TotalVariation
from torchmetrics.image.uqi import UniversalImageQualityIndex
from torchmetrics.utilities.prints import _deprecated_root_import_class
class _ErrorRelativeGlobalDimensionlessSynthesis(ErrorRelativeGlobalDimensionlessSynthesis):
"""Wrapper for deprecated import.
>>> from torch import rand
>>> preds = rand([16, 1, 16, 16])
>>> target = preds * 0.75
>>> ergas = _ErrorRelativeGlobalDimensionlessSynthesis()
>>> ergas(preds, target).round()
tensor(10.)
"""
def __init__(
self,
ratio: float = 4,
reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean",
**kwargs: Any,
) -> None:
_deprecated_root_import_class("ErrorRelativeGlobalDimensionlessSynthesis", "image")
super().__init__(ratio=ratio, reduction=reduction, **kwargs)
class _MultiScaleStructuralSimilarityIndexMeasure(MultiScaleStructuralSimilarityIndexMeasure):
"""Wrapper for deprecated import.
>>> from torch import rand
>>> preds = rand([3, 3, 256, 256])
>>> target = preds * 0.75
>>> ms_ssim = _MultiScaleStructuralSimilarityIndexMeasure(data_range=1.0)
>>> ms_ssim(preds, target)
tensor(0.9628)
"""
def __init__(
self,
gaussian_kernel: bool = True,
kernel_size: Union[int, Sequence[int]] = 11,
sigma: Union[float, Sequence[float]] = 1.5,
reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean",
data_range: Optional[Union[float, tuple[float, float]]] = None,
k1: float = 0.01,
k2: float = 0.03,
betas: tuple[float, ...] = (0.0448, 0.2856, 0.3001, 0.2363, 0.1333),
normalize: Literal["relu", "simple", None] = "relu",
**kwargs: Any,
) -> None:
_deprecated_root_import_class("MultiScaleStructuralSimilarityIndexMeasure", "image")
super().__init__(
gaussian_kernel=gaussian_kernel,
kernel_size=kernel_size,
sigma=sigma,
reduction=reduction,
data_range=data_range,
k1=k1,
k2=k2,
betas=betas,
normalize=normalize,
**kwargs,
)
class _PeakSignalNoiseRatio(PeakSignalNoiseRatio):
"""Wrapper for deprecated import.
>>> from torch import tensor
>>> psnr = _PeakSignalNoiseRatio()
>>> preds = tensor([[0.0, 1.0], [2.0, 3.0]])
>>> target = tensor([[3.0, 2.0], [1.0, 0.0]])
>>> psnr(preds, target)
tensor(2.5527)
"""
def __init__(
self,
data_range: Optional[Union[float, tuple[float, float]]] = None,
base: float = 10.0,
reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean",
dim: Optional[Union[int, tuple[int, ...]]] = None,
**kwargs: Any,
) -> None:
_deprecated_root_import_class("PeakSignalNoiseRatio", "image")
super().__init__(data_range=data_range, base=base, reduction=reduction, dim=dim, **kwargs)
class _RelativeAverageSpectralError(RelativeAverageSpectralError):
"""Wrapper for deprecated import.
>>> from torch import rand
>>> preds = rand(4, 3, 16, 16)
>>> target = rand(4, 3, 16, 16)
>>> rase = _RelativeAverageSpectralError()
>>> rase(preds, target)
tensor(5326.40...)
"""
def __init__(
self,
window_size: int = 8,
**kwargs: dict[str, Any],
) -> None:
_deprecated_root_import_class("RelativeAverageSpectralError", "image")
super().__init__(window_size=window_size, **kwargs)
class _RootMeanSquaredErrorUsingSlidingWindow(RootMeanSquaredErrorUsingSlidingWindow):
"""Wrapper for deprecated import.
>>> from torch import rand
>>> preds = rand(4, 3, 16, 16)
>>> target = rand(4, 3, 16, 16)
>>> rmse_sw = RootMeanSquaredErrorUsingSlidingWindow()
>>> rmse_sw(preds, target)
tensor(0.4158)
"""
def __init__(
self,
window_size: int = 8,
**kwargs: dict[str, Any],
) -> None:
_deprecated_root_import_class("RootMeanSquaredErrorUsingSlidingWindow", "image")
super().__init__(window_size=window_size, **kwargs)
class _SpectralAngleMapper(SpectralAngleMapper):
"""Wrapper for deprecated import.
>>> from torch import rand
>>> preds = rand([16, 3, 16, 16])
>>> target = rand([16, 3, 16, 16])
>>> sam = _SpectralAngleMapper()
>>> sam(preds, target)
tensor(0.5914)
"""
def __init__(
self,
reduction: Literal["elementwise_mean", "sum", "none"] = "elementwise_mean",
**kwargs: Any,
) -> None:
_deprecated_root_import_class("SpectralAngleMapper", "image")
super().__init__(reduction=reduction, **kwargs)
class _SpectralDistortionIndex(SpectralDistortionIndex):
"""Wrapper for deprecated import.
>>> from torch import rand
>>> preds = rand([16, 3, 16, 16])
>>> target = rand([16, 3, 16, 16])
>>> sdi = _SpectralDistortionIndex()
>>> sdi(preds, target)
tensor(0.0234)
"""
def __init__(
self, p: int = 1, reduction: Literal["elementwise_mean", "sum", "none"] = "elementwise_mean", **kwargs: Any
) -> None:
_deprecated_root_import_class("SpectralDistortionIndex", "image")
super().__init__(p=p, reduction=reduction, **kwargs)
class _StructuralSimilarityIndexMeasure(StructuralSimilarityIndexMeasure):
"""Wrapper for deprecated import.
>>> import torch
>>> preds = torch.rand([3, 3, 256, 256])
>>> target = preds * 0.75
>>> ssim = _StructuralSimilarityIndexMeasure(data_range=1.0)
>>> ssim(preds, target)
tensor(0.9219)
"""
def __init__(
self,
gaussian_kernel: bool = True,
sigma: Union[float, Sequence[float]] = 1.5,
kernel_size: Union[int, Sequence[int]] = 11,
reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean",
data_range: Optional[Union[float, tuple[float, float]]] = None,
k1: float = 0.01,
k2: float = 0.03,
return_full_image: bool = False,
return_contrast_sensitivity: bool = False,
**kwargs: Any,
) -> None:
_deprecated_root_import_class("StructuralSimilarityIndexMeasure", "image")
super().__init__(
gaussian_kernel=gaussian_kernel,
sigma=sigma,
kernel_size=kernel_size,
reduction=reduction,
data_range=data_range,
k1=k1,
k2=k2,
return_full_image=return_full_image,
return_contrast_sensitivity=return_contrast_sensitivity,
**kwargs,
)
class _TotalVariation(TotalVariation):
"""Wrapper for deprecated import.
>>> from torch import rand
>>> tv = _TotalVariation()
>>> img = rand(5, 3, 28, 28)
>>> tv(img)
tensor(7546.8018)
"""
def __init__(self, reduction: Literal["mean", "sum", "none", None] = "sum", **kwargs: Any) -> None:
_deprecated_root_import_class("TotalVariation", "image")
super().__init__(reduction=reduction, **kwargs)
class _UniversalImageQualityIndex(UniversalImageQualityIndex):
"""Wrapper for deprecated import.
>>> import torch
>>> preds = torch.rand([16, 1, 16, 16])
>>> target = preds * 0.75
>>> uqi = _UniversalImageQualityIndex()
>>> uqi(preds, target)
tensor(0.9216)
"""
def __init__(
self,
kernel_size: Sequence[int] = (11, 11),
sigma: Sequence[float] = (1.5, 1.5),
reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean",
**kwargs: Any,
) -> None:
_deprecated_root_import_class("UniversalImageQualityIndex", "image")
super().__init__(kernel_size=kernel_size, sigma=sigma, reduction=reduction, **kwargs)