jamtur01's picture
Upload folder using huggingface_hub
9c6594c verified
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import torch
from torch import Tensor, tensor
def _compute_bef(x: Tensor, block_size: int = 8) -> Tensor:
"""Compute block effect.
Args:
x: input image
block_size: integer indication the block size
Returns:
Computed block effect
Raises:
ValueError:
If the image is not a grayscale image
"""
(
_,
channels,
height,
width,
) = x.shape
if channels > 1:
raise ValueError(f"`psnrb` metric expects grayscale images, but got images with {channels} channels.")
h = torch.arange(width - 1)
h_b = torch.tensor(range(block_size - 1, width - 1, block_size))
h_bc = torch.tensor(list(set(h.tolist()).symmetric_difference(h_b.tolist())))
v = torch.arange(height - 1)
v_b = torch.tensor(range(block_size - 1, height - 1, block_size))
v_bc = torch.tensor(list(set(v.tolist()).symmetric_difference(v_b.tolist())))
d_b = (x[:, :, :, h_b] - x[:, :, :, h_b + 1]).pow(2.0).sum()
d_bc = (x[:, :, :, h_bc] - x[:, :, :, h_bc + 1]).pow(2.0).sum()
d_b += (x[:, :, v_b, :] - x[:, :, v_b + 1, :]).pow(2.0).sum()
d_bc += (x[:, :, v_bc, :] - x[:, :, v_bc + 1, :]).pow(2.0).sum()
n_hb = height * (width / block_size) - 1
n_hbc = (height * (width - 1)) - n_hb
n_vb = width * (height / block_size) - 1
n_vbc = (width * (height - 1)) - n_vb
d_b /= n_hb + n_vb
d_bc /= n_hbc + n_vbc
t = math.log2(block_size) / math.log2(min(height, width)) if d_b > d_bc else 0
return t * (d_b - d_bc)
def _psnrb_compute(
sum_squared_error: Tensor,
bef: Tensor,
num_obs: Tensor,
data_range: Tensor,
) -> Tensor:
"""Computes peak signal-to-noise ratio.
Args:
sum_squared_error: Sum of square of errors over all observations
bef: block effect
num_obs: Number of predictions or observations
data_range: the range of the data. If None, it is determined from the data (max - min).
"""
sum_squared_error = sum_squared_error / num_obs + bef
if data_range > 2:
return 10 * torch.log10(data_range**2 / sum_squared_error)
return 10 * torch.log10(1.0 / sum_squared_error)
def _psnrb_update(preds: Tensor, target: Tensor, block_size: int = 8) -> tuple[Tensor, Tensor, Tensor]:
"""Updates and returns variables required to compute peak signal-to-noise ratio.
Args:
preds: Predicted tensor
target: Ground truth tensor
block_size: Integer indication the block size
"""
sum_squared_error = torch.sum(torch.pow(preds - target, 2))
num_obs = tensor(target.numel(), device=target.device)
bef = _compute_bef(preds, block_size=block_size)
return sum_squared_error, bef, num_obs
def peak_signal_noise_ratio_with_blocked_effect(
preds: Tensor,
target: Tensor,
block_size: int = 8,
) -> Tensor:
r"""Computes `Peak Signal to Noise Ratio With Blocked Effect` (PSNRB) metrics.
.. math::
\text{PSNRB}(I, J) = 10 * \log_{10} \left(\frac{\max(I)^2}{\text{MSE}(I, J)-\text{B}(I, J)}\right)
Where :math:`\text{MSE}` denotes the `mean-squared-error`_ function.
Args:
preds: estimated signal
target: groun truth signal
block_size: integer indication the block size
Return:
Tensor with PSNRB score
Example:
>>> from torch import rand
>>> from torchmetrics.functional.image import peak_signal_noise_ratio_with_blocked_effect
>>> preds = rand(1, 1, 28, 28)
>>> target = rand(1, 1, 28, 28)
>>> peak_signal_noise_ratio_with_blocked_effect(preds, target)
tensor(7.8402)
"""
data_range = target.max() - target.min()
sum_squared_error, bef, num_obs = _psnrb_update(preds, target, block_size=block_size)
return _psnrb_compute(sum_squared_error, bef, num_obs, data_range)