# Copyright The Lightning team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from collections.abc import Sequence from typing import Any, Callable, Optional, Union import torch from torch import Tensor from typing_extensions import Literal from torchmetrics.metric import Metric from torchmetrics.utilities import rank_zero_warn from torchmetrics.utilities.data import dim_zero_cat from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE from torchmetrics.wrappers.running import Running if not _MATPLOTLIB_AVAILABLE: __doctest_skip__ = ["SumMetric.plot", "MeanMetric.plot", "MaxMetric.plot", "MinMetric.plot"] class BaseAggregator(Metric): """Base class for aggregation metrics. Args: fn: string specifying the reduction function default_value: default tensor value to use for the metric state nan_strategy: options: - ``'error'``: if any `nan` values are encountered will give a RuntimeError - ``'warn'``: if any `nan` values are encountered will give a warning and continue - ``'ignore'``: all `nan` values are silently removed - ``'disable'``: disable all `nan` checks - a float: if a float is provided will impute any `nan` values with this value state_name: name of the metric state kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. Raises: ValueError: If ``nan_strategy`` is not one of ``error``, ``warn``, ``ignore``, ``disable`` or a float """ is_differentiable = None higher_is_better = None full_state_update: bool = False def __init__( self, fn: Union[Callable, str], default_value: Union[Tensor, list], nan_strategy: Union[Literal["error", "warn", "ignore", "disable"], float] = "error", state_name: str = "value", **kwargs: Any, ) -> None: super().__init__(**kwargs) allowed_nan_strategy = ("error", "warn", "ignore", "disable") if nan_strategy not in allowed_nan_strategy and not isinstance(nan_strategy, float): raise ValueError( f"Arg `nan_strategy` should either be a float or one of {allowed_nan_strategy} but got {nan_strategy}." ) self.nan_strategy = nan_strategy self.add_state(state_name, default=default_value, dist_reduce_fx=fn) self.state_name = state_name def _cast_and_nan_check_input( self, x: Union[float, Tensor], weight: Optional[Union[float, Tensor]] = None ) -> tuple[Tensor, Tensor]: """Convert input ``x`` to a tensor and check for Nans.""" if not isinstance(x, Tensor): x = torch.as_tensor(x, dtype=self.dtype, device=self.device) if weight is not None and not isinstance(weight, Tensor): weight = torch.as_tensor(weight, dtype=self.dtype, device=self.device) if self.nan_strategy != "disable": nans = torch.isnan(x) if weight is not None: nans_weight = torch.isnan(weight) else: nans_weight = torch.zeros_like(nans).bool() weight = torch.ones_like(x) if nans.any() or nans_weight.any(): if self.nan_strategy == "error": raise RuntimeError("Encountered `nan` values in tensor") if self.nan_strategy in ("ignore", "warn"): if self.nan_strategy == "warn": rank_zero_warn("Encountered `nan` values in tensor. Will be removed.", UserWarning) x = x[~(nans | nans_weight)] weight = weight[~(nans | nans_weight)] else: if not isinstance(self.nan_strategy, float): raise ValueError(f"`nan_strategy` shall be float but you pass {self.nan_strategy}") x[nans | nans_weight] = self.nan_strategy weight[nans | nans_weight] = 1 else: weight = torch.ones_like(x) return x.to(self.dtype), weight.to(self.dtype) def update(self, value: Union[float, Tensor]) -> None: """Overwrite in child class.""" def compute(self) -> Tensor: """Compute the aggregated value.""" return getattr(self, self.state_name) class MaxMetric(BaseAggregator): """Aggregate a stream of value into their maximum value. As input to ``forward`` and ``update`` the metric accepts the following input - ``value`` (:class:`~float` or :class:`~torch.Tensor`): a single float or an tensor of float values with arbitrary shape ``(...,)``. As output of `forward` and `compute` the metric returns the following output - ``agg`` (:class:`~torch.Tensor`): scalar float tensor with aggregated maximum value over all inputs received Args: nan_strategy: options: - ``'error'``: if any `nan` values are encountered will give a RuntimeError - ``'warn'``: if any `nan` values are encountered will give a warning and continue - ``'ignore'``: all `nan` values are silently removed - ``'disable'``: disable all `nan` checks - a float: if a float is provided will impute any `nan` values with this value kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. Raises: ValueError: If ``nan_strategy`` is not one of ``error``, ``warn``, ``ignore``, ``disable`` or a float Example: >>> from torch import tensor >>> from torchmetrics.aggregation import MaxMetric >>> metric = MaxMetric() >>> metric.update(1) >>> metric.update(tensor([2, 3])) >>> metric.compute() tensor(3.) """ full_state_update: bool = True max_value: Tensor def __init__( self, nan_strategy: Union[Literal["error", "warn", "ignore", "disable"], float] = "warn", **kwargs: Any, ) -> None: super().__init__( "max", -torch.tensor(float("inf"), dtype=torch.get_default_dtype()), nan_strategy, state_name="max_value", **kwargs, ) def update(self, value: Union[float, Tensor]) -> None: """Update state with data. Args: value: Either a float or tensor containing data. Additional tensor dimensions will be flattened """ value, _ = self._cast_and_nan_check_input(value) if value.numel(): # make sure tensor not empty self.max_value = torch.max(self.max_value, torch.max(value)) def plot( self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None ) -> _PLOT_OUT_TYPE: """Plot a single or multiple values from the metric. Args: val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. If no value is provided, will automatically call `metric.compute` and plot that result. ax: An matplotlib axis object. If provided will add plot to that axis Returns: Figure and Axes object Raises: ModuleNotFoundError: If `matplotlib` is not installed .. plot:: :scale: 75 >>> # Example plotting a single value >>> from torchmetrics.aggregation import MaxMetric >>> metric = MaxMetric() >>> metric.update([1, 2, 3]) >>> fig_, ax_ = metric.plot() .. plot:: :scale: 75 >>> # Example plotting multiple values >>> from torchmetrics.aggregation import MaxMetric >>> metric = MaxMetric() >>> values = [ ] >>> for i in range(10): ... values.append(metric(i)) >>> fig_, ax_ = metric.plot(values) """ return self._plot(val, ax) class MinMetric(BaseAggregator): """Aggregate a stream of value into their minimum value. As input to ``forward`` and ``update`` the metric accepts the following input - ``value`` (:class:`~float` or :class:`~torch.Tensor`): a single float or an tensor of float values with arbitrary shape ``(...,)``. As output of `forward` and `compute` the metric returns the following output - ``agg`` (:class:`~torch.Tensor`): scalar float tensor with aggregated minimum value over all inputs received Args: nan_strategy: options: - ``'error'``: if any `nan` values are encountered will give a RuntimeError - ``'warn'``: if any `nan` values are encountered will give a warning and continue - ``'ignore'``: all `nan` values are silently removed - ``'disable'``: disable all `nan` checks - a float: if a float is provided will impute any `nan` values with this value kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. Raises: ValueError: If ``nan_strategy`` is not one of ``error``, ``warn``, ``ignore``, ``disable`` or a float Example: >>> from torch import tensor >>> from torchmetrics.aggregation import MinMetric >>> metric = MinMetric() >>> metric.update(1) >>> metric.update(tensor([2, 3])) >>> metric.compute() tensor(1.) """ full_state_update: bool = True min_value: Tensor def __init__( self, nan_strategy: Union[Literal["error", "warn", "ignore", "disable"], float] = "warn", **kwargs: Any, ) -> None: super().__init__( "min", torch.tensor(float("inf"), dtype=torch.get_default_dtype()), nan_strategy, state_name="min_value", **kwargs, ) def update(self, value: Union[float, Tensor]) -> None: """Update state with data. Args: value: Either a float or tensor containing data. Additional tensor dimensions will be flattened """ value, _ = self._cast_and_nan_check_input(value) if value.numel(): # make sure tensor not empty self.min_value = torch.min(self.min_value, torch.min(value)) def plot( self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None ) -> _PLOT_OUT_TYPE: """Plot a single or multiple values from the metric. Args: val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. If no value is provided, will automatically call `metric.compute` and plot that result. ax: An matplotlib axis object. If provided will add plot to that axis Returns: Figure and Axes object Raises: ModuleNotFoundError: If `matplotlib` is not installed .. plot:: :scale: 75 >>> # Example plotting a single value >>> from torchmetrics.aggregation import MinMetric >>> metric = MinMetric() >>> metric.update([1, 2, 3]) >>> fig_, ax_ = metric.plot() .. plot:: :scale: 75 >>> # Example plotting multiple values >>> from torchmetrics.aggregation import MinMetric >>> metric = MinMetric() >>> values = [ ] >>> for i in range(10): ... values.append(metric(i)) >>> fig_, ax_ = metric.plot(values) """ return self._plot(val, ax) class SumMetric(BaseAggregator): """Aggregate a stream of value into their sum. As input to ``forward`` and ``update`` the metric accepts the following input - ``value`` (:class:`~float` or :class:`~torch.Tensor`): a single float or an tensor of float values with arbitrary shape ``(...,)``. As output of `forward` and `compute` the metric returns the following output - ``agg`` (:class:`~torch.Tensor`): scalar float tensor with aggregated sum over all inputs received Args: nan_strategy: options: - ``'error'``: if any `nan` values are encountered will give a RuntimeError - ``'warn'``: if any `nan` values are encountered will give a warning and continue - ``'ignore'``: all `nan` values are silently removed - ``'disable'``: disable all `nan` checks - a float: if a float is provided will impute any `nan` values with this value kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. Raises: ValueError: If ``nan_strategy`` is not one of ``error``, ``warn``, ``ignore``, ``disable`` or a float Example: >>> from torch import tensor >>> from torchmetrics.aggregation import SumMetric >>> metric = SumMetric() >>> metric.update(1) >>> metric.update(tensor([2, 3])) >>> metric.compute() tensor(6.) """ sum_value: Tensor def __init__( self, nan_strategy: Union[Literal["error", "warn", "ignore", "disable"], float] = "warn", **kwargs: Any, ) -> None: super().__init__( "sum", torch.tensor(0.0, dtype=torch.get_default_dtype()), nan_strategy, state_name="sum_value", **kwargs, ) def update(self, value: Union[float, Tensor]) -> None: """Update state with data. Args: value: Either a float or tensor containing data. Additional tensor dimensions will be flattened """ value, _ = self._cast_and_nan_check_input(value) if value.numel(): self.sum_value += value.sum() def plot( self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None ) -> _PLOT_OUT_TYPE: """Plot a single or multiple values from the metric. Args: val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. If no value is provided, will automatically call `metric.compute` and plot that result. ax: An matplotlib axis object. If provided will add plot to that axis Returns: Figure and Axes object Raises: ModuleNotFoundError: If `matplotlib` is not installed .. plot:: :scale: 75 >>> # Example plotting a single value >>> from torchmetrics.aggregation import SumMetric >>> metric = SumMetric() >>> metric.update([1, 2, 3]) >>> fig_, ax_ = metric.plot() .. plot:: :scale: 75 >>> # Example plotting multiple values >>> from torch import rand, randint >>> from torchmetrics.aggregation import SumMetric >>> metric = SumMetric() >>> values = [ ] >>> for i in range(10): ... values.append(metric([i, i+1])) >>> fig_, ax_ = metric.plot(values) """ return self._plot(val, ax) class CatMetric(BaseAggregator): """Concatenate a stream of values. As input to ``forward`` and ``update`` the metric accepts the following input - ``value`` (:class:`~float` or :class:`~torch.Tensor`): a single float or an tensor of float values with arbitrary shape ``(...,)``. As output of `forward` and `compute` the metric returns the following output - ``agg`` (:class:`~torch.Tensor`): scalar float tensor with concatenated values over all input received Args: nan_strategy: options: - ``'error'``: if any `nan` values are encountered will give a RuntimeError - ``'warn'``: if any `nan` values are encountered will give a warning and continue - ``'ignore'``: all `nan` values are silently removed - ``'disable'``: disable all `nan` checks - a float: if a float is provided will impute any `nan` values with this value kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. Raises: ValueError: If ``nan_strategy`` is not one of ``error``, ``warn``, ``ignore``, ``disable`` or a float Example: >>> from torch import tensor >>> from torchmetrics.aggregation import CatMetric >>> metric = CatMetric() >>> metric.update(1) >>> metric.update(tensor([2, 3])) >>> metric.compute() tensor([1., 2., 3.]) """ value: Tensor def __init__( self, nan_strategy: Union[Literal["error", "warn", "ignore", "disable"], float] = "warn", **kwargs: Any, ) -> None: super().__init__("cat", [], nan_strategy, **kwargs) def update(self, value: Union[float, Tensor]) -> None: """Update state with data. Args: value: Either a float or tensor containing data. Additional tensor dimensions will be flattened """ value, _ = self._cast_and_nan_check_input(value) if value.numel(): self.value.append(value) def compute(self) -> Tensor: """Compute the aggregated value.""" if isinstance(self.value, list) and self.value: return dim_zero_cat(self.value) return self.value class MeanMetric(BaseAggregator): """Aggregate a stream of value into their mean value. As input to ``forward`` and ``update`` the metric accepts the following input - ``value`` (:class:`~float` or :class:`~torch.Tensor`): a single float or an tensor of float values with arbitrary shape ``(...,)``. - ``weight`` (:class:`~float` or :class:`~torch.Tensor`): a single float or an tensor of float value with arbitrary shape ``(...,)``. Needs to be broadcastable with the shape of ``value`` tensor. As output of `forward` and `compute` the metric returns the following output - ``agg`` (:class:`~torch.Tensor`): scalar float tensor with aggregated (weighted) mean over all inputs received Args: nan_strategy: options: - ``'error'``: if any `nan` values are encountered will give a RuntimeError - ``'warn'``: if any `nan` values are encountered will give a warning and continue - ``'ignore'``: all `nan` values are silently removed - ``'disable'``: disable all `nan` checks - a float: if a float is provided will impute any `nan` values with this value kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. Raises: ValueError: If ``nan_strategy`` is not one of ``error``, ``warn``, ``ignore``, ``disable`` or a float Example: >>> from torchmetrics.aggregation import MeanMetric >>> metric = MeanMetric() >>> metric.update(1) >>> metric.update(torch.tensor([2, 3])) >>> metric.compute() tensor(2.) """ mean_value: Tensor weight: Tensor def __init__( self, nan_strategy: Union[Literal["error", "warn", "ignore", "disable"], float] = "warn", **kwargs: Any, ) -> None: super().__init__( "sum", torch.tensor(0.0, dtype=torch.get_default_dtype()), nan_strategy, state_name="mean_value", **kwargs, ) self.add_state("weight", default=torch.tensor(0.0, dtype=torch.get_default_dtype()), dist_reduce_fx="sum") def update(self, value: Union[float, Tensor], weight: Union[float, Tensor, None] = None) -> None: """Update state with data. Args: value: Either a float or tensor containing data. Additional tensor dimensions will be flattened weight: Either a float or tensor containing weights for calculating the average. Shape of weight should be able to broadcast with the shape of `value`. Default to None corresponding to simple harmonic average. """ # broadcast weight to value shape if not isinstance(value, Tensor): value = torch.as_tensor(value, dtype=self.dtype, device=self.device) if weight is None: weight = torch.ones_like(value) elif not isinstance(weight, Tensor): weight = torch.as_tensor(weight, dtype=self.dtype, device=self.device) weight = torch.broadcast_to(weight, value.shape) value, weight = self._cast_and_nan_check_input(value, weight) if value.numel() == 0: return self.mean_value += (value * weight).sum() self.weight += weight.sum() def compute(self) -> Tensor: """Compute the aggregated value.""" return self.mean_value / self.weight def plot( self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None ) -> _PLOT_OUT_TYPE: """Plot a single or multiple values from the metric. Args: val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. If no value is provided, will automatically call `metric.compute` and plot that result. ax: An matplotlib axis object. If provided will add plot to that axis Returns: Figure and Axes object Raises: ModuleNotFoundError: If `matplotlib` is not installed .. plot:: :scale: 75 >>> # Example plotting a single value >>> from torchmetrics.aggregation import MeanMetric >>> metric = MeanMetric() >>> metric.update([1, 2, 3]) >>> fig_, ax_ = metric.plot() .. plot:: :scale: 75 >>> # Example plotting multiple values >>> from torchmetrics.aggregation import MeanMetric >>> metric = MeanMetric() >>> values = [ ] >>> for i in range(10): ... values.append(metric([i, i+1])) >>> fig_, ax_ = metric.plot(values) """ return self._plot(val, ax) class RunningMean(Running): """Aggregate a stream of value into their mean over a running window. Using this metric compared to `MeanMetric` allows for calculating metrics over a running window of values, instead of the whole history of values. This is beneficial when you want to get a better estimate of the metric during training and don't want to wait for the whole training to finish to get epoch level estimates. As input to ``forward`` and ``update`` the metric accepts the following input - ``value`` (:class:`~float` or :class:`~torch.Tensor`): a single float or an tensor of float values with arbitrary shape ``(...,)``. As output of `forward` and `compute` the metric returns the following output - ``agg`` (:class:`~torch.Tensor`): scalar float tensor with aggregated sum over all inputs received Args: nan_strategy: options: - ``'error'``: if any `nan` values are encountered will give a RuntimeError - ``'warn'``: if any `nan` values are encountered will give a warning and continue - ``'ignore'``: all `nan` values are silently removed - ``'disable'``: disable all `nan` checks - a float: if a float is provided will impute any `nan` values with this value kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. Raises: ValueError: If ``nan_strategy`` is not one of ``error``, ``warn``, ``ignore``, ``disable`` or a float Example: >>> from torch import tensor >>> from torchmetrics.aggregation import RunningMean >>> metric = RunningMean(window=3) >>> for i in range(6): ... current_val = metric(tensor([i])) ... running_val = metric.compute() ... total_val = tensor(sum(list(range(i+1)))) / (i+1) # total mean over all samples ... print(f"{current_val=}, {running_val=}, {total_val=}") current_val=tensor(0.), running_val=tensor(0.), total_val=tensor(0.) current_val=tensor(1.), running_val=tensor(0.5000), total_val=tensor(0.5000) current_val=tensor(2.), running_val=tensor(1.), total_val=tensor(1.) current_val=tensor(3.), running_val=tensor(2.), total_val=tensor(1.5000) current_val=tensor(4.), running_val=tensor(3.), total_val=tensor(2.) current_val=tensor(5.), running_val=tensor(4.), total_val=tensor(2.5000) """ def __init__( self, window: int = 5, nan_strategy: Union[Literal["error", "warn", "ignore", "disable"], float] = "warn", **kwargs: Any, ) -> None: super().__init__(base_metric=MeanMetric(nan_strategy=nan_strategy, **kwargs), window=window) class RunningSum(Running): """Aggregate a stream of value into their sum over a running window. Using this metric compared to `SumMetric` allows for calculating metrics over a running window of values, instead of the whole history of values. This is beneficial when you want to get a better estimate of the metric during training and don't want to wait for the whole training to finish to get epoch level estimates. As input to ``forward`` and ``update`` the metric accepts the following input - ``value`` (:class:`~float` or :class:`~torch.Tensor`): a single float or an tensor of float values with arbitrary shape ``(...,)``. As output of `forward` and `compute` the metric returns the following output - ``agg`` (:class:`~torch.Tensor`): scalar float tensor with aggregated sum over all inputs received Args: window: The size of the running window. nan_strategy: options: - ``'error'``: if any `nan` values are encountered will give a RuntimeError - ``'warn'``: if any `nan` values are encountered will give a warning and continue - ``'ignore'``: all `nan` values are silently removed - ``'disable'``: disable all `nan` checks - a float: if a float is provided will impute any `nan` values with this value kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. Raises: ValueError: If ``nan_strategy`` is not one of ``error``, ``warn``, ``ignore``, ``disable`` or a float Example: >>> from torch import tensor >>> from torchmetrics.aggregation import RunningSum >>> metric = RunningSum(window=3) >>> for i in range(6): ... current_val = metric(tensor([i])) ... running_val = metric.compute() ... total_val = tensor(sum(list(range(i+1)))) # total sum over all samples ... print(f"{current_val=}, {running_val=}, {total_val=}") current_val=tensor(0.), running_val=tensor(0.), total_val=tensor(0) current_val=tensor(1.), running_val=tensor(1.), total_val=tensor(1) current_val=tensor(2.), running_val=tensor(3.), total_val=tensor(3) current_val=tensor(3.), running_val=tensor(6.), total_val=tensor(6) current_val=tensor(4.), running_val=tensor(9.), total_val=tensor(10) current_val=tensor(5.), running_val=tensor(12.), total_val=tensor(15) """ def __init__( self, window: int = 5, nan_strategy: Union[Literal["error", "warn", "ignore", "disable"], float] = "warn", **kwargs: Any, ) -> None: super().__init__(base_metric=SumMetric(nan_strategy=nan_strategy, **kwargs), window=window)