# Copyright The Lightning team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import math from typing import Optional, Union import torch from torch import Tensor, tensor from torch.nn.functional import conv2d, pad from typing_extensions import Literal from torchmetrics.utilities.checks import _check_same_shape from torchmetrics.utilities.distributed import reduce def _scc_update(preds: Tensor, target: Tensor, hp_filter: Tensor, window_size: int) -> tuple[Tensor, Tensor, Tensor]: """Update and returns variables required to compute Spatial Correlation Coefficient. Args: preds: Predicted tensor target: Ground truth tensor hp_filter: High-pass filter tensor window_size: Local window size integer Return: Tuple of (preds, target, hp_filter) tensors Raises: ValueError: If ``preds`` and ``target`` have different number of channels If ``preds`` and ``target`` have different shapes If ``preds`` and ``target`` have invalid shapes If ``window_size`` is not a positive integer If ``window_size`` is greater than the size of the image """ if preds.dtype != target.dtype: target = target.to(preds.dtype) _check_same_shape(preds, target) if preds.ndim not in (3, 4): raise ValueError( "Expected `preds` and `target` to have batch of colored images with BxCxHxW shape" " or batch of grayscale images of BxHxW shape." f" Got preds: {preds.shape} and target: {target.shape}." ) if len(preds.shape) == 3: preds = preds.unsqueeze(1) target = target.unsqueeze(1) if not window_size > 0: raise ValueError(f"Expected `window_size` to be a positive integer. Got {window_size}.") if window_size > preds.size(2) or window_size > preds.size(3): raise ValueError( f"Expected `window_size` to be less than or equal to the size of the image." f" Got window_size: {window_size} and image size: {preds.size(2)}x{preds.size(3)}." ) preds = preds.to(torch.float32) target = target.to(torch.float32) hp_filter = hp_filter[None, None, :].to(dtype=preds.dtype, device=preds.device) return preds, target, hp_filter def _symmetric_reflect_pad_2d(input_img: Tensor, pad: Union[int, tuple[int, ...]]) -> Tensor: """Applies symmetric padding to the 2D image tensor input using ``reflect`` mode (d c b a | a b c d | d c b a).""" if isinstance(pad, int): pad = (pad, pad, pad, pad) if len(pad) != 4: raise ValueError(f"Expected padding to have length 4, but got {len(pad)}") left_pad = input_img[:, :, :, 0 : pad[0]].flip(dims=[3]) right_pad = input_img[:, :, :, -pad[1] :].flip(dims=[3]) padded = torch.cat([left_pad, input_img, right_pad], dim=3) top_pad = padded[:, :, 0 : pad[2], :].flip(dims=[2]) bottom_pad = padded[:, :, -pad[3] :, :].flip(dims=[2]) return torch.cat([top_pad, padded, bottom_pad], dim=2) def _signal_convolve_2d(input_img: Tensor, kernel: Tensor) -> Tensor: """Applies 2D signal convolution to the input tensor with the given kernel.""" left_padding = int(math.floor((kernel.size(3) - 1) / 2)) right_padding = int(math.ceil((kernel.size(3) - 1) / 2)) top_padding = int(math.floor((kernel.size(2) - 1) / 2)) bottom_padding = int(math.ceil((kernel.size(2) - 1) / 2)) padded = _symmetric_reflect_pad_2d(input_img, pad=(left_padding, right_padding, top_padding, bottom_padding)) kernel = kernel.flip([2, 3]) return conv2d(padded, kernel, stride=1, padding=0) def _hp_2d_laplacian(input_img: Tensor, kernel: Tensor) -> Tensor: """Applies 2-D Laplace filter to the input tensor with the given high pass filter.""" return _signal_convolve_2d(input_img, kernel) * 2.0 def _local_variance_covariance(preds: Tensor, target: Tensor, window: Tensor) -> tuple[Tensor, Tensor, Tensor]: """Computes local variance and covariance of the input tensors.""" # This code is inspired by # https://github.com/andrewekhalel/sewar/blob/master/sewar/full_ref.py#L187. left_padding = int(math.ceil((window.size(3) - 1) / 2)) right_padding = int(math.floor((window.size(3) - 1) / 2)) preds = pad(preds, (left_padding, right_padding, left_padding, right_padding)) target = pad(target, (left_padding, right_padding, left_padding, right_padding)) preds_mean = conv2d(preds, window, stride=1, padding=0) target_mean = conv2d(target, window, stride=1, padding=0) preds_var = conv2d(preds**2, window, stride=1, padding=0) - preds_mean**2 target_var = conv2d(target**2, window, stride=1, padding=0) - target_mean**2 target_preds_cov = conv2d(target * preds, window, stride=1, padding=0) - target_mean * preds_mean return preds_var, target_var, target_preds_cov def _scc_per_channel_compute(preds: Tensor, target: Tensor, hp_filter: Tensor, window_size: int) -> Tensor: """Computes per channel Spatial Correlation Coefficient. Args: preds: estimated image of Bx1xHxW shape. target: ground truth image of Bx1xHxW shape. hp_filter: 2D high-pass filter. window_size: size of window for local mean calculation. Return: Tensor with Spatial Correlation Coefficient score """ dtype = preds.dtype device = preds.device # This code is inspired by # https://github.com/andrewekhalel/sewar/blob/master/sewar/full_ref.py#L187. window = torch.ones(size=(1, 1, window_size, window_size), dtype=dtype, device=device) / (window_size**2) preds_hp = _hp_2d_laplacian(preds, hp_filter) target_hp = _hp_2d_laplacian(target, hp_filter) preds_var, target_var, target_preds_cov = _local_variance_covariance(preds_hp, target_hp, window) preds_var[preds_var < 0] = 0 target_var[target_var < 0] = 0 den = torch.sqrt(target_var) * torch.sqrt(preds_var) idx = den == 0 den[den == 0] = 1 scc = target_preds_cov / den scc[idx] = 0 return scc def spatial_correlation_coefficient( preds: Tensor, target: Tensor, hp_filter: Optional[Tensor] = None, window_size: int = 8, reduction: Optional[Literal["mean", "none", None]] = "mean", ) -> Tensor: """Compute Spatial Correlation Coefficient (SCC_). Args: preds: predicted images of shape ``(N,C,H,W)`` or ``(N,H,W)``. target: ground truth images of shape ``(N,C,H,W)`` or ``(N,H,W)``. hp_filter: High-pass filter tensor. default: tensor([[-1,-1,-1],[-1,8,-1],[-1,-1,-1]]) window_size: Local window size integer. default: 8, reduction: Reduction method for output tensor. If ``None`` or ``"none"``, returns a tensor with the per sample results. default: ``"mean"``. Return: Tensor with scc score Example: >>> from torch import randn >>> from torchmetrics.functional.image import spatial_correlation_coefficient as scc >>> x = randn(5, 3, 16, 16) >>> scc(x, x) tensor(1.) >>> x = randn(5, 16, 16) >>> scc(x, x) tensor(1.) >>> x = randn(5, 3, 16, 16) >>> y = randn(5, 3, 16, 16) >>> scc(x, y, reduction="none") tensor([0.0223, 0.0256, 0.0616, 0.0159, 0.0170]) """ if hp_filter is None: hp_filter = tensor([[-1, -1, -1], [-1, 8, -1], [-1, -1, -1]]) if reduction is None: reduction = "none" if reduction not in ("mean", "none"): raise ValueError(f"Expected reduction to be 'mean' or 'none', but got {reduction}") preds, target, hp_filter = _scc_update(preds, target, hp_filter, window_size) per_channel = [ _scc_per_channel_compute( preds[:, i, :, :].unsqueeze(1), target[:, i, :, :].unsqueeze(1), hp_filter, window_size ) for i in range(preds.size(1)) ] if reduction == "none": return torch.mean(torch.cat(per_channel, dim=1), dim=[1, 2, 3]) if reduction == "mean": return reduce(torch.cat(per_channel, dim=1), reduction="elementwise_mean") return None