# Copyright The Lightning team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from collections.abc import Sequence from copy import deepcopy from typing import Any, Optional, Union import torch from lightning_utilities import apply_to_collection from torch import Tensor from torch.nn import ModuleList from torchmetrics.metric import Metric from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE from torchmetrics.wrappers.abstract import WrapperMetric if not _MATPLOTLIB_AVAILABLE: __doctest_skip__ = ["BootStrapper.plot"] def _bootstrap_sampler( size: int, sampling_strategy: str = "poisson", ) -> torch.Tensor: """Resample a tensor along its first dimension with replacement. Args: size: number of samples sampling_strategy: the strategy to use for sampling, either ``'poisson'`` or ``'multinomial'`` Returns: resampled tensor """ if sampling_strategy == "poisson": p = torch.distributions.Poisson(1) n = p.sample((size,)) return torch.arange(size).repeat_interleave(n.long(), dim=0) if sampling_strategy == "multinomial": return torch.multinomial(torch.ones(size), num_samples=size, replacement=True) raise ValueError("Unknown sampling strategy") class BootStrapper(WrapperMetric): r"""Using `Turn a Metric into a Bootstrapped`_. That can automate the process of getting confidence intervals for metric values. This wrapper class basically keeps multiple copies of the same base metric in memory and whenever ``update`` or ``forward`` is called, all input tensors are resampled (with replacement) along the first dimension. Args: base_metric: base metric class to wrap num_bootstraps: number of copies to make of the base metric for bootstrapping mean: if ``True`` return the mean of the bootstraps std: if ``True`` return the standard deviation of the bootstraps quantile: if given, returns the quantile of the bootstraps. Can only be used with pytorch version 1.6 or higher raw: if ``True``, return all bootstrapped values sampling_strategy: Determines how to produce bootstrapped samplings. Either ``'poisson'`` or ``multinomial``. If ``'possion'`` is chosen, the number of times each sample will be included in the bootstrap will be given by :math:`n\sim Poisson(\lambda=1)`, which approximates the true bootstrap distribution when the number of samples is large. If ``'multinomial'`` is chosen, we will apply true bootstrapping at the batch level to approximate bootstrapping over the hole dataset. kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. Example:: >>> from pprint import pprint >>> from torch import randint >>> from torchmetrics.wrappers import BootStrapper >>> from torchmetrics.classification import MulticlassAccuracy >>> base_metric = MulticlassAccuracy(num_classes=5, average='micro') >>> bootstrap = BootStrapper(base_metric, num_bootstraps=20) >>> bootstrap.update(randint(5, (20,)), randint(5, (20,))) >>> output = bootstrap.compute() >>> pprint(output) {'mean': tensor(0.2089), 'std': tensor(0.0772)} """ full_state_update: Optional[bool] = True def __init__( self, base_metric: Metric, num_bootstraps: int = 10, mean: bool = True, std: bool = True, quantile: Optional[Union[float, Tensor]] = None, raw: bool = False, sampling_strategy: str = "poisson", **kwargs: Any, ) -> None: super().__init__(**kwargs) if not isinstance(base_metric, Metric): raise ValueError( f"Expected base metric to be an instance of torchmetrics.Metric but received {base_metric}" ) self.metrics = ModuleList([deepcopy(base_metric) for _ in range(num_bootstraps)]) self.num_bootstraps = num_bootstraps self.mean = mean self.std = std self.quantile = quantile self.raw = raw allowed_sampling = ("poisson", "multinomial") if sampling_strategy not in allowed_sampling: raise ValueError( f"Expected argument ``sampling_strategy`` to be one of {allowed_sampling}" f" but received {sampling_strategy}" ) self.sampling_strategy = sampling_strategy def update(self, *args: Any, **kwargs: Any) -> None: """Update the state of the base metric. Any tensor passed in will be bootstrapped along dimension 0. """ args_sizes = apply_to_collection(args, torch.Tensor, len) kwargs_sizes = apply_to_collection(kwargs, torch.Tensor, len) if len(args_sizes) > 0: size = args_sizes[0] elif len(kwargs_sizes) > 0: size = next(iter(kwargs_sizes.values())) else: raise ValueError("None of the input contained tensors, so could not determine the sampling size") for idx in range(self.num_bootstraps): sample_idx = _bootstrap_sampler(size, sampling_strategy=self.sampling_strategy).to(self.device) if sample_idx.numel() == 0: continue new_args = apply_to_collection(args, torch.Tensor, torch.index_select, dim=0, index=sample_idx) new_kwargs = apply_to_collection(kwargs, torch.Tensor, torch.index_select, dim=0, index=sample_idx) self.metrics[idx].update(*new_args, **new_kwargs) # type: ignore[operator] # needed for mypy def compute(self) -> dict[str, Tensor]: """Compute the bootstrapped metric values. Always returns a dict of tensors, which can contain the following keys: ``mean``, ``std``, ``quantile`` and ``raw`` depending on how the class was initialized. """ computed_vals = torch.stack([m.compute() for m in self.metrics], dim=0) output_dict = {} if self.mean: output_dict["mean"] = computed_vals.mean(dim=0) if self.std: output_dict["std"] = computed_vals.std(dim=0) if self.quantile is not None: output_dict["quantile"] = torch.quantile(computed_vals, self.quantile) if self.raw: output_dict["raw"] = computed_vals return output_dict def forward(self, *args: Any, **kwargs: Any) -> Any: """Use the original forward method of the base metric class.""" return super(WrapperMetric, self).forward(*args, **kwargs) def reset(self) -> None: """Reset the state of the base metric.""" for m in self.metrics: m.reset() super().reset() def plot( self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None ) -> _PLOT_OUT_TYPE: """Plot a single or multiple values from the metric. Args: val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. If no value is provided, will automatically call `metric.compute` and plot that result. ax: An matplotlib axis object. If provided will add plot to that axis Returns: Figure and Axes object Raises: ModuleNotFoundError: If `matplotlib` is not installed .. plot:: :scale: 75 >>> # Example plotting a single value >>> import torch >>> from torchmetrics.wrappers import BootStrapper >>> from torchmetrics.regression import MeanSquaredError >>> metric = BootStrapper(MeanSquaredError(), num_bootstraps=20) >>> metric.update(torch.randn(100,), torch.randn(100,)) >>> fig_, ax_ = metric.plot() .. plot:: :scale: 75 >>> # Example plotting multiple values >>> import torch >>> from torchmetrics.wrappers import BootStrapper >>> from torchmetrics.regression import MeanSquaredError >>> metric = BootStrapper(MeanSquaredError(), num_bootstraps=20) >>> values = [ ] >>> for _ in range(3): ... values.append(metric(torch.randn(100,), torch.randn(100,))) >>> fig_, ax_ = metric.plot(values) """ return self._plot(val, ax)