jamtur01's picture
Upload folder using huggingface_hub
9c6594c verified
# Copyright The Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections.abc import Sequence
from typing import Any, Callable, Optional, Union
import torch
from torch import Tensor
from typing_extensions import Literal
from torchmetrics.metric import Metric
from torchmetrics.utilities import rank_zero_warn
from torchmetrics.utilities.data import dim_zero_cat
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE
from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE
from torchmetrics.wrappers.running import Running
if not _MATPLOTLIB_AVAILABLE:
__doctest_skip__ = ["SumMetric.plot", "MeanMetric.plot", "MaxMetric.plot", "MinMetric.plot"]
class BaseAggregator(Metric):
"""Base class for aggregation metrics.
Args:
fn: string specifying the reduction function
default_value: default tensor value to use for the metric state
nan_strategy: options:
- ``'error'``: if any `nan` values are encountered will give a RuntimeError
- ``'warn'``: if any `nan` values are encountered will give a warning and continue
- ``'ignore'``: all `nan` values are silently removed
- ``'disable'``: disable all `nan` checks
- a float: if a float is provided will impute any `nan` values with this value
state_name: name of the metric state
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
Raises:
ValueError:
If ``nan_strategy`` is not one of ``error``, ``warn``, ``ignore``, ``disable`` or a float
"""
is_differentiable = None
higher_is_better = None
full_state_update: bool = False
def __init__(
self,
fn: Union[Callable, str],
default_value: Union[Tensor, list],
nan_strategy: Union[Literal["error", "warn", "ignore", "disable"], float] = "error",
state_name: str = "value",
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
allowed_nan_strategy = ("error", "warn", "ignore", "disable")
if nan_strategy not in allowed_nan_strategy and not isinstance(nan_strategy, float):
raise ValueError(
f"Arg `nan_strategy` should either be a float or one of {allowed_nan_strategy} but got {nan_strategy}."
)
self.nan_strategy = nan_strategy
self.add_state(state_name, default=default_value, dist_reduce_fx=fn)
self.state_name = state_name
def _cast_and_nan_check_input(
self, x: Union[float, Tensor], weight: Optional[Union[float, Tensor]] = None
) -> tuple[Tensor, Tensor]:
"""Convert input ``x`` to a tensor and check for Nans."""
if not isinstance(x, Tensor):
x = torch.as_tensor(x, dtype=self.dtype, device=self.device)
if weight is not None and not isinstance(weight, Tensor):
weight = torch.as_tensor(weight, dtype=self.dtype, device=self.device)
if self.nan_strategy != "disable":
nans = torch.isnan(x)
if weight is not None:
nans_weight = torch.isnan(weight)
else:
nans_weight = torch.zeros_like(nans).bool()
weight = torch.ones_like(x)
if nans.any() or nans_weight.any():
if self.nan_strategy == "error":
raise RuntimeError("Encountered `nan` values in tensor")
if self.nan_strategy in ("ignore", "warn"):
if self.nan_strategy == "warn":
rank_zero_warn("Encountered `nan` values in tensor. Will be removed.", UserWarning)
x = x[~(nans | nans_weight)]
weight = weight[~(nans | nans_weight)]
else:
if not isinstance(self.nan_strategy, float):
raise ValueError(f"`nan_strategy` shall be float but you pass {self.nan_strategy}")
x[nans | nans_weight] = self.nan_strategy
weight[nans | nans_weight] = 1
else:
weight = torch.ones_like(x)
return x.to(self.dtype), weight.to(self.dtype)
def update(self, value: Union[float, Tensor]) -> None:
"""Overwrite in child class."""
def compute(self) -> Tensor:
"""Compute the aggregated value."""
return getattr(self, self.state_name)
class MaxMetric(BaseAggregator):
"""Aggregate a stream of value into their maximum value.
As input to ``forward`` and ``update`` the metric accepts the following input
- ``value`` (:class:`~float` or :class:`~torch.Tensor`): a single float or an tensor of float values with
arbitrary shape ``(...,)``.
As output of `forward` and `compute` the metric returns the following output
- ``agg`` (:class:`~torch.Tensor`): scalar float tensor with aggregated maximum value over all inputs received
Args:
nan_strategy: options:
- ``'error'``: if any `nan` values are encountered will give a RuntimeError
- ``'warn'``: if any `nan` values are encountered will give a warning and continue
- ``'ignore'``: all `nan` values are silently removed
- ``'disable'``: disable all `nan` checks
- a float: if a float is provided will impute any `nan` values with this value
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
Raises:
ValueError:
If ``nan_strategy`` is not one of ``error``, ``warn``, ``ignore``, ``disable`` or a float
Example:
>>> from torch import tensor
>>> from torchmetrics.aggregation import MaxMetric
>>> metric = MaxMetric()
>>> metric.update(1)
>>> metric.update(tensor([2, 3]))
>>> metric.compute()
tensor(3.)
"""
full_state_update: bool = True
max_value: Tensor
def __init__(
self,
nan_strategy: Union[Literal["error", "warn", "ignore", "disable"], float] = "warn",
**kwargs: Any,
) -> None:
super().__init__(
"max",
-torch.tensor(float("inf"), dtype=torch.get_default_dtype()),
nan_strategy,
state_name="max_value",
**kwargs,
)
def update(self, value: Union[float, Tensor]) -> None:
"""Update state with data.
Args:
value: Either a float or tensor containing data. Additional tensor
dimensions will be flattened
"""
value, _ = self._cast_and_nan_check_input(value)
if value.numel(): # make sure tensor not empty
self.max_value = torch.max(self.max_value, torch.max(value))
def plot(
self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
) -> _PLOT_OUT_TYPE:
"""Plot a single or multiple values from the metric.
Args:
val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
If no value is provided, will automatically call `metric.compute` and plot that result.
ax: An matplotlib axis object. If provided will add plot to that axis
Returns:
Figure and Axes object
Raises:
ModuleNotFoundError:
If `matplotlib` is not installed
.. plot::
:scale: 75
>>> # Example plotting a single value
>>> from torchmetrics.aggregation import MaxMetric
>>> metric = MaxMetric()
>>> metric.update([1, 2, 3])
>>> fig_, ax_ = metric.plot()
.. plot::
:scale: 75
>>> # Example plotting multiple values
>>> from torchmetrics.aggregation import MaxMetric
>>> metric = MaxMetric()
>>> values = [ ]
>>> for i in range(10):
... values.append(metric(i))
>>> fig_, ax_ = metric.plot(values)
"""
return self._plot(val, ax)
class MinMetric(BaseAggregator):
"""Aggregate a stream of value into their minimum value.
As input to ``forward`` and ``update`` the metric accepts the following input
- ``value`` (:class:`~float` or :class:`~torch.Tensor`): a single float or an tensor of float values with
arbitrary shape ``(...,)``.
As output of `forward` and `compute` the metric returns the following output
- ``agg`` (:class:`~torch.Tensor`): scalar float tensor with aggregated minimum value over all inputs received
Args:
nan_strategy: options:
- ``'error'``: if any `nan` values are encountered will give a RuntimeError
- ``'warn'``: if any `nan` values are encountered will give a warning and continue
- ``'ignore'``: all `nan` values are silently removed
- ``'disable'``: disable all `nan` checks
- a float: if a float is provided will impute any `nan` values with this value
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
Raises:
ValueError:
If ``nan_strategy`` is not one of ``error``, ``warn``, ``ignore``, ``disable`` or a float
Example:
>>> from torch import tensor
>>> from torchmetrics.aggregation import MinMetric
>>> metric = MinMetric()
>>> metric.update(1)
>>> metric.update(tensor([2, 3]))
>>> metric.compute()
tensor(1.)
"""
full_state_update: bool = True
min_value: Tensor
def __init__(
self,
nan_strategy: Union[Literal["error", "warn", "ignore", "disable"], float] = "warn",
**kwargs: Any,
) -> None:
super().__init__(
"min",
torch.tensor(float("inf"), dtype=torch.get_default_dtype()),
nan_strategy,
state_name="min_value",
**kwargs,
)
def update(self, value: Union[float, Tensor]) -> None:
"""Update state with data.
Args:
value: Either a float or tensor containing data. Additional tensor
dimensions will be flattened
"""
value, _ = self._cast_and_nan_check_input(value)
if value.numel(): # make sure tensor not empty
self.min_value = torch.min(self.min_value, torch.min(value))
def plot(
self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
) -> _PLOT_OUT_TYPE:
"""Plot a single or multiple values from the metric.
Args:
val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
If no value is provided, will automatically call `metric.compute` and plot that result.
ax: An matplotlib axis object. If provided will add plot to that axis
Returns:
Figure and Axes object
Raises:
ModuleNotFoundError:
If `matplotlib` is not installed
.. plot::
:scale: 75
>>> # Example plotting a single value
>>> from torchmetrics.aggregation import MinMetric
>>> metric = MinMetric()
>>> metric.update([1, 2, 3])
>>> fig_, ax_ = metric.plot()
.. plot::
:scale: 75
>>> # Example plotting multiple values
>>> from torchmetrics.aggregation import MinMetric
>>> metric = MinMetric()
>>> values = [ ]
>>> for i in range(10):
... values.append(metric(i))
>>> fig_, ax_ = metric.plot(values)
"""
return self._plot(val, ax)
class SumMetric(BaseAggregator):
"""Aggregate a stream of value into their sum.
As input to ``forward`` and ``update`` the metric accepts the following input
- ``value`` (:class:`~float` or :class:`~torch.Tensor`): a single float or an tensor of float values with
arbitrary shape ``(...,)``.
As output of `forward` and `compute` the metric returns the following output
- ``agg`` (:class:`~torch.Tensor`): scalar float tensor with aggregated sum over all inputs received
Args:
nan_strategy: options:
- ``'error'``: if any `nan` values are encountered will give a RuntimeError
- ``'warn'``: if any `nan` values are encountered will give a warning and continue
- ``'ignore'``: all `nan` values are silently removed
- ``'disable'``: disable all `nan` checks
- a float: if a float is provided will impute any `nan` values with this value
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
Raises:
ValueError:
If ``nan_strategy`` is not one of ``error``, ``warn``, ``ignore``, ``disable`` or a float
Example:
>>> from torch import tensor
>>> from torchmetrics.aggregation import SumMetric
>>> metric = SumMetric()
>>> metric.update(1)
>>> metric.update(tensor([2, 3]))
>>> metric.compute()
tensor(6.)
"""
sum_value: Tensor
def __init__(
self,
nan_strategy: Union[Literal["error", "warn", "ignore", "disable"], float] = "warn",
**kwargs: Any,
) -> None:
super().__init__(
"sum",
torch.tensor(0.0, dtype=torch.get_default_dtype()),
nan_strategy,
state_name="sum_value",
**kwargs,
)
def update(self, value: Union[float, Tensor]) -> None:
"""Update state with data.
Args:
value: Either a float or tensor containing data. Additional tensor
dimensions will be flattened
"""
value, _ = self._cast_and_nan_check_input(value)
if value.numel():
self.sum_value += value.sum()
def plot(
self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
) -> _PLOT_OUT_TYPE:
"""Plot a single or multiple values from the metric.
Args:
val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
If no value is provided, will automatically call `metric.compute` and plot that result.
ax: An matplotlib axis object. If provided will add plot to that axis
Returns:
Figure and Axes object
Raises:
ModuleNotFoundError:
If `matplotlib` is not installed
.. plot::
:scale: 75
>>> # Example plotting a single value
>>> from torchmetrics.aggregation import SumMetric
>>> metric = SumMetric()
>>> metric.update([1, 2, 3])
>>> fig_, ax_ = metric.plot()
.. plot::
:scale: 75
>>> # Example plotting multiple values
>>> from torch import rand, randint
>>> from torchmetrics.aggregation import SumMetric
>>> metric = SumMetric()
>>> values = [ ]
>>> for i in range(10):
... values.append(metric([i, i+1]))
>>> fig_, ax_ = metric.plot(values)
"""
return self._plot(val, ax)
class CatMetric(BaseAggregator):
"""Concatenate a stream of values.
As input to ``forward`` and ``update`` the metric accepts the following input
- ``value`` (:class:`~float` or :class:`~torch.Tensor`): a single float or an tensor of float values with
arbitrary shape ``(...,)``.
As output of `forward` and `compute` the metric returns the following output
- ``agg`` (:class:`~torch.Tensor`): scalar float tensor with concatenated values over all input received
Args:
nan_strategy: options:
- ``'error'``: if any `nan` values are encountered will give a RuntimeError
- ``'warn'``: if any `nan` values are encountered will give a warning and continue
- ``'ignore'``: all `nan` values are silently removed
- ``'disable'``: disable all `nan` checks
- a float: if a float is provided will impute any `nan` values with this value
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
Raises:
ValueError:
If ``nan_strategy`` is not one of ``error``, ``warn``, ``ignore``, ``disable`` or a float
Example:
>>> from torch import tensor
>>> from torchmetrics.aggregation import CatMetric
>>> metric = CatMetric()
>>> metric.update(1)
>>> metric.update(tensor([2, 3]))
>>> metric.compute()
tensor([1., 2., 3.])
"""
value: Tensor
def __init__(
self,
nan_strategy: Union[Literal["error", "warn", "ignore", "disable"], float] = "warn",
**kwargs: Any,
) -> None:
super().__init__("cat", [], nan_strategy, **kwargs)
def update(self, value: Union[float, Tensor]) -> None:
"""Update state with data.
Args:
value: Either a float or tensor containing data. Additional tensor
dimensions will be flattened
"""
value, _ = self._cast_and_nan_check_input(value)
if value.numel():
self.value.append(value)
def compute(self) -> Tensor:
"""Compute the aggregated value."""
if isinstance(self.value, list) and self.value:
return dim_zero_cat(self.value)
return self.value
class MeanMetric(BaseAggregator):
"""Aggregate a stream of value into their mean value.
As input to ``forward`` and ``update`` the metric accepts the following input
- ``value`` (:class:`~float` or :class:`~torch.Tensor`): a single float or an tensor of float values with
arbitrary shape ``(...,)``.
- ``weight`` (:class:`~float` or :class:`~torch.Tensor`): a single float or an tensor of float value with
arbitrary shape ``(...,)``. Needs to be broadcastable with the shape of ``value`` tensor.
As output of `forward` and `compute` the metric returns the following output
- ``agg`` (:class:`~torch.Tensor`): scalar float tensor with aggregated (weighted) mean over all inputs received
Args:
nan_strategy: options:
- ``'error'``: if any `nan` values are encountered will give a RuntimeError
- ``'warn'``: if any `nan` values are encountered will give a warning and continue
- ``'ignore'``: all `nan` values are silently removed
- ``'disable'``: disable all `nan` checks
- a float: if a float is provided will impute any `nan` values with this value
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
Raises:
ValueError:
If ``nan_strategy`` is not one of ``error``, ``warn``, ``ignore``, ``disable`` or a float
Example:
>>> from torchmetrics.aggregation import MeanMetric
>>> metric = MeanMetric()
>>> metric.update(1)
>>> metric.update(torch.tensor([2, 3]))
>>> metric.compute()
tensor(2.)
"""
mean_value: Tensor
weight: Tensor
def __init__(
self,
nan_strategy: Union[Literal["error", "warn", "ignore", "disable"], float] = "warn",
**kwargs: Any,
) -> None:
super().__init__(
"sum",
torch.tensor(0.0, dtype=torch.get_default_dtype()),
nan_strategy,
state_name="mean_value",
**kwargs,
)
self.add_state("weight", default=torch.tensor(0.0, dtype=torch.get_default_dtype()), dist_reduce_fx="sum")
def update(self, value: Union[float, Tensor], weight: Union[float, Tensor, None] = None) -> None:
"""Update state with data.
Args:
value: Either a float or tensor containing data. Additional tensor
dimensions will be flattened
weight: Either a float or tensor containing weights for calculating
the average. Shape of weight should be able to broadcast with
the shape of `value`. Default to None corresponding to simple
harmonic average.
"""
# broadcast weight to value shape
if not isinstance(value, Tensor):
value = torch.as_tensor(value, dtype=self.dtype, device=self.device)
if weight is None:
weight = torch.ones_like(value)
elif not isinstance(weight, Tensor):
weight = torch.as_tensor(weight, dtype=self.dtype, device=self.device)
weight = torch.broadcast_to(weight, value.shape)
value, weight = self._cast_and_nan_check_input(value, weight)
if value.numel() == 0:
return
self.mean_value += (value * weight).sum()
self.weight += weight.sum()
def compute(self) -> Tensor:
"""Compute the aggregated value."""
return self.mean_value / self.weight
def plot(
self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
) -> _PLOT_OUT_TYPE:
"""Plot a single or multiple values from the metric.
Args:
val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
If no value is provided, will automatically call `metric.compute` and plot that result.
ax: An matplotlib axis object. If provided will add plot to that axis
Returns:
Figure and Axes object
Raises:
ModuleNotFoundError:
If `matplotlib` is not installed
.. plot::
:scale: 75
>>> # Example plotting a single value
>>> from torchmetrics.aggregation import MeanMetric
>>> metric = MeanMetric()
>>> metric.update([1, 2, 3])
>>> fig_, ax_ = metric.plot()
.. plot::
:scale: 75
>>> # Example plotting multiple values
>>> from torchmetrics.aggregation import MeanMetric
>>> metric = MeanMetric()
>>> values = [ ]
>>> for i in range(10):
... values.append(metric([i, i+1]))
>>> fig_, ax_ = metric.plot(values)
"""
return self._plot(val, ax)
class RunningMean(Running):
"""Aggregate a stream of value into their mean over a running window.
Using this metric compared to `MeanMetric` allows for calculating metrics over a running window of values, instead
of the whole history of values. This is beneficial when you want to get a better estimate of the metric during
training and don't want to wait for the whole training to finish to get epoch level estimates.
As input to ``forward`` and ``update`` the metric accepts the following input
- ``value`` (:class:`~float` or :class:`~torch.Tensor`): a single float or an tensor of float values with
arbitrary shape ``(...,)``.
As output of `forward` and `compute` the metric returns the following output
- ``agg`` (:class:`~torch.Tensor`): scalar float tensor with aggregated sum over all inputs received
Args:
nan_strategy: options:
- ``'error'``: if any `nan` values are encountered will give a RuntimeError
- ``'warn'``: if any `nan` values are encountered will give a warning and continue
- ``'ignore'``: all `nan` values are silently removed
- ``'disable'``: disable all `nan` checks
- a float: if a float is provided will impute any `nan` values with this value
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
Raises:
ValueError:
If ``nan_strategy`` is not one of ``error``, ``warn``, ``ignore``, ``disable`` or a float
Example:
>>> from torch import tensor
>>> from torchmetrics.aggregation import RunningMean
>>> metric = RunningMean(window=3)
>>> for i in range(6):
... current_val = metric(tensor([i]))
... running_val = metric.compute()
... total_val = tensor(sum(list(range(i+1)))) / (i+1) # total mean over all samples
... print(f"{current_val=}, {running_val=}, {total_val=}")
current_val=tensor(0.), running_val=tensor(0.), total_val=tensor(0.)
current_val=tensor(1.), running_val=tensor(0.5000), total_val=tensor(0.5000)
current_val=tensor(2.), running_val=tensor(1.), total_val=tensor(1.)
current_val=tensor(3.), running_val=tensor(2.), total_val=tensor(1.5000)
current_val=tensor(4.), running_val=tensor(3.), total_val=tensor(2.)
current_val=tensor(5.), running_val=tensor(4.), total_val=tensor(2.5000)
"""
def __init__(
self,
window: int = 5,
nan_strategy: Union[Literal["error", "warn", "ignore", "disable"], float] = "warn",
**kwargs: Any,
) -> None:
super().__init__(base_metric=MeanMetric(nan_strategy=nan_strategy, **kwargs), window=window)
class RunningSum(Running):
"""Aggregate a stream of value into their sum over a running window.
Using this metric compared to `SumMetric` allows for calculating metrics over a running window of values, instead
of the whole history of values. This is beneficial when you want to get a better estimate of the metric during
training and don't want to wait for the whole training to finish to get epoch level estimates.
As input to ``forward`` and ``update`` the metric accepts the following input
- ``value`` (:class:`~float` or :class:`~torch.Tensor`): a single float or an tensor of float values with
arbitrary shape ``(...,)``.
As output of `forward` and `compute` the metric returns the following output
- ``agg`` (:class:`~torch.Tensor`): scalar float tensor with aggregated sum over all inputs received
Args:
window: The size of the running window.
nan_strategy: options:
- ``'error'``: if any `nan` values are encountered will give a RuntimeError
- ``'warn'``: if any `nan` values are encountered will give a warning and continue
- ``'ignore'``: all `nan` values are silently removed
- ``'disable'``: disable all `nan` checks
- a float: if a float is provided will impute any `nan` values with this value
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
Raises:
ValueError:
If ``nan_strategy`` is not one of ``error``, ``warn``, ``ignore``, ``disable`` or a float
Example:
>>> from torch import tensor
>>> from torchmetrics.aggregation import RunningSum
>>> metric = RunningSum(window=3)
>>> for i in range(6):
... current_val = metric(tensor([i]))
... running_val = metric.compute()
... total_val = tensor(sum(list(range(i+1)))) # total sum over all samples
... print(f"{current_val=}, {running_val=}, {total_val=}")
current_val=tensor(0.), running_val=tensor(0.), total_val=tensor(0)
current_val=tensor(1.), running_val=tensor(1.), total_val=tensor(1)
current_val=tensor(2.), running_val=tensor(3.), total_val=tensor(3)
current_val=tensor(3.), running_val=tensor(6.), total_val=tensor(6)
current_val=tensor(4.), running_val=tensor(9.), total_val=tensor(10)
current_val=tensor(5.), running_val=tensor(12.), total_val=tensor(15)
"""
def __init__(
self,
window: int = 5,
nan_strategy: Union[Literal["error", "warn", "ignore", "disable"], float] = "warn",
**kwargs: Any,
) -> None:
super().__init__(base_metric=SumMetric(nan_strategy=nan_strategy, **kwargs), window=window)