|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from collections.abc import Sequence |
|
from typing import Any, Callable, Optional, Union |
|
|
|
import torch |
|
from torch import Tensor |
|
from typing_extensions import Literal |
|
|
|
from torchmetrics.metric import Metric |
|
from torchmetrics.utilities import rank_zero_warn |
|
from torchmetrics.utilities.data import dim_zero_cat |
|
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE |
|
from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE |
|
from torchmetrics.wrappers.running import Running |
|
|
|
if not _MATPLOTLIB_AVAILABLE: |
|
__doctest_skip__ = ["SumMetric.plot", "MeanMetric.plot", "MaxMetric.plot", "MinMetric.plot"] |
|
|
|
|
|
class BaseAggregator(Metric): |
|
"""Base class for aggregation metrics. |
|
|
|
Args: |
|
fn: string specifying the reduction function |
|
default_value: default tensor value to use for the metric state |
|
nan_strategy: options: |
|
- ``'error'``: if any `nan` values are encountered will give a RuntimeError |
|
- ``'warn'``: if any `nan` values are encountered will give a warning and continue |
|
- ``'ignore'``: all `nan` values are silently removed |
|
- ``'disable'``: disable all `nan` checks |
|
- a float: if a float is provided will impute any `nan` values with this value |
|
|
|
state_name: name of the metric state |
|
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. |
|
|
|
Raises: |
|
ValueError: |
|
If ``nan_strategy`` is not one of ``error``, ``warn``, ``ignore``, ``disable`` or a float |
|
|
|
""" |
|
|
|
is_differentiable = None |
|
higher_is_better = None |
|
full_state_update: bool = False |
|
|
|
def __init__( |
|
self, |
|
fn: Union[Callable, str], |
|
default_value: Union[Tensor, list], |
|
nan_strategy: Union[Literal["error", "warn", "ignore", "disable"], float] = "error", |
|
state_name: str = "value", |
|
**kwargs: Any, |
|
) -> None: |
|
super().__init__(**kwargs) |
|
allowed_nan_strategy = ("error", "warn", "ignore", "disable") |
|
if nan_strategy not in allowed_nan_strategy and not isinstance(nan_strategy, float): |
|
raise ValueError( |
|
f"Arg `nan_strategy` should either be a float or one of {allowed_nan_strategy} but got {nan_strategy}." |
|
) |
|
|
|
self.nan_strategy = nan_strategy |
|
self.add_state(state_name, default=default_value, dist_reduce_fx=fn) |
|
self.state_name = state_name |
|
|
|
def _cast_and_nan_check_input( |
|
self, x: Union[float, Tensor], weight: Optional[Union[float, Tensor]] = None |
|
) -> tuple[Tensor, Tensor]: |
|
"""Convert input ``x`` to a tensor and check for Nans.""" |
|
if not isinstance(x, Tensor): |
|
x = torch.as_tensor(x, dtype=self.dtype, device=self.device) |
|
if weight is not None and not isinstance(weight, Tensor): |
|
weight = torch.as_tensor(weight, dtype=self.dtype, device=self.device) |
|
|
|
if self.nan_strategy != "disable": |
|
nans = torch.isnan(x) |
|
if weight is not None: |
|
nans_weight = torch.isnan(weight) |
|
else: |
|
nans_weight = torch.zeros_like(nans).bool() |
|
weight = torch.ones_like(x) |
|
if nans.any() or nans_weight.any(): |
|
if self.nan_strategy == "error": |
|
raise RuntimeError("Encountered `nan` values in tensor") |
|
if self.nan_strategy in ("ignore", "warn"): |
|
if self.nan_strategy == "warn": |
|
rank_zero_warn("Encountered `nan` values in tensor. Will be removed.", UserWarning) |
|
x = x[~(nans | nans_weight)] |
|
weight = weight[~(nans | nans_weight)] |
|
else: |
|
if not isinstance(self.nan_strategy, float): |
|
raise ValueError(f"`nan_strategy` shall be float but you pass {self.nan_strategy}") |
|
x[nans | nans_weight] = self.nan_strategy |
|
weight[nans | nans_weight] = 1 |
|
else: |
|
weight = torch.ones_like(x) |
|
return x.to(self.dtype), weight.to(self.dtype) |
|
|
|
def update(self, value: Union[float, Tensor]) -> None: |
|
"""Overwrite in child class.""" |
|
|
|
def compute(self) -> Tensor: |
|
"""Compute the aggregated value.""" |
|
return getattr(self, self.state_name) |
|
|
|
|
|
class MaxMetric(BaseAggregator): |
|
"""Aggregate a stream of value into their maximum value. |
|
|
|
As input to ``forward`` and ``update`` the metric accepts the following input |
|
|
|
- ``value`` (:class:`~float` or :class:`~torch.Tensor`): a single float or an tensor of float values with |
|
arbitrary shape ``(...,)``. |
|
|
|
As output of `forward` and `compute` the metric returns the following output |
|
|
|
- ``agg`` (:class:`~torch.Tensor`): scalar float tensor with aggregated maximum value over all inputs received |
|
|
|
Args: |
|
nan_strategy: options: |
|
- ``'error'``: if any `nan` values are encountered will give a RuntimeError |
|
- ``'warn'``: if any `nan` values are encountered will give a warning and continue |
|
- ``'ignore'``: all `nan` values are silently removed |
|
- ``'disable'``: disable all `nan` checks |
|
- a float: if a float is provided will impute any `nan` values with this value |
|
|
|
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. |
|
|
|
Raises: |
|
ValueError: |
|
If ``nan_strategy`` is not one of ``error``, ``warn``, ``ignore``, ``disable`` or a float |
|
|
|
Example: |
|
>>> from torch import tensor |
|
>>> from torchmetrics.aggregation import MaxMetric |
|
>>> metric = MaxMetric() |
|
>>> metric.update(1) |
|
>>> metric.update(tensor([2, 3])) |
|
>>> metric.compute() |
|
tensor(3.) |
|
|
|
""" |
|
|
|
full_state_update: bool = True |
|
max_value: Tensor |
|
|
|
def __init__( |
|
self, |
|
nan_strategy: Union[Literal["error", "warn", "ignore", "disable"], float] = "warn", |
|
**kwargs: Any, |
|
) -> None: |
|
super().__init__( |
|
"max", |
|
-torch.tensor(float("inf"), dtype=torch.get_default_dtype()), |
|
nan_strategy, |
|
state_name="max_value", |
|
**kwargs, |
|
) |
|
|
|
def update(self, value: Union[float, Tensor]) -> None: |
|
"""Update state with data. |
|
|
|
Args: |
|
value: Either a float or tensor containing data. Additional tensor |
|
dimensions will be flattened |
|
|
|
""" |
|
value, _ = self._cast_and_nan_check_input(value) |
|
if value.numel(): |
|
self.max_value = torch.max(self.max_value, torch.max(value)) |
|
|
|
def plot( |
|
self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None |
|
) -> _PLOT_OUT_TYPE: |
|
"""Plot a single or multiple values from the metric. |
|
|
|
Args: |
|
val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. |
|
If no value is provided, will automatically call `metric.compute` and plot that result. |
|
ax: An matplotlib axis object. If provided will add plot to that axis |
|
|
|
Returns: |
|
Figure and Axes object |
|
|
|
Raises: |
|
ModuleNotFoundError: |
|
If `matplotlib` is not installed |
|
|
|
.. plot:: |
|
:scale: 75 |
|
|
|
>>> # Example plotting a single value |
|
>>> from torchmetrics.aggregation import MaxMetric |
|
>>> metric = MaxMetric() |
|
>>> metric.update([1, 2, 3]) |
|
>>> fig_, ax_ = metric.plot() |
|
|
|
.. plot:: |
|
:scale: 75 |
|
|
|
>>> # Example plotting multiple values |
|
>>> from torchmetrics.aggregation import MaxMetric |
|
>>> metric = MaxMetric() |
|
>>> values = [ ] |
|
>>> for i in range(10): |
|
... values.append(metric(i)) |
|
>>> fig_, ax_ = metric.plot(values) |
|
|
|
""" |
|
return self._plot(val, ax) |
|
|
|
|
|
class MinMetric(BaseAggregator): |
|
"""Aggregate a stream of value into their minimum value. |
|
|
|
As input to ``forward`` and ``update`` the metric accepts the following input |
|
|
|
- ``value`` (:class:`~float` or :class:`~torch.Tensor`): a single float or an tensor of float values with |
|
arbitrary shape ``(...,)``. |
|
|
|
As output of `forward` and `compute` the metric returns the following output |
|
|
|
- ``agg`` (:class:`~torch.Tensor`): scalar float tensor with aggregated minimum value over all inputs received |
|
|
|
Args: |
|
nan_strategy: options: |
|
- ``'error'``: if any `nan` values are encountered will give a RuntimeError |
|
- ``'warn'``: if any `nan` values are encountered will give a warning and continue |
|
- ``'ignore'``: all `nan` values are silently removed |
|
- ``'disable'``: disable all `nan` checks |
|
- a float: if a float is provided will impute any `nan` values with this value |
|
|
|
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. |
|
|
|
Raises: |
|
ValueError: |
|
If ``nan_strategy`` is not one of ``error``, ``warn``, ``ignore``, ``disable`` or a float |
|
|
|
Example: |
|
>>> from torch import tensor |
|
>>> from torchmetrics.aggregation import MinMetric |
|
>>> metric = MinMetric() |
|
>>> metric.update(1) |
|
>>> metric.update(tensor([2, 3])) |
|
>>> metric.compute() |
|
tensor(1.) |
|
|
|
""" |
|
|
|
full_state_update: bool = True |
|
min_value: Tensor |
|
|
|
def __init__( |
|
self, |
|
nan_strategy: Union[Literal["error", "warn", "ignore", "disable"], float] = "warn", |
|
**kwargs: Any, |
|
) -> None: |
|
super().__init__( |
|
"min", |
|
torch.tensor(float("inf"), dtype=torch.get_default_dtype()), |
|
nan_strategy, |
|
state_name="min_value", |
|
**kwargs, |
|
) |
|
|
|
def update(self, value: Union[float, Tensor]) -> None: |
|
"""Update state with data. |
|
|
|
Args: |
|
value: Either a float or tensor containing data. Additional tensor |
|
dimensions will be flattened |
|
|
|
""" |
|
value, _ = self._cast_and_nan_check_input(value) |
|
if value.numel(): |
|
self.min_value = torch.min(self.min_value, torch.min(value)) |
|
|
|
def plot( |
|
self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None |
|
) -> _PLOT_OUT_TYPE: |
|
"""Plot a single or multiple values from the metric. |
|
|
|
Args: |
|
val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. |
|
If no value is provided, will automatically call `metric.compute` and plot that result. |
|
ax: An matplotlib axis object. If provided will add plot to that axis |
|
|
|
Returns: |
|
Figure and Axes object |
|
|
|
Raises: |
|
ModuleNotFoundError: |
|
If `matplotlib` is not installed |
|
|
|
.. plot:: |
|
:scale: 75 |
|
|
|
>>> # Example plotting a single value |
|
>>> from torchmetrics.aggregation import MinMetric |
|
>>> metric = MinMetric() |
|
>>> metric.update([1, 2, 3]) |
|
>>> fig_, ax_ = metric.plot() |
|
|
|
.. plot:: |
|
:scale: 75 |
|
|
|
>>> # Example plotting multiple values |
|
>>> from torchmetrics.aggregation import MinMetric |
|
>>> metric = MinMetric() |
|
>>> values = [ ] |
|
>>> for i in range(10): |
|
... values.append(metric(i)) |
|
>>> fig_, ax_ = metric.plot(values) |
|
|
|
""" |
|
return self._plot(val, ax) |
|
|
|
|
|
class SumMetric(BaseAggregator): |
|
"""Aggregate a stream of value into their sum. |
|
|
|
As input to ``forward`` and ``update`` the metric accepts the following input |
|
|
|
- ``value`` (:class:`~float` or :class:`~torch.Tensor`): a single float or an tensor of float values with |
|
arbitrary shape ``(...,)``. |
|
|
|
As output of `forward` and `compute` the metric returns the following output |
|
|
|
- ``agg`` (:class:`~torch.Tensor`): scalar float tensor with aggregated sum over all inputs received |
|
|
|
Args: |
|
nan_strategy: options: |
|
- ``'error'``: if any `nan` values are encountered will give a RuntimeError |
|
- ``'warn'``: if any `nan` values are encountered will give a warning and continue |
|
- ``'ignore'``: all `nan` values are silently removed |
|
- ``'disable'``: disable all `nan` checks |
|
- a float: if a float is provided will impute any `nan` values with this value |
|
|
|
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. |
|
|
|
Raises: |
|
ValueError: |
|
If ``nan_strategy`` is not one of ``error``, ``warn``, ``ignore``, ``disable`` or a float |
|
|
|
Example: |
|
>>> from torch import tensor |
|
>>> from torchmetrics.aggregation import SumMetric |
|
>>> metric = SumMetric() |
|
>>> metric.update(1) |
|
>>> metric.update(tensor([2, 3])) |
|
>>> metric.compute() |
|
tensor(6.) |
|
|
|
""" |
|
|
|
sum_value: Tensor |
|
|
|
def __init__( |
|
self, |
|
nan_strategy: Union[Literal["error", "warn", "ignore", "disable"], float] = "warn", |
|
**kwargs: Any, |
|
) -> None: |
|
super().__init__( |
|
"sum", |
|
torch.tensor(0.0, dtype=torch.get_default_dtype()), |
|
nan_strategy, |
|
state_name="sum_value", |
|
**kwargs, |
|
) |
|
|
|
def update(self, value: Union[float, Tensor]) -> None: |
|
"""Update state with data. |
|
|
|
Args: |
|
value: Either a float or tensor containing data. Additional tensor |
|
dimensions will be flattened |
|
|
|
""" |
|
value, _ = self._cast_and_nan_check_input(value) |
|
if value.numel(): |
|
self.sum_value += value.sum() |
|
|
|
def plot( |
|
self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None |
|
) -> _PLOT_OUT_TYPE: |
|
"""Plot a single or multiple values from the metric. |
|
|
|
Args: |
|
val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. |
|
If no value is provided, will automatically call `metric.compute` and plot that result. |
|
ax: An matplotlib axis object. If provided will add plot to that axis |
|
|
|
Returns: |
|
Figure and Axes object |
|
|
|
Raises: |
|
ModuleNotFoundError: |
|
If `matplotlib` is not installed |
|
|
|
.. plot:: |
|
:scale: 75 |
|
|
|
>>> # Example plotting a single value |
|
>>> from torchmetrics.aggregation import SumMetric |
|
>>> metric = SumMetric() |
|
>>> metric.update([1, 2, 3]) |
|
>>> fig_, ax_ = metric.plot() |
|
|
|
.. plot:: |
|
:scale: 75 |
|
|
|
>>> # Example plotting multiple values |
|
>>> from torch import rand, randint |
|
>>> from torchmetrics.aggregation import SumMetric |
|
>>> metric = SumMetric() |
|
>>> values = [ ] |
|
>>> for i in range(10): |
|
... values.append(metric([i, i+1])) |
|
>>> fig_, ax_ = metric.plot(values) |
|
|
|
""" |
|
return self._plot(val, ax) |
|
|
|
|
|
class CatMetric(BaseAggregator): |
|
"""Concatenate a stream of values. |
|
|
|
As input to ``forward`` and ``update`` the metric accepts the following input |
|
|
|
- ``value`` (:class:`~float` or :class:`~torch.Tensor`): a single float or an tensor of float values with |
|
arbitrary shape ``(...,)``. |
|
|
|
As output of `forward` and `compute` the metric returns the following output |
|
|
|
- ``agg`` (:class:`~torch.Tensor`): scalar float tensor with concatenated values over all input received |
|
|
|
Args: |
|
nan_strategy: options: |
|
- ``'error'``: if any `nan` values are encountered will give a RuntimeError |
|
- ``'warn'``: if any `nan` values are encountered will give a warning and continue |
|
- ``'ignore'``: all `nan` values are silently removed |
|
- ``'disable'``: disable all `nan` checks |
|
- a float: if a float is provided will impute any `nan` values with this value |
|
|
|
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. |
|
|
|
Raises: |
|
ValueError: |
|
If ``nan_strategy`` is not one of ``error``, ``warn``, ``ignore``, ``disable`` or a float |
|
|
|
Example: |
|
>>> from torch import tensor |
|
>>> from torchmetrics.aggregation import CatMetric |
|
>>> metric = CatMetric() |
|
>>> metric.update(1) |
|
>>> metric.update(tensor([2, 3])) |
|
>>> metric.compute() |
|
tensor([1., 2., 3.]) |
|
|
|
""" |
|
|
|
value: Tensor |
|
|
|
def __init__( |
|
self, |
|
nan_strategy: Union[Literal["error", "warn", "ignore", "disable"], float] = "warn", |
|
**kwargs: Any, |
|
) -> None: |
|
super().__init__("cat", [], nan_strategy, **kwargs) |
|
|
|
def update(self, value: Union[float, Tensor]) -> None: |
|
"""Update state with data. |
|
|
|
Args: |
|
value: Either a float or tensor containing data. Additional tensor |
|
dimensions will be flattened |
|
|
|
""" |
|
value, _ = self._cast_and_nan_check_input(value) |
|
if value.numel(): |
|
self.value.append(value) |
|
|
|
def compute(self) -> Tensor: |
|
"""Compute the aggregated value.""" |
|
if isinstance(self.value, list) and self.value: |
|
return dim_zero_cat(self.value) |
|
return self.value |
|
|
|
|
|
class MeanMetric(BaseAggregator): |
|
"""Aggregate a stream of value into their mean value. |
|
|
|
As input to ``forward`` and ``update`` the metric accepts the following input |
|
|
|
- ``value`` (:class:`~float` or :class:`~torch.Tensor`): a single float or an tensor of float values with |
|
arbitrary shape ``(...,)``. |
|
- ``weight`` (:class:`~float` or :class:`~torch.Tensor`): a single float or an tensor of float value with |
|
arbitrary shape ``(...,)``. Needs to be broadcastable with the shape of ``value`` tensor. |
|
|
|
As output of `forward` and `compute` the metric returns the following output |
|
|
|
- ``agg`` (:class:`~torch.Tensor`): scalar float tensor with aggregated (weighted) mean over all inputs received |
|
|
|
Args: |
|
nan_strategy: options: |
|
- ``'error'``: if any `nan` values are encountered will give a RuntimeError |
|
- ``'warn'``: if any `nan` values are encountered will give a warning and continue |
|
- ``'ignore'``: all `nan` values are silently removed |
|
- ``'disable'``: disable all `nan` checks |
|
- a float: if a float is provided will impute any `nan` values with this value |
|
|
|
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. |
|
|
|
Raises: |
|
ValueError: |
|
If ``nan_strategy`` is not one of ``error``, ``warn``, ``ignore``, ``disable`` or a float |
|
|
|
Example: |
|
>>> from torchmetrics.aggregation import MeanMetric |
|
>>> metric = MeanMetric() |
|
>>> metric.update(1) |
|
>>> metric.update(torch.tensor([2, 3])) |
|
>>> metric.compute() |
|
tensor(2.) |
|
|
|
""" |
|
|
|
mean_value: Tensor |
|
weight: Tensor |
|
|
|
def __init__( |
|
self, |
|
nan_strategy: Union[Literal["error", "warn", "ignore", "disable"], float] = "warn", |
|
**kwargs: Any, |
|
) -> None: |
|
super().__init__( |
|
"sum", |
|
torch.tensor(0.0, dtype=torch.get_default_dtype()), |
|
nan_strategy, |
|
state_name="mean_value", |
|
**kwargs, |
|
) |
|
self.add_state("weight", default=torch.tensor(0.0, dtype=torch.get_default_dtype()), dist_reduce_fx="sum") |
|
|
|
def update(self, value: Union[float, Tensor], weight: Union[float, Tensor, None] = None) -> None: |
|
"""Update state with data. |
|
|
|
Args: |
|
value: Either a float or tensor containing data. Additional tensor |
|
dimensions will be flattened |
|
weight: Either a float or tensor containing weights for calculating |
|
the average. Shape of weight should be able to broadcast with |
|
the shape of `value`. Default to None corresponding to simple |
|
harmonic average. |
|
|
|
""" |
|
|
|
if not isinstance(value, Tensor): |
|
value = torch.as_tensor(value, dtype=self.dtype, device=self.device) |
|
if weight is None: |
|
weight = torch.ones_like(value) |
|
elif not isinstance(weight, Tensor): |
|
weight = torch.as_tensor(weight, dtype=self.dtype, device=self.device) |
|
weight = torch.broadcast_to(weight, value.shape) |
|
value, weight = self._cast_and_nan_check_input(value, weight) |
|
|
|
if value.numel() == 0: |
|
return |
|
self.mean_value += (value * weight).sum() |
|
self.weight += weight.sum() |
|
|
|
def compute(self) -> Tensor: |
|
"""Compute the aggregated value.""" |
|
return self.mean_value / self.weight |
|
|
|
def plot( |
|
self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None |
|
) -> _PLOT_OUT_TYPE: |
|
"""Plot a single or multiple values from the metric. |
|
|
|
Args: |
|
val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. |
|
If no value is provided, will automatically call `metric.compute` and plot that result. |
|
ax: An matplotlib axis object. If provided will add plot to that axis |
|
|
|
Returns: |
|
Figure and Axes object |
|
|
|
Raises: |
|
ModuleNotFoundError: |
|
If `matplotlib` is not installed |
|
|
|
.. plot:: |
|
:scale: 75 |
|
|
|
>>> # Example plotting a single value |
|
>>> from torchmetrics.aggregation import MeanMetric |
|
>>> metric = MeanMetric() |
|
>>> metric.update([1, 2, 3]) |
|
>>> fig_, ax_ = metric.plot() |
|
|
|
.. plot:: |
|
:scale: 75 |
|
|
|
>>> # Example plotting multiple values |
|
>>> from torchmetrics.aggregation import MeanMetric |
|
>>> metric = MeanMetric() |
|
>>> values = [ ] |
|
>>> for i in range(10): |
|
... values.append(metric([i, i+1])) |
|
>>> fig_, ax_ = metric.plot(values) |
|
|
|
""" |
|
return self._plot(val, ax) |
|
|
|
|
|
class RunningMean(Running): |
|
"""Aggregate a stream of value into their mean over a running window. |
|
|
|
Using this metric compared to `MeanMetric` allows for calculating metrics over a running window of values, instead |
|
of the whole history of values. This is beneficial when you want to get a better estimate of the metric during |
|
training and don't want to wait for the whole training to finish to get epoch level estimates. |
|
|
|
As input to ``forward`` and ``update`` the metric accepts the following input |
|
|
|
- ``value`` (:class:`~float` or :class:`~torch.Tensor`): a single float or an tensor of float values with |
|
arbitrary shape ``(...,)``. |
|
|
|
As output of `forward` and `compute` the metric returns the following output |
|
|
|
- ``agg`` (:class:`~torch.Tensor`): scalar float tensor with aggregated sum over all inputs received |
|
|
|
Args: |
|
nan_strategy: options: |
|
- ``'error'``: if any `nan` values are encountered will give a RuntimeError |
|
- ``'warn'``: if any `nan` values are encountered will give a warning and continue |
|
- ``'ignore'``: all `nan` values are silently removed |
|
- ``'disable'``: disable all `nan` checks |
|
- a float: if a float is provided will impute any `nan` values with this value |
|
|
|
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. |
|
|
|
Raises: |
|
ValueError: |
|
If ``nan_strategy`` is not one of ``error``, ``warn``, ``ignore``, ``disable`` or a float |
|
|
|
Example: |
|
>>> from torch import tensor |
|
>>> from torchmetrics.aggregation import RunningMean |
|
>>> metric = RunningMean(window=3) |
|
>>> for i in range(6): |
|
... current_val = metric(tensor([i])) |
|
... running_val = metric.compute() |
|
... total_val = tensor(sum(list(range(i+1)))) / (i+1) # total mean over all samples |
|
... print(f"{current_val=}, {running_val=}, {total_val=}") |
|
current_val=tensor(0.), running_val=tensor(0.), total_val=tensor(0.) |
|
current_val=tensor(1.), running_val=tensor(0.5000), total_val=tensor(0.5000) |
|
current_val=tensor(2.), running_val=tensor(1.), total_val=tensor(1.) |
|
current_val=tensor(3.), running_val=tensor(2.), total_val=tensor(1.5000) |
|
current_val=tensor(4.), running_val=tensor(3.), total_val=tensor(2.) |
|
current_val=tensor(5.), running_val=tensor(4.), total_val=tensor(2.5000) |
|
|
|
""" |
|
|
|
def __init__( |
|
self, |
|
window: int = 5, |
|
nan_strategy: Union[Literal["error", "warn", "ignore", "disable"], float] = "warn", |
|
**kwargs: Any, |
|
) -> None: |
|
super().__init__(base_metric=MeanMetric(nan_strategy=nan_strategy, **kwargs), window=window) |
|
|
|
|
|
class RunningSum(Running): |
|
"""Aggregate a stream of value into their sum over a running window. |
|
|
|
Using this metric compared to `SumMetric` allows for calculating metrics over a running window of values, instead |
|
of the whole history of values. This is beneficial when you want to get a better estimate of the metric during |
|
training and don't want to wait for the whole training to finish to get epoch level estimates. |
|
|
|
As input to ``forward`` and ``update`` the metric accepts the following input |
|
|
|
- ``value`` (:class:`~float` or :class:`~torch.Tensor`): a single float or an tensor of float values with |
|
arbitrary shape ``(...,)``. |
|
|
|
As output of `forward` and `compute` the metric returns the following output |
|
|
|
- ``agg`` (:class:`~torch.Tensor`): scalar float tensor with aggregated sum over all inputs received |
|
|
|
Args: |
|
window: The size of the running window. |
|
nan_strategy: options: |
|
- ``'error'``: if any `nan` values are encountered will give a RuntimeError |
|
- ``'warn'``: if any `nan` values are encountered will give a warning and continue |
|
- ``'ignore'``: all `nan` values are silently removed |
|
- ``'disable'``: disable all `nan` checks |
|
- a float: if a float is provided will impute any `nan` values with this value |
|
|
|
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. |
|
|
|
Raises: |
|
ValueError: |
|
If ``nan_strategy`` is not one of ``error``, ``warn``, ``ignore``, ``disable`` or a float |
|
|
|
Example: |
|
>>> from torch import tensor |
|
>>> from torchmetrics.aggregation import RunningSum |
|
>>> metric = RunningSum(window=3) |
|
>>> for i in range(6): |
|
... current_val = metric(tensor([i])) |
|
... running_val = metric.compute() |
|
... total_val = tensor(sum(list(range(i+1)))) # total sum over all samples |
|
... print(f"{current_val=}, {running_val=}, {total_val=}") |
|
current_val=tensor(0.), running_val=tensor(0.), total_val=tensor(0) |
|
current_val=tensor(1.), running_val=tensor(1.), total_val=tensor(1) |
|
current_val=tensor(2.), running_val=tensor(3.), total_val=tensor(3) |
|
current_val=tensor(3.), running_val=tensor(6.), total_val=tensor(6) |
|
current_val=tensor(4.), running_val=tensor(9.), total_val=tensor(10) |
|
current_val=tensor(5.), running_val=tensor(12.), total_val=tensor(15) |
|
|
|
""" |
|
|
|
def __init__( |
|
self, |
|
window: int = 5, |
|
nan_strategy: Union[Literal["error", "warn", "ignore", "disable"], float] = "warn", |
|
**kwargs: Any, |
|
) -> None: |
|
super().__init__(base_metric=SumMetric(nan_strategy=nan_strategy, **kwargs), window=window) |
|
|