|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from collections.abc import Sequence |
|
from copy import deepcopy |
|
from typing import Any, Optional, Union |
|
|
|
import torch |
|
from torch import Tensor |
|
from torch.nn import Module |
|
from torch.nn.functional import adaptive_avg_pool2d |
|
|
|
from torchmetrics.metric import Metric |
|
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE, _TORCH_FIDELITY_AVAILABLE |
|
from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE |
|
|
|
if not _MATPLOTLIB_AVAILABLE: |
|
__doctest_skip__ = ["FrechetInceptionDistance.plot"] |
|
|
|
if _TORCH_FIDELITY_AVAILABLE: |
|
from torch_fidelity.feature_extractor_inceptionv3 import FeatureExtractorInceptionV3 as _FeatureExtractorInceptionV3 |
|
from torch_fidelity.helpers import vassert |
|
from torch_fidelity.interpolate_compat_tensorflow import interpolate_bilinear_2d_like_tensorflow1x |
|
else: |
|
|
|
class _FeatureExtractorInceptionV3(Module): |
|
pass |
|
|
|
vassert = None |
|
interpolate_bilinear_2d_like_tensorflow1x = None |
|
|
|
__doctest_skip__ = ["FrechetInceptionDistance", "FrechetInceptionDistance.plot"] |
|
|
|
|
|
class NoTrainInceptionV3(_FeatureExtractorInceptionV3): |
|
"""Module that never leaves evaluation mode.""" |
|
|
|
def __init__( |
|
self, |
|
name: str, |
|
features_list: list[str], |
|
feature_extractor_weights_path: Optional[str] = None, |
|
) -> None: |
|
if not _TORCH_FIDELITY_AVAILABLE: |
|
raise ModuleNotFoundError( |
|
"NoTrainInceptionV3 module requires that `Torch-fidelity` is installed." |
|
" Either install as `pip install torchmetrics[image]` or `pip install torch-fidelity`." |
|
) |
|
|
|
super().__init__(name, features_list, feature_extractor_weights_path) |
|
|
|
self.eval() |
|
|
|
def train(self, mode: bool) -> "NoTrainInceptionV3": |
|
"""Force network to always be in evaluation mode.""" |
|
return super().train(False) |
|
|
|
def _torch_fidelity_forward(self, x: Tensor) -> tuple[Tensor, ...]: |
|
"""Forward method of inception net. |
|
|
|
Copy of the forward method from this file: |
|
https://github.com/toshas/torch-fidelity/blob/master/torch_fidelity/feature_extractor_inceptionv3.py |
|
with a single line change regarding the casting of `x` in the beginning. |
|
|
|
Corresponding license file (Apache License, Version 2.0): |
|
https://github.com/toshas/torch-fidelity/blob/master/LICENSE.md |
|
|
|
""" |
|
vassert(torch.is_tensor(x) and x.dtype == torch.uint8, "Expecting image as torch.Tensor with dtype=torch.uint8") |
|
features = {} |
|
remaining_features = self.features_list.copy() |
|
|
|
x = x.to(self._dtype) if hasattr(self, "_dtype") else x.to(torch.float) |
|
x = interpolate_bilinear_2d_like_tensorflow1x( |
|
x, |
|
size=(self.INPUT_IMAGE_SIZE, self.INPUT_IMAGE_SIZE), |
|
align_corners=False, |
|
) |
|
x = (x - 128) / 128 |
|
|
|
x = self.Conv2d_1a_3x3(x) |
|
x = self.Conv2d_2a_3x3(x) |
|
x = self.Conv2d_2b_3x3(x) |
|
x = self.MaxPool_1(x) |
|
|
|
if "64" in remaining_features: |
|
features["64"] = adaptive_avg_pool2d(x, output_size=(1, 1)).squeeze(-1).squeeze(-1) |
|
remaining_features.remove("64") |
|
if len(remaining_features) == 0: |
|
return tuple(features[a] for a in self.features_list) |
|
|
|
x = self.Conv2d_3b_1x1(x) |
|
x = self.Conv2d_4a_3x3(x) |
|
x = self.MaxPool_2(x) |
|
|
|
if "192" in remaining_features: |
|
features["192"] = adaptive_avg_pool2d(x, output_size=(1, 1)).squeeze(-1).squeeze(-1) |
|
remaining_features.remove("192") |
|
if len(remaining_features) == 0: |
|
return tuple(features[a] for a in self.features_list) |
|
|
|
x = self.Mixed_5b(x) |
|
x = self.Mixed_5c(x) |
|
x = self.Mixed_5d(x) |
|
x = self.Mixed_6a(x) |
|
x = self.Mixed_6b(x) |
|
x = self.Mixed_6c(x) |
|
x = self.Mixed_6d(x) |
|
x = self.Mixed_6e(x) |
|
|
|
if "768" in remaining_features: |
|
features["768"] = adaptive_avg_pool2d(x, output_size=(1, 1)).squeeze(-1).squeeze(-1) |
|
remaining_features.remove("768") |
|
if len(remaining_features) == 0: |
|
return tuple(features[a] for a in self.features_list) |
|
|
|
x = self.Mixed_7a(x) |
|
x = self.Mixed_7b(x) |
|
x = self.Mixed_7c(x) |
|
x = self.AvgPool(x) |
|
x = torch.flatten(x, 1) |
|
|
|
if "2048" in remaining_features: |
|
features["2048"] = x |
|
remaining_features.remove("2048") |
|
if len(remaining_features) == 0: |
|
return tuple(features[a] for a in self.features_list) |
|
|
|
if "logits_unbiased" in remaining_features: |
|
x = x.mm(self.fc.weight.T) |
|
|
|
features["logits_unbiased"] = x |
|
remaining_features.remove("logits_unbiased") |
|
if len(remaining_features) == 0: |
|
return tuple(features[a] for a in self.features_list) |
|
|
|
x = x + self.fc.bias.unsqueeze(0) |
|
else: |
|
x = self.fc(x) |
|
|
|
features["logits"] = x |
|
return tuple(features[a] for a in self.features_list) |
|
|
|
def forward(self, x: Tensor) -> Tensor: |
|
"""Forward pass of neural network with reshaping of output.""" |
|
out = self._torch_fidelity_forward(x) |
|
return out[0].reshape(x.shape[0], -1) |
|
|
|
|
|
def _compute_fid(mu1: Tensor, sigma1: Tensor, mu2: Tensor, sigma2: Tensor) -> Tensor: |
|
r"""Compute adjusted version of `Fid Score`_. |
|
|
|
The Frechet Inception Distance between two multivariate Gaussians X_x ~ N(mu_1, sigm_1) |
|
and X_y ~ N(mu_2, sigm_2) is d^2 = ||mu_1 - mu_2||^2 + Tr(sigm_1 + sigm_2 - 2*sqrt(sigm_1*sigm_2)). |
|
|
|
Args: |
|
mu1: mean of activations calculated on predicted (x) samples |
|
sigma1: covariance matrix over activations calculated on predicted (x) samples |
|
mu2: mean of activations calculated on target (y) samples |
|
sigma2: covariance matrix over activations calculated on target (y) samples |
|
|
|
Returns: |
|
Scalar value of the distance between sets. |
|
|
|
""" |
|
a = (mu1 - mu2).square().sum(dim=-1) |
|
b = sigma1.trace() + sigma2.trace() |
|
c = torch.linalg.eigvals(sigma1 @ sigma2).sqrt().real.sum(dim=-1) |
|
|
|
return a + b - 2 * c |
|
|
|
|
|
class FrechetInceptionDistance(Metric): |
|
r"""Calculate FrΓ©chet inception distance (FID_) which is used to assess the quality of generated images. |
|
|
|
.. math:: |
|
FID = \|\mu - \mu_w\|^2 + tr(\Sigma + \Sigma_w - 2(\Sigma \Sigma_w)^{\frac{1}{2}}) |
|
|
|
where :math:`\mathcal{N}(\mu, \Sigma)` is the multivariate normal distribution estimated from Inception v3 |
|
(`fid ref1`_) features calculated on real life images and :math:`\mathcal{N}(\mu_w, \Sigma_w)` is the |
|
multivariate normal distribution estimated from Inception v3 features calculated on generated (fake) images. |
|
The metric was originally proposed in `fid ref1`_. |
|
|
|
Using the default feature extraction (Inception v3 using the original weights from `fid ref2`_), the input is |
|
expected to be mini-batches of 3-channel RGB images of shape ``(3xHxW)``. If argument ``normalize`` |
|
is ``True`` images are expected to be dtype ``float`` and have values in the ``[0,1]`` range, else if |
|
``normalize`` is set to ``False`` images are expected to have dtype ``uint8`` and take values in the ``[0, 255]`` |
|
range. All images will be resized to 299 x 299 which is the size of the original training data. The boolian |
|
flag ``real`` determines if the images should update the statistics of the real distribution or the |
|
fake distribution. |
|
|
|
Using custom feature extractor is also possible. One can give a torch.nn.Module as `feature` argument. This |
|
custom feature extractor is expected to have output shape of ``(1, num_features)``. This would change the |
|
used feature extractor from default (Inception v3) to the given network. In case network doesn't have |
|
``num_features`` attribute, a random tensor will be given to the network to infer feature dimensionality. |
|
Size of this tensor can be controlled by ``input_img_size`` argument and type of the tensor can be controlled |
|
with ``normalize`` argument (``True`` uses float32 tensors and ``False`` uses int8 tensors). In this case, update |
|
method expects to have the tensor given to `imgs` argument to be in the correct shape and type that is compatible |
|
to the custom feature extractor. |
|
|
|
This metric is known to be unstable in its calculatations, and we recommend for the best results using this metric |
|
that you calculate using `torch.float64` (default is `torch.float32`) which can be set using the `.set_dtype` |
|
method of the metric. |
|
|
|
.. hint:: |
|
Using this metric with the default feature extractor requires that ``torch-fidelity`` |
|
is installed. Either install as ``pip install torchmetrics[image]`` or ``pip install torch-fidelity`` |
|
|
|
As input to ``forward`` and ``update`` the metric accepts the following input |
|
|
|
- ``imgs`` (:class:`~torch.Tensor`): tensor with images feed to the feature extractor with |
|
- ``real`` (:class:`~bool`): bool indicating if ``imgs`` belong to the real or the fake distribution |
|
|
|
As output of `forward` and `compute` the metric returns the following output |
|
|
|
- ``fid`` (:class:`~torch.Tensor`): float scalar tensor with mean FID value over samples |
|
|
|
Args: |
|
feature: |
|
Either an integer or ``nn.Module``: |
|
|
|
- an integer will indicate the inceptionv3 feature layer to choose. Can be one of the following: |
|
64, 192, 768, 2048 |
|
- an ``nn.Module`` for using a custom feature extractor. Expects that its forward method returns |
|
an ``(N,d)`` matrix where ``N`` is the batch size and ``d`` is the feature size. |
|
|
|
reset_real_features: Whether to also reset the real features. Since in many cases the real dataset does not |
|
change, the features can be cached them to avoid recomputing them which is costly. Set this to ``False`` if |
|
your dataset does not change. |
|
normalize: |
|
Argument for controlling the input image dtype normalization: |
|
|
|
- If default feature extractor is used, controls whether input imgs have values in range [0, 1] or not: |
|
|
|
- True: if input imgs have values ranged in [0, 1]. They are cast to int8/byte tensors. |
|
- False: if input imgs have values ranged in [0, 255]. No casting is done. |
|
|
|
- If custom feature extractor module is used, controls type of the input img tensors: |
|
|
|
- True: if input imgs are expected to be in the data type of torch.float32. |
|
- False: if input imgs are expected to be in the data type of torch.int8. |
|
input_img_size: tuple of integers. Indicates input img size to the custom feature extractor network if provided. |
|
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. |
|
|
|
Raises: |
|
ValueError: |
|
If torch version is lower than 1.9 |
|
ModuleNotFoundError: |
|
If ``feature`` is set to an ``int`` (default settings) and ``torch-fidelity`` is not installed |
|
ValueError: |
|
If ``feature`` is set to an ``int`` not in [64, 192, 768, 2048] |
|
TypeError: |
|
If ``feature`` is not an ``str``, ``int`` or ``torch.nn.Module`` |
|
ValueError: |
|
If ``reset_real_features`` is not an ``bool`` |
|
|
|
Example: |
|
>>> from torch import rand |
|
>>> from torchmetrics.image.fid import FrechetInceptionDistance |
|
>>> fid = FrechetInceptionDistance(feature=64) |
|
>>> # generate two slightly overlapping image intensity distributions |
|
>>> imgs_dist1 = torch.randint(0, 200, (100, 3, 299, 299), dtype=torch.uint8) |
|
>>> imgs_dist2 = torch.randint(100, 255, (100, 3, 299, 299), dtype=torch.uint8) |
|
>>> fid.update(imgs_dist1, real=True) |
|
>>> fid.update(imgs_dist2, real=False) |
|
>>> fid.compute() |
|
tensor(12.6388) |
|
|
|
""" |
|
|
|
higher_is_better: bool = False |
|
is_differentiable: bool = False |
|
full_state_update: bool = False |
|
plot_lower_bound: float = 0.0 |
|
|
|
real_features_sum: Tensor |
|
real_features_cov_sum: Tensor |
|
real_features_num_samples: Tensor |
|
|
|
fake_features_sum: Tensor |
|
fake_features_cov_sum: Tensor |
|
fake_features_num_samples: Tensor |
|
|
|
inception: Module |
|
feature_network: str = "inception" |
|
|
|
def __init__( |
|
self, |
|
feature: Union[int, Module] = 2048, |
|
reset_real_features: bool = True, |
|
normalize: bool = False, |
|
input_img_size: tuple[int, int, int] = (3, 299, 299), |
|
feature_extractor_weights_path: Optional[str] = None, |
|
**kwargs: Any, |
|
) -> None: |
|
super().__init__(**kwargs) |
|
|
|
if not isinstance(normalize, bool): |
|
raise ValueError("Argument `normalize` expected to be a bool") |
|
self.normalize = normalize |
|
self.used_custom_model = False |
|
|
|
if isinstance(feature, int): |
|
num_features = feature |
|
if not _TORCH_FIDELITY_AVAILABLE: |
|
raise ModuleNotFoundError( |
|
"FrechetInceptionDistance metric requires that `Torch-fidelity` is installed." |
|
" Either install as `pip install torchmetrics[image]` or `pip install torch-fidelity`." |
|
) |
|
valid_int_input = (64, 192, 768, 2048) |
|
if feature not in valid_int_input: |
|
raise ValueError( |
|
f"Integer input to argument `feature` must be one of {valid_int_input}, but got {feature}." |
|
) |
|
|
|
self.inception = NoTrainInceptionV3( |
|
name="inception-v3-compat", |
|
features_list=[str(feature)], |
|
feature_extractor_weights_path=feature_extractor_weights_path, |
|
) |
|
|
|
elif isinstance(feature, Module): |
|
self.inception = feature |
|
self.used_custom_model = True |
|
if hasattr(self.inception, "num_features"): |
|
if isinstance(self.inception.num_features, int): |
|
num_features = self.inception.num_features |
|
elif isinstance(self.inception.num_features, Tensor): |
|
num_features = int(self.inception.num_features.item()) |
|
else: |
|
raise TypeError("Expected `self.inception.num_features` to be of type int or Tensor.") |
|
else: |
|
if self.normalize: |
|
dummy_image = torch.rand(1, *input_img_size, dtype=torch.float32) |
|
else: |
|
dummy_image = torch.randint(0, 255, (1, *input_img_size), dtype=torch.uint8) |
|
num_features = self.inception(dummy_image).shape[-1] |
|
else: |
|
raise TypeError("Got unknown input to argument `feature`") |
|
|
|
if not isinstance(reset_real_features, bool): |
|
raise ValueError("Argument `reset_real_features` expected to be a bool") |
|
self.reset_real_features = reset_real_features |
|
|
|
mx_num_feats = (num_features, num_features) |
|
self.add_state("real_features_sum", torch.zeros(num_features).double(), dist_reduce_fx="sum") |
|
self.add_state("real_features_cov_sum", torch.zeros(mx_num_feats).double(), dist_reduce_fx="sum") |
|
self.add_state("real_features_num_samples", torch.tensor(0).long(), dist_reduce_fx="sum") |
|
|
|
self.add_state("fake_features_sum", torch.zeros(num_features).double(), dist_reduce_fx="sum") |
|
self.add_state("fake_features_cov_sum", torch.zeros(mx_num_feats).double(), dist_reduce_fx="sum") |
|
self.add_state("fake_features_num_samples", torch.tensor(0).long(), dist_reduce_fx="sum") |
|
|
|
def update(self, imgs: Tensor, real: bool) -> None: |
|
"""Update the state with extracted features. |
|
|
|
Args: |
|
imgs: Input img tensors to evaluate. If used custom feature extractor please |
|
make sure dtype and size is correct for the model. |
|
real: Whether given image is real or fake. |
|
|
|
""" |
|
imgs = (imgs * 255).byte() if self.normalize and (not self.used_custom_model) else imgs |
|
features = self.inception(imgs) |
|
self.orig_dtype = features.dtype |
|
features = features.double() |
|
|
|
if features.dim() == 1: |
|
features = features.unsqueeze(0) |
|
if real: |
|
self.real_features_sum += features.sum(dim=0) |
|
self.real_features_cov_sum += features.t().mm(features) |
|
self.real_features_num_samples += imgs.shape[0] |
|
else: |
|
self.fake_features_sum += features.sum(dim=0) |
|
self.fake_features_cov_sum += features.t().mm(features) |
|
self.fake_features_num_samples += imgs.shape[0] |
|
|
|
def compute(self) -> Tensor: |
|
"""Calculate FID score based on accumulated extracted features from the two distributions.""" |
|
if self.real_features_num_samples < 2 or self.fake_features_num_samples < 2: |
|
raise RuntimeError("More than one sample is required for both the real and fake distributed to compute FID") |
|
mean_real = (self.real_features_sum / self.real_features_num_samples).unsqueeze(0) |
|
mean_fake = (self.fake_features_sum / self.fake_features_num_samples).unsqueeze(0) |
|
|
|
cov_real_num = self.real_features_cov_sum - self.real_features_num_samples * mean_real.t().mm(mean_real) |
|
cov_real = cov_real_num / (self.real_features_num_samples - 1) |
|
cov_fake_num = self.fake_features_cov_sum - self.fake_features_num_samples * mean_fake.t().mm(mean_fake) |
|
cov_fake = cov_fake_num / (self.fake_features_num_samples - 1) |
|
return _compute_fid(mean_real.squeeze(0), cov_real, mean_fake.squeeze(0), cov_fake).to(self.orig_dtype) |
|
|
|
def reset(self) -> None: |
|
"""Reset metric states.""" |
|
if not self.reset_real_features: |
|
real_features_sum = deepcopy(self.real_features_sum) |
|
real_features_cov_sum = deepcopy(self.real_features_cov_sum) |
|
real_features_num_samples = deepcopy(self.real_features_num_samples) |
|
super().reset() |
|
self.real_features_sum = real_features_sum |
|
self.real_features_cov_sum = real_features_cov_sum |
|
self.real_features_num_samples = real_features_num_samples |
|
else: |
|
super().reset() |
|
|
|
def set_dtype(self, dst_type: Union[str, torch.dtype]) -> "Metric": |
|
"""Transfer all metric state to specific dtype. Special version of standard `type` method. |
|
|
|
Arguments: |
|
dst_type: the desired type as ``torch.dtype`` or string |
|
|
|
""" |
|
out = super().set_dtype(dst_type) |
|
if isinstance(out.inception, NoTrainInceptionV3): |
|
out.inception._dtype = dst_type |
|
return out |
|
|
|
def plot( |
|
self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None |
|
) -> _PLOT_OUT_TYPE: |
|
"""Plot a single or multiple values from the metric. |
|
|
|
Args: |
|
val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. |
|
If no value is provided, will automatically call `metric.compute` and plot that result. |
|
ax: An matplotlib axis object. If provided will add plot to that axis |
|
|
|
Returns: |
|
Figure and Axes object |
|
|
|
Raises: |
|
ModuleNotFoundError: |
|
If `matplotlib` is not installed |
|
|
|
.. plot:: |
|
:scale: 75 |
|
|
|
>>> # Example plotting a single value |
|
>>> import torch |
|
>>> from torchmetrics.image.fid import FrechetInceptionDistance |
|
>>> imgs_dist1 = torch.randint(0, 200, (100, 3, 299, 299), dtype=torch.uint8) |
|
>>> imgs_dist2 = torch.randint(100, 255, (100, 3, 299, 299), dtype=torch.uint8) |
|
>>> metric = FrechetInceptionDistance(feature=64) |
|
>>> metric.update(imgs_dist1, real=True) |
|
>>> metric.update(imgs_dist2, real=False) |
|
>>> fig_, ax_ = metric.plot() |
|
|
|
.. plot:: |
|
:scale: 75 |
|
|
|
>>> # Example plotting multiple values |
|
>>> import torch |
|
>>> from torchmetrics.image.fid import FrechetInceptionDistance |
|
>>> imgs_dist1 = lambda: torch.randint(0, 200, (100, 3, 299, 299), dtype=torch.uint8) |
|
>>> imgs_dist2 = lambda: torch.randint(100, 255, (100, 3, 299, 299), dtype=torch.uint8) |
|
>>> metric = FrechetInceptionDistance(feature=64) |
|
>>> values = [ ] |
|
>>> for _ in range(3): |
|
... metric.update(imgs_dist1(), real=True) |
|
... metric.update(imgs_dist2(), real=False) |
|
... values.append(metric.compute()) |
|
... metric.reset() |
|
>>> fig_, ax_ = metric.plot(values) |
|
|
|
""" |
|
return self._plot(val, ax) |
|
|