# Copyright The Lightning team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from collections.abc import Sequence from copy import deepcopy from typing import Any, Optional, Union import torch from lightning_utilities import apply_to_collection from torch import Tensor from torch.nn import ModuleList from torchmetrics.metric import Metric from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE from torchmetrics.wrappers.abstract import WrapperMetric if not _MATPLOTLIB_AVAILABLE: __doctest_skip__ = ["MultioutputWrapper.plot"] def _get_nan_indices(*tensors: Tensor) -> Tensor: """Get indices of rows along dim 0 which have NaN values.""" if len(tensors) == 0: raise ValueError("Must pass at least one tensor as argument") sentinel = tensors[0] nan_idxs = torch.zeros(len(sentinel), dtype=torch.bool, device=sentinel.device) for tensor in tensors: permuted_tensor = tensor.flatten(start_dim=1) nan_idxs |= torch.any(torch.isnan(permuted_tensor), dim=1) return nan_idxs class MultioutputWrapper(WrapperMetric): """Wrap a base metric to enable it to support multiple outputs. Several torchmetrics metrics, such as :class:`~torchmetrics.regression.spearman.SpearmanCorrCoef` lack support for multioutput mode. This class wraps such metrics to support computing one metric per output. Unlike specific torchmetric metrics, it doesn't support any aggregation across outputs. This means if you set ``num_outputs`` to 2, ``.compute()`` will return a Tensor of dimension ``(2, ...)`` where ``...`` represents the dimensions the metric returns when not wrapped. In addition to enabling multioutput support for metrics that lack it, this class also supports, albeit in a crude fashion, dealing with missing labels (or other data). When ``remove_nans`` is passed, the class will remove the intersection of NaN containing "rows" upon each update for each output. For example, suppose a user uses `MultioutputWrapper` to wrap :class:`torchmetrics.regression.r2.R2Score` with 2 outputs, one of which occasionally has missing labels for classes like ``R2Score`` is that this class supports removing ``NaN`` values (parameter ``remove_nans``) on a per-output basis. When ``remove_nans`` is passed the wrapper will remove all rows Args: base_metric: Metric being wrapped. num_outputs: Expected dimensionality of the output dimension. This parameter is used to determine the number of distinct metrics we need to track. output_dim: Dimension on which output is expected. Note that while this provides some flexibility, the output dimension must be the same for all inputs to update. This applies even for metrics such as `Accuracy` where the labels can have a different number of dimensions than the predictions. This can be worked around if the output dimension can be set to -1 for both, even if -1 corresponds to different dimensions in different inputs. remove_nans: Whether to remove the intersection of rows containing NaNs from the values passed through to each underlying metric. Proper operation requires all tensors passed to update to have dimension ``(N, ...)`` where N represents the length of the batch or dataset being passed in. squeeze_outputs: If ``True``, will squeeze the 1-item dimensions left after ``index_select`` is applied. This is sometimes unnecessary but harmless for metrics such as `R2Score` but useful for certain classification metrics that can't handle additional 1-item dimensions. Example: >>> # Mimic R2Score in `multioutput`, `raw_values` mode: >>> import torch >>> from torchmetrics.wrappers import MultioutputWrapper >>> from torchmetrics.regression import R2Score >>> target = torch.tensor([[0.5, 1], [-1, 1], [7, -6]]) >>> preds = torch.tensor([[0, 2], [-1, 2], [8, -5]]) >>> r2score = MultioutputWrapper(R2Score(), 2) >>> r2score(preds, target) tensor([0.9654, 0.9082]) """ is_differentiable = False def __init__( self, base_metric: Metric, num_outputs: int, output_dim: int = -1, remove_nans: bool = True, squeeze_outputs: bool = True, ) -> None: super().__init__() self.metrics = ModuleList([deepcopy(base_metric) for _ in range(num_outputs)]) self.output_dim = output_dim self.remove_nans = remove_nans self.squeeze_outputs = squeeze_outputs def _get_args_kwargs_by_output(self, *args: Tensor, **kwargs: Tensor) -> list[tuple[Tensor, Tensor]]: """Get args and kwargs reshaped to be output-specific and (maybe) having NaNs stripped out.""" args_kwargs_by_output = [] for i in range(len(self.metrics)): selected_args = apply_to_collection( args, Tensor, torch.index_select, dim=self.output_dim, index=torch.tensor(i, device=self.device) ) selected_kwargs = apply_to_collection( kwargs, Tensor, torch.index_select, dim=self.output_dim, index=torch.tensor(i, device=self.device) ) if self.remove_nans: args_kwargs = selected_args + tuple(selected_kwargs.values()) nan_idxs = _get_nan_indices(*args_kwargs) selected_args = [arg[~nan_idxs] for arg in selected_args] selected_kwargs = {k: v[~nan_idxs] for k, v in selected_kwargs.items()} if self.squeeze_outputs: selected_args = [arg.squeeze(self.output_dim) for arg in selected_args] selected_kwargs = {k: v.squeeze(self.output_dim) for k, v in selected_kwargs.items()} args_kwargs_by_output.append((selected_args, selected_kwargs)) return args_kwargs_by_output def update(self, *args: Any, **kwargs: Any) -> None: """Update each underlying metric with the corresponding output.""" reshaped_args_kwargs = self._get_args_kwargs_by_output(*args, **kwargs) for metric, (selected_args, selected_kwargs) in zip(self.metrics, reshaped_args_kwargs): metric.update(*selected_args, **selected_kwargs) def compute(self) -> Tensor: """Compute metrics.""" return torch.stack([m.compute() for m in self.metrics], 0) @torch.jit.unused def forward(self, *args: Any, **kwargs: Any) -> Any: """Call underlying forward methods and aggregate the results if they're non-null. We override this method to ensure that state variables get copied over on the underlying metrics. """ reshaped_args_kwargs = self._get_args_kwargs_by_output(*args, **kwargs) results = [ metric(*selected_args, **selected_kwargs) for metric, (selected_args, selected_kwargs) in zip(self.metrics, reshaped_args_kwargs) ] if results[0] is None: return None return torch.stack(results, 0) def reset(self) -> None: """Reset all underlying metrics.""" for metric in self.metrics: metric.reset() super().reset() def plot( self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None ) -> _PLOT_OUT_TYPE: """Plot a single or multiple values from the metric. Args: val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. If no value is provided, will automatically call `metric.compute` and plot that result. ax: An matplotlib axis object. If provided will add plot to that axis Returns: Figure and Axes object Raises: ModuleNotFoundError: If `matplotlib` is not installed .. plot:: :scale: 75 >>> # Example plotting a single value >>> import torch >>> from torchmetrics.wrappers import MultioutputWrapper >>> from torchmetrics.regression import R2Score >>> metric = MultioutputWrapper(R2Score(), 2) >>> metric.update(torch.randn(20, 2), torch.randn(20, 2)) >>> fig_, ax_ = metric.plot() .. plot:: :scale: 75 >>> # Example plotting multiple values >>> import torch >>> from torchmetrics.wrappers import MultioutputWrapper >>> from torchmetrics.regression import R2Score >>> metric = MultioutputWrapper(R2Score(), 2) >>> values = [ ] >>> for _ in range(3): ... values.append(metric(torch.randn(20, 2), torch.randn(20, 2))) >>> fig_, ax_ = metric.plot(values) """ return self._plot(val, ax)