# Copyright The Lightning team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from collections.abc import Sequence from typing import Any, Optional, Union import torch from torch import Tensor from typing_extensions import Literal from torchmetrics.classification.base import _ClassificationTaskWrapper from torchmetrics.functional.classification.hinge import ( _binary_confusion_matrix_format, _binary_hinge_loss_arg_validation, _binary_hinge_loss_tensor_validation, _binary_hinge_loss_update, _hinge_loss_compute, _multiclass_confusion_matrix_format, _multiclass_hinge_loss_arg_validation, _multiclass_hinge_loss_tensor_validation, _multiclass_hinge_loss_update, ) from torchmetrics.metric import Metric from torchmetrics.utilities.enums import ClassificationTaskNoMultilabel from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE if not _MATPLOTLIB_AVAILABLE: __doctest_skip__ = ["BinaryHingeLoss.plot", "MulticlassHingeLoss.plot"] class BinaryHingeLoss(Metric): r"""Compute the mean `Hinge loss`_ typically used for Support Vector Machines (SVMs) for binary tasks. .. math:: \text{Hinge loss} = \max(0, 1 - y \times \hat{y}) Where :math:`y \in {-1, 1}` is the target, and :math:`\hat{y} \in \mathbb{R}` is the prediction. As input to ``forward`` and ``update`` the metric accepts the following input: - ``preds`` (:class:`~torch.Tensor`): A float tensor of shape ``(N, ...)``. Preds should be a tensor containing probabilities or logits for each observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element. - ``target`` (:class:`~torch.Tensor`): An int tensor of shape ``(N, ...)``. Target should be a tensor containing ground truth labels, and therefore only contain {0,1} values (except if `ignore_index` is specified). The value 1 always encodes the positive class. .. tip:: Additional dimension ``...`` will be flattened into the batch dimension. As output to ``forward`` and ``compute`` the metric returns the following output: - ``bhl`` (:class:`~torch.Tensor`): A tensor containing the hinge loss. Args: squared: If True, this will compute the squared hinge loss. Otherwise, computes the regular hinge loss. ignore_index: Specifies a target value that is ignored and does not contribute to the metric calculation validate_args: bool indicating if input arguments and tensors should be validated for correctness. Set to ``False`` for faster computations. kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. Example: >>> from torchmetrics.classification import BinaryHingeLoss >>> preds = torch.tensor([0.25, 0.25, 0.55, 0.75, 0.75]) >>> target = torch.tensor([0, 0, 1, 1, 1]) >>> bhl = BinaryHingeLoss() >>> bhl(preds, target) tensor(0.6900) >>> bhl = BinaryHingeLoss(squared=True) >>> bhl(preds, target) tensor(0.6905) """ is_differentiable: bool = True higher_is_better: bool = False full_state_update: bool = False plot_lower_bound: float = 0.0 plot_upper_bound: float = 1.0 measures: Tensor total: Tensor def __init__( self, squared: bool = False, ignore_index: Optional[int] = None, validate_args: bool = True, **kwargs: Any, ) -> None: super().__init__(**kwargs) if validate_args: _binary_hinge_loss_arg_validation(squared, ignore_index) self.validate_args = validate_args self.squared = squared self.ignore_index = ignore_index self.add_state("measures", default=torch.tensor(0.0), dist_reduce_fx="sum") self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") def update(self, preds: Tensor, target: Tensor) -> None: """Update metric state.""" if self.validate_args: _binary_hinge_loss_tensor_validation(preds, target, self.ignore_index) preds, target = _binary_confusion_matrix_format( preds, target, threshold=0.0, ignore_index=self.ignore_index, convert_to_labels=False ) measures, total = _binary_hinge_loss_update(preds, target, self.squared) self.measures += measures self.total += total def compute(self) -> Tensor: """Compute metric.""" return _hinge_loss_compute(self.measures, self.total) def plot( self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None ) -> _PLOT_OUT_TYPE: """Plot a single or multiple values from the metric. Args: val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. If no value is provided, will automatically call `metric.compute` and plot that result. ax: An matplotlib axis object. If provided will add plot to that axis Returns: Figure object and Axes object Raises: ModuleNotFoundError: If `matplotlib` is not installed .. plot:: :scale: 75 >>> # Example plotting a single value >>> from torch import rand, randint >>> from torchmetrics.classification import BinaryHingeLoss >>> metric = BinaryHingeLoss() >>> metric.update(rand(10), randint(2,(10,))) >>> fig_, ax_ = metric.plot() .. plot:: :scale: 75 >>> # Example plotting multiple values >>> from torch import rand, randint >>> from torchmetrics.classification import BinaryHingeLoss >>> metric = BinaryHingeLoss() >>> values = [ ] >>> for _ in range(10): ... values.append(metric(rand(10), randint(2,(10,)))) >>> fig_, ax_ = metric.plot(values) """ return self._plot(val, ax) class MulticlassHingeLoss(Metric): r"""Compute the mean `Hinge loss`_ typically used for Support Vector Machines (SVMs) for multiclass tasks. The metric can be computed in two ways. Either, the definition by Crammer and Singer is used: .. math:: \text{Hinge loss} = \max\left(0, 1 - \hat{y}_y + \max_{i \ne y} (\hat{y}_i)\right) Where :math:`y \in {0, ..., \mathrm{C}}` is the target class (where :math:`\mathrm{C}` is the number of classes), and :math:`\hat{y} \in \mathbb{R}^\mathrm{C}` is the predicted output per class. Alternatively, the metric can also be computed in one-vs-all approach, where each class is valued against all other classes in a binary fashion. As input to ``forward`` and ``update`` the metric accepts the following input: - ``preds`` (:class:`~torch.Tensor`): A float tensor of shape ``(N, C, ...)``. Preds should be a tensor containing probabilities or logits for each observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply softmax per sample. - ``target`` (:class:`~torch.Tensor`): An int tensor of shape ``(N, ...)``. Target should be a tensor containing ground truth labels, and therefore only contain values in the [0, n_classes-1] range (except if `ignore_index` is specified). .. tip:: Additional dimension ``...`` will be flattened into the batch dimension. As output to ``forward`` and ``compute`` the metric returns the following output: - ``mchl`` (:class:`~torch.Tensor`): A tensor containing the multi-class hinge loss. Args: num_classes: Integer specifying the number of classes squared: If True, this will compute the squared hinge loss. Otherwise, computes the regular hinge loss. multiclass_mode: Determines how to compute the metric ignore_index: Specifies a target value that is ignored and does not contribute to the metric calculation validate_args: bool indicating if input arguments and tensors should be validated for correctness. Set to ``False`` for faster computations. kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. Example: >>> from torchmetrics.classification import MulticlassHingeLoss >>> preds = torch.tensor([[0.25, 0.20, 0.55], ... [0.55, 0.05, 0.40], ... [0.10, 0.30, 0.60], ... [0.90, 0.05, 0.05]]) >>> target = torch.tensor([0, 1, 2, 0]) >>> mchl = MulticlassHingeLoss(num_classes=3) >>> mchl(preds, target) tensor(0.9125) >>> mchl = MulticlassHingeLoss(num_classes=3, squared=True) >>> mchl(preds, target) tensor(1.1131) >>> mchl = MulticlassHingeLoss(num_classes=3, multiclass_mode='one-vs-all') >>> mchl(preds, target) tensor([0.8750, 1.1250, 1.1000]) """ is_differentiable: bool = True higher_is_better: bool = False full_state_update: bool = False plot_lower_bound: float = 0.0 plot_upper_bound: float = 1.0 plot_legend_name: str = "Class" measures: Tensor total: Tensor def __init__( self, num_classes: int, squared: bool = False, multiclass_mode: Literal["crammer-singer", "one-vs-all"] = "crammer-singer", ignore_index: Optional[int] = None, validate_args: bool = True, **kwargs: Any, ) -> None: super().__init__(**kwargs) if validate_args: _multiclass_hinge_loss_arg_validation(num_classes, squared, multiclass_mode, ignore_index) self.validate_args = validate_args self.num_classes = num_classes self.squared = squared self.multiclass_mode = multiclass_mode self.ignore_index = ignore_index self.add_state( "measures", default=torch.tensor(0.0) if self.multiclass_mode == "crammer-singer" else torch.zeros( num_classes, ), dist_reduce_fx="sum", ) self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") def update(self, preds: Tensor, target: Tensor) -> None: """Update metric state.""" if self.validate_args: _multiclass_hinge_loss_tensor_validation(preds, target, self.num_classes, self.ignore_index) preds, target = _multiclass_confusion_matrix_format(preds, target, self.ignore_index, convert_to_labels=False) measures, total = _multiclass_hinge_loss_update(preds, target, self.squared, self.multiclass_mode) self.measures += measures self.total += total def compute(self) -> Tensor: """Compute metric.""" return _hinge_loss_compute(self.measures, self.total) def plot( self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None ) -> _PLOT_OUT_TYPE: """Plot a single or multiple values from the metric. Args: val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. If no value is provided, will automatically call `metric.compute` and plot that result. ax: An matplotlib axis object. If provided will add plot to that axis Returns: Figure object and Axes object Raises: ModuleNotFoundError: If `matplotlib` is not installed .. plot:: :scale: 75 >>> # Example plotting a single value per class >>> from torch import randint, randn >>> from torchmetrics.classification import MulticlassHingeLoss >>> metric = MulticlassHingeLoss(num_classes=3) >>> metric.update(randn(20, 3), randint(3, (20,))) >>> fig_, ax_ = metric.plot() .. plot:: :scale: 75 >>> # Example plotting a multiple values per class >>> from torch import randint, randn >>> from torchmetrics.classification import MulticlassHingeLoss >>> metric = MulticlassHingeLoss(num_classes=3) >>> values = [] >>> for _ in range(20): ... values.append(metric(randn(20, 3), randint(3, (20,)))) >>> fig_, ax_ = metric.plot(values) """ return self._plot(val, ax) class HingeLoss(_ClassificationTaskWrapper): r"""Compute the mean `Hinge loss`_ typically used for Support Vector Machines (SVMs). This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the ``task`` argument to either ``'binary'`` or ``'multiclass'``. See the documentation of :class:`~torchmetrics.classification.BinaryHingeLoss` and :class:`~torchmetrics.classification.MulticlassHingeLoss` for the specific details of each argument influence and examples. Legacy Example: >>> from torch import tensor >>> target = tensor([0, 1, 1]) >>> preds = tensor([0.5, 0.7, 0.1]) >>> hinge = HingeLoss(task="binary") >>> hinge(preds, target) tensor(0.9000) >>> target = tensor([0, 1, 2]) >>> preds = tensor([[-1.0, 0.9, 0.2], [0.5, -1.1, 0.8], [2.2, -0.5, 0.3]]) >>> hinge = HingeLoss(task="multiclass", num_classes=3) >>> hinge(preds, target) tensor(1.5551) >>> target = tensor([0, 1, 2]) >>> preds = tensor([[-1.0, 0.9, 0.2], [0.5, -1.1, 0.8], [2.2, -0.5, 0.3]]) >>> hinge = HingeLoss(task="multiclass", num_classes=3, multiclass_mode="one-vs-all") >>> hinge(preds, target) tensor([1.3743, 1.1945, 1.2359]) """ def __new__( # type: ignore[misc] cls: type["HingeLoss"], task: Literal["binary", "multiclass"], num_classes: Optional[int] = None, squared: bool = False, multiclass_mode: Optional[Literal["crammer-singer", "one-vs-all"]] = "crammer-singer", ignore_index: Optional[int] = None, validate_args: bool = True, **kwargs: Any, ) -> Metric: """Initialize task metric.""" task = ClassificationTaskNoMultilabel.from_str(task) kwargs.update({"ignore_index": ignore_index, "validate_args": validate_args}) if task == ClassificationTaskNoMultilabel.BINARY: return BinaryHingeLoss(squared, **kwargs) if task == ClassificationTaskNoMultilabel.MULTICLASS: if not isinstance(num_classes, int): raise ValueError(f"`num_classes` is expected to be `int` but `{type(num_classes)} was passed.`") if multiclass_mode not in ("crammer-singer", "one-vs-all"): raise ValueError( f"`multiclass_mode` is expected to be one of 'crammer-singer' or 'one-vs-all' but " f"`{multiclass_mode}` was passed." ) return MulticlassHingeLoss(num_classes, squared, multiclass_mode, **kwargs) raise ValueError(f"Unsupported task `{task}`")