# Copyright The Lightning team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Content copied from # https://github.com/richzhang/PerceptualSimilarity/blob/master/lpips/lpips.py # and # https://github.com/richzhang/PerceptualSimilarity/blob/master/lpips/pretrained_networks.py # and with adjustments from # https://github.com/richzhang/PerceptualSimilarity/pull/114/files # due to package no longer being maintained # Copyright (c) 2018, Richard Zhang, Phillip Isola, Alexei A. Efros, Eli Shechtman, Oliver Wang # All rights reserved. # License under BSD 2-clause import inspect import os from typing import List, NamedTuple, Optional, Union import torch from torch import Tensor, nn from typing_extensions import Literal from torchmetrics.utilities.imports import _TORCHVISION_AVAILABLE _weight_map = { "squeezenet1_1": "SqueezeNet1_1_Weights", "alexnet": "AlexNet_Weights", "vgg16": "VGG16_Weights", } if not _TORCHVISION_AVAILABLE: __doctest_skip__ = ["learned_perceptual_image_patch_similarity"] def _get_net(net: str, pretrained: bool) -> nn.modules.container.Sequential: """Get torchvision network. Args: net: Name of network pretrained: If pretrained weights should be used """ from torchvision import models as tv if _TORCHVISION_AVAILABLE: if pretrained: pretrained_features = getattr(tv, net)(weights=getattr(tv, _weight_map[net]).IMAGENET1K_V1).features else: pretrained_features = getattr(tv, net)(weights=None).features return pretrained_features class SqueezeNet(torch.nn.Module): """SqueezeNet implementation.""" def __init__(self, requires_grad: bool = False, pretrained: bool = True) -> None: super().__init__() pretrained_features = _get_net("squeezenet1_1", pretrained) self.N_slices = 7 slices = [] feature_ranges = [range(2), range(2, 5), range(5, 8), range(8, 10), range(10, 11), range(11, 12), range(12, 13)] for feature_range in feature_ranges: seq = torch.nn.Sequential() for i in feature_range: seq.add_module(str(i), pretrained_features[i]) slices.append(seq) self.slices = nn.ModuleList(slices) if not requires_grad: for param in self.parameters(): param.requires_grad = False def forward(self, x: Tensor) -> NamedTuple: """Process input.""" class _SqueezeOutput(NamedTuple): relu1: Tensor relu2: Tensor relu3: Tensor relu4: Tensor relu5: Tensor relu6: Tensor relu7: Tensor relus = [] for slice_ in self.slices: x = slice_(x) relus.append(x) return _SqueezeOutput(*relus) class Alexnet(torch.nn.Module): """Alexnet implementation.""" def __init__(self, requires_grad: bool = False, pretrained: bool = True) -> None: super().__init__() alexnet_pretrained_features = _get_net("alexnet", pretrained) self.slice1 = torch.nn.Sequential() self.slice2 = torch.nn.Sequential() self.slice3 = torch.nn.Sequential() self.slice4 = torch.nn.Sequential() self.slice5 = torch.nn.Sequential() self.N_slices = 5 for x in range(2): self.slice1.add_module(str(x), alexnet_pretrained_features[x]) for x in range(2, 5): self.slice2.add_module(str(x), alexnet_pretrained_features[x]) for x in range(5, 8): self.slice3.add_module(str(x), alexnet_pretrained_features[x]) for x in range(8, 10): self.slice4.add_module(str(x), alexnet_pretrained_features[x]) for x in range(10, 12): self.slice5.add_module(str(x), alexnet_pretrained_features[x]) if not requires_grad: for param in self.parameters(): param.requires_grad = False def forward(self, x: Tensor) -> NamedTuple: """Process input.""" h = self.slice1(x) h_relu1 = h h = self.slice2(h) h_relu2 = h h = self.slice3(h) h_relu3 = h h = self.slice4(h) h_relu4 = h h = self.slice5(h) h_relu5 = h class _AlexnetOutputs(NamedTuple): relu1: Tensor relu2: Tensor relu3: Tensor relu4: Tensor relu5: Tensor return _AlexnetOutputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5) class Vgg16(torch.nn.Module): """Vgg16 implementation.""" def __init__(self, requires_grad: bool = False, pretrained: bool = True) -> None: super().__init__() vgg_pretrained_features = _get_net("vgg16", pretrained) self.slice1 = torch.nn.Sequential() self.slice2 = torch.nn.Sequential() self.slice3 = torch.nn.Sequential() self.slice4 = torch.nn.Sequential() self.slice5 = torch.nn.Sequential() self.N_slices = 5 for x in range(4): self.slice1.add_module(str(x), vgg_pretrained_features[x]) for x in range(4, 9): self.slice2.add_module(str(x), vgg_pretrained_features[x]) for x in range(9, 16): self.slice3.add_module(str(x), vgg_pretrained_features[x]) for x in range(16, 23): self.slice4.add_module(str(x), vgg_pretrained_features[x]) for x in range(23, 30): self.slice5.add_module(str(x), vgg_pretrained_features[x]) if not requires_grad: for param in self.parameters(): param.requires_grad = False def forward(self, x: Tensor) -> NamedTuple: """Process input.""" h = self.slice1(x) h_relu1_2 = h h = self.slice2(h) h_relu2_2 = h h = self.slice3(h) h_relu3_3 = h h = self.slice4(h) h_relu4_3 = h h = self.slice5(h) h_relu5_3 = h class _VGGOutputs(NamedTuple): relu1_2: Tensor relu2_2: Tensor relu3_3: Tensor relu4_3: Tensor relu5_3: Tensor return _VGGOutputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3) def _spatial_average(in_tens: Tensor, keep_dim: bool = True) -> Tensor: """Spatial averaging over height and width of images.""" return in_tens.mean([2, 3], keepdim=keep_dim) def _upsample(in_tens: Tensor, out_hw: tuple[int, ...] = (64, 64)) -> Tensor: """Upsample input with bilinear interpolation.""" return nn.Upsample(size=out_hw, mode="bilinear", align_corners=False)(in_tens) def _normalize_tensor(in_feat: Tensor, eps: float = 1e-8) -> Tensor: """Normalize input tensor.""" norm_factor = torch.sqrt(eps + torch.sum(in_feat**2, dim=1, keepdim=True)) return in_feat / norm_factor def _resize_tensor(x: Tensor, size: int = 64) -> Tensor: """https://github.com/toshas/torch-fidelity/blob/master/torch_fidelity/sample_similarity_lpips.py#L127C22-L132.""" if x.shape[-1] > size and x.shape[-2] > size: return torch.nn.functional.interpolate(x, (size, size), mode="area") return torch.nn.functional.interpolate(x, (size, size), mode="bilinear", align_corners=False) class ScalingLayer(nn.Module): """Scaling layer.""" shift: Tensor scale: Tensor def __init__(self) -> None: super().__init__() self.register_buffer("shift", torch.Tensor([-0.030, -0.088, -0.188])[None, :, None, None], persistent=False) self.register_buffer("scale", torch.Tensor([0.458, 0.448, 0.450])[None, :, None, None], persistent=False) def forward(self, inp: Tensor) -> Tensor: """Process input.""" return (inp - self.shift) / self.scale class NetLinLayer(nn.Module): """A single linear layer which does a 1x1 conv.""" def __init__(self, chn_in: int, chn_out: int = 1, use_dropout: bool = False) -> None: super().__init__() layers = [nn.Dropout()] if use_dropout else [] layers += [ nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), # type: ignore[list-item] ] self.model = nn.Sequential(*layers) def forward(self, x: Tensor) -> Tensor: """Process input.""" return self.model(x) class _LPIPS(nn.Module): def __init__( self, pretrained: bool = True, net: Literal["alex", "vgg", "squeeze"] = "alex", spatial: bool = False, pnet_rand: bool = False, pnet_tune: bool = False, use_dropout: bool = True, model_path: Optional[str] = None, eval_mode: bool = True, resize: Optional[int] = None, ) -> None: """Initializes a perceptual loss torch.nn.Module. Args: pretrained: This flag controls the linear layers should be pretrained version or random net: Indicate backbone to use, choose between ['alex','vgg','squeeze'] spatial: If input should be spatial averaged pnet_rand: If backbone should be random or use imagenet pre-trained weights pnet_tune: If backprop should be enabled for both backbone and linear layers use_dropout: If dropout layers should be added model_path: Model path to load pretained models from eval_mode: If network should be in evaluation mode resize: If input should be resized to this size """ super().__init__() self.pnet_type = net self.pnet_tune = pnet_tune self.pnet_rand = pnet_rand self.spatial = spatial self.resize = resize self.scaling_layer = ScalingLayer() if self.pnet_type in ["vgg", "vgg16"]: net_type = Vgg16 self.chns = [64, 128, 256, 512, 512] elif self.pnet_type == "alex": net_type = Alexnet # type: ignore[assignment] self.chns = [64, 192, 384, 256, 256] elif self.pnet_type == "squeeze": net_type = SqueezeNet # type: ignore[assignment] self.chns = [64, 128, 256, 384, 384, 512, 512] self.L = len(self.chns) self.net = net_type(pretrained=not self.pnet_rand, requires_grad=self.pnet_tune) self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) self.lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4] if self.pnet_type == "squeeze": # 7 layers for squeezenet self.lin5 = NetLinLayer(self.chns[5], use_dropout=use_dropout) self.lin6 = NetLinLayer(self.chns[6], use_dropout=use_dropout) self.lins += [self.lin5, self.lin6] self.lins = nn.ModuleList(self.lins) # type: ignore[assignment] if pretrained: if model_path is None: model_path = os.path.abspath( os.path.join(inspect.getfile(self.__init__), "..", f"lpips_models/{net}.pth") # type: ignore[misc] ) self.load_state_dict(torch.load(model_path, map_location="cpu"), strict=False) if eval_mode: self.eval() if not self.pnet_tune: for param in self.parameters(): param.requires_grad = False def forward( self, in0: Tensor, in1: Tensor, retperlayer: bool = False, normalize: bool = False ) -> Union[Tensor, tuple[Tensor, List[Tensor]]]: if normalize: # turn on this flag if input is [0,1] so it can be adjusted to [-1, +1] in0 = 2 * in0 - 1 in1 = 2 * in1 - 1 # normalize input in0_input, in1_input = self.scaling_layer(in0), self.scaling_layer(in1) # resize input if needed if self.resize is not None: in0_input = _resize_tensor(in0_input, size=self.resize) in1_input = _resize_tensor(in1_input, size=self.resize) outs0, outs1 = self.net.forward(in0_input), self.net.forward(in1_input) feats0, feats1, diffs = {}, {}, {} for kk in range(self.L): feats0[kk], feats1[kk] = _normalize_tensor(outs0[kk]), _normalize_tensor(outs1[kk]) diffs[kk] = (feats0[kk] - feats1[kk]) ** 2 res = [] for kk in range(self.L): if self.spatial: res.append(_upsample(self.lins[kk](diffs[kk]), out_hw=tuple(in0.shape[2:]))) else: res.append(_spatial_average(self.lins[kk](diffs[kk]), keep_dim=True)) val: Tensor = sum(res) # type: ignore[assignment] if retperlayer: return (val, res) return val class _NoTrainLpips(_LPIPS): """Wrapper to make sure LPIPS never leaves evaluation mode.""" def train(self, mode: bool) -> "_NoTrainLpips": # type: ignore[override] """Force network to always be in evaluation mode.""" return super().train(False) def _valid_img(img: Tensor, normalize: bool) -> bool: """Check that input is a valid image to the network.""" value_check = img.max() <= 1.0 and img.min() >= 0.0 if normalize else img.min() >= -1 return img.ndim == 4 and img.shape[1] == 3 and value_check # type: ignore[return-value] def _lpips_update(img1: Tensor, img2: Tensor, net: nn.Module, normalize: bool) -> tuple[Tensor, Union[int, Tensor]]: if not (_valid_img(img1, normalize) and _valid_img(img2, normalize)): raise ValueError( "Expected both input arguments to be normalized tensors with shape [N, 3, H, W]." f" Got input with shape {img1.shape} and {img2.shape} and values in range" f" {[img1.min(), img1.max()]} and {[img2.min(), img2.max()]} when all values are" f" expected to be in the {[0, 1] if normalize else [-1, 1]} range." ) loss = net(img1, img2, normalize=normalize).squeeze() return loss, img1.shape[0] def _lpips_compute(sum_scores: Tensor, total: Union[Tensor, int], reduction: Literal["sum", "mean"] = "mean") -> Tensor: return sum_scores / total if reduction == "mean" else sum_scores def learned_perceptual_image_patch_similarity( img1: Tensor, img2: Tensor, net_type: Literal["alex", "vgg", "squeeze"] = "alex", reduction: Literal["sum", "mean"] = "mean", normalize: bool = False, ) -> Tensor: """The Learned Perceptual Image Patch Similarity (`LPIPS_`) calculates perceptual similarity between two images. LPIPS essentially computes the similarity between the activations of two image patches for some pre-defined network. This measure has been shown to match human perception well. A low LPIPS score means that image patches are perceptual similar. Both input image patches are expected to have shape ``(N, 3, H, W)``. The minimum size of `H, W` depends on the chosen backbone (see `net_type` arg). Args: img1: first set of images img2: second set of images net_type: str indicating backbone network type to use. Choose between `'alex'`, `'vgg'` or `'squeeze'` reduction: str indicating how to reduce over the batch dimension. Choose between `'sum'` or `'mean'`. normalize: by default this is ``False`` meaning that the input is expected to be in the [-1,1] range. If set to ``True`` will instead expect input to be in the ``[0,1]`` range. Example: >>> from torch import rand >>> from torchmetrics.functional.image.lpips import learned_perceptual_image_patch_similarity >>> img1 = (rand(10, 3, 100, 100) * 2) - 1 >>> img2 = (rand(10, 3, 100, 100) * 2) - 1 >>> learned_perceptual_image_patch_similarity(img1, img2, net_type='squeeze') tensor(0.1005) """ net = _NoTrainLpips(net=net_type).to(device=img1.device, dtype=img1.dtype) loss, total = _lpips_update(img1, img2, net, normalize) return _lpips_compute(loss.sum(), total, reduction)