|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Optional |
|
|
|
import torch |
|
from torch import Tensor, tensor |
|
from typing_extensions import Literal |
|
|
|
from torchmetrics.functional.classification.confusion_matrix import ( |
|
_binary_confusion_matrix_format, |
|
_binary_confusion_matrix_tensor_validation, |
|
_multiclass_confusion_matrix_format, |
|
_multiclass_confusion_matrix_tensor_validation, |
|
) |
|
from torchmetrics.utilities.compute import normalize_logits_if_needed |
|
from torchmetrics.utilities.data import to_onehot |
|
from torchmetrics.utilities.enums import ClassificationTaskNoMultilabel |
|
|
|
|
|
def _hinge_loss_compute(measure: Tensor, total: Tensor) -> Tensor: |
|
return measure / total |
|
|
|
|
|
def _binary_hinge_loss_arg_validation(squared: bool, ignore_index: Optional[int] = None) -> None: |
|
if not isinstance(squared, bool): |
|
raise ValueError(f"Expected argument `squared` to be an bool but got {squared}") |
|
if ignore_index is not None and not isinstance(ignore_index, int): |
|
raise ValueError(f"Expected argument `ignore_index` to either be `None` or an integer, but got {ignore_index}") |
|
|
|
|
|
def _binary_hinge_loss_tensor_validation(preds: Tensor, target: Tensor, ignore_index: Optional[int] = None) -> None: |
|
_binary_confusion_matrix_tensor_validation(preds, target, ignore_index) |
|
if not preds.is_floating_point(): |
|
raise ValueError( |
|
"Expected argument `preds` to be floating tensor with probabilities/logits" |
|
f" but got tensor with dtype {preds.dtype}" |
|
) |
|
|
|
|
|
def _binary_hinge_loss_update( |
|
preds: Tensor, |
|
target: Tensor, |
|
squared: bool, |
|
) -> tuple[Tensor, Tensor]: |
|
target = target.bool() |
|
margin = torch.zeros_like(preds) |
|
margin[target] = preds[target] |
|
margin[~target] = -preds[~target] |
|
|
|
measures = 1 - margin |
|
measures = torch.clamp(measures, 0) |
|
|
|
if squared: |
|
measures = measures.pow(2) |
|
|
|
total = tensor(target.shape[0], device=target.device) |
|
return measures.sum(dim=0), total |
|
|
|
|
|
def binary_hinge_loss( |
|
preds: Tensor, |
|
target: Tensor, |
|
squared: bool = False, |
|
ignore_index: Optional[int] = None, |
|
validate_args: bool = False, |
|
) -> Tensor: |
|
r"""Compute the mean `Hinge loss`_ typically used for Support Vector Machines (SVMs) for binary tasks. |
|
|
|
.. math:: |
|
\text{Hinge loss} = \max(0, 1 - y \times \hat{y}) |
|
|
|
Where :math:`y \in {-1, 1}` is the target, and :math:`\hat{y} \in \mathbb{R}` is the prediction. |
|
|
|
Accepts the following input tensors: |
|
|
|
- ``preds`` (float tensor): ``(N, ...)``. Preds should be a tensor containing probabilities or logits for each |
|
observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply |
|
sigmoid per element. |
|
- ``target`` (int tensor): ``(N, ...)``. Target should be a tensor containing ground truth labels, and therefore |
|
only contain {0,1} values (except if `ignore_index` is specified). The value 1 always encodes the positive class. |
|
|
|
Additional dimension ``...`` will be flattened into the batch dimension. |
|
|
|
Args: |
|
preds: Tensor with predictions |
|
target: Tensor with true labels |
|
squared: |
|
If True, this will compute the squared hinge loss. Otherwise, computes the regular hinge loss. |
|
ignore_index: |
|
Specifies a target value that is ignored and does not contribute to the metric calculation |
|
validate_args: bool indicating if input arguments and tensors should be validated for correctness. |
|
Set to ``False`` for faster computations. |
|
|
|
Example: |
|
>>> from torch import tensor |
|
>>> from torchmetrics.functional.classification import binary_hinge_loss |
|
>>> preds = tensor([0.25, 0.25, 0.55, 0.75, 0.75]) |
|
>>> target = tensor([0, 0, 1, 1, 1]) |
|
>>> binary_hinge_loss(preds, target) |
|
tensor(0.6900) |
|
>>> binary_hinge_loss(preds, target, squared=True) |
|
tensor(0.6905) |
|
|
|
""" |
|
if validate_args: |
|
_binary_hinge_loss_arg_validation(squared, ignore_index) |
|
_binary_hinge_loss_tensor_validation(preds, target, ignore_index) |
|
preds, target = _binary_confusion_matrix_format( |
|
preds, target, threshold=0.0, ignore_index=ignore_index, convert_to_labels=False |
|
) |
|
measures, total = _binary_hinge_loss_update(preds, target, squared) |
|
return _hinge_loss_compute(measures, total) |
|
|
|
|
|
def _multiclass_hinge_loss_arg_validation( |
|
num_classes: int, |
|
squared: bool = False, |
|
multiclass_mode: Literal["crammer-singer", "one-vs-all"] = "crammer-singer", |
|
ignore_index: Optional[int] = None, |
|
) -> None: |
|
_binary_hinge_loss_arg_validation(squared, ignore_index) |
|
if not isinstance(num_classes, int) or num_classes < 2: |
|
raise ValueError(f"Expected argument `num_classes` to be an integer larger than 1, but got {num_classes}") |
|
allowed_mm = ("crammer-singer", "one-vs-all") |
|
if multiclass_mode not in allowed_mm: |
|
raise ValueError(f"Expected argument `multiclass_mode` to be one of {allowed_mm}, but got {multiclass_mode}.") |
|
|
|
|
|
def _multiclass_hinge_loss_tensor_validation( |
|
preds: Tensor, target: Tensor, num_classes: int, ignore_index: Optional[int] = None |
|
) -> None: |
|
_multiclass_confusion_matrix_tensor_validation(preds, target, num_classes, ignore_index) |
|
if not preds.is_floating_point(): |
|
raise ValueError( |
|
"Expected argument `preds` to be floating tensor with probabilities/logits" |
|
f" but got tensor with dtype {preds.dtype}" |
|
) |
|
|
|
|
|
def _multiclass_hinge_loss_update( |
|
preds: Tensor, |
|
target: Tensor, |
|
squared: bool, |
|
multiclass_mode: Literal["crammer-singer", "one-vs-all"] = "crammer-singer", |
|
) -> tuple[Tensor, Tensor]: |
|
preds = normalize_logits_if_needed(preds, "softmax") |
|
target = to_onehot(target, max(2, preds.shape[1])).bool() |
|
if multiclass_mode == "crammer-singer": |
|
margin = preds[target] |
|
margin -= torch.max(preds[~target].view(preds.shape[0], -1), dim=1)[0] |
|
else: |
|
target = target.bool() |
|
margin = torch.zeros_like(preds) |
|
margin[target] = preds[target] |
|
margin[~target] = -preds[~target] |
|
|
|
measures = 1 - margin |
|
measures = torch.clamp(measures, 0) |
|
|
|
if squared: |
|
measures = measures.pow(2) |
|
|
|
total = tensor(target.shape[0], device=target.device) |
|
return measures.sum(dim=0), total |
|
|
|
|
|
def multiclass_hinge_loss( |
|
preds: Tensor, |
|
target: Tensor, |
|
num_classes: int, |
|
squared: bool = False, |
|
multiclass_mode: Literal["crammer-singer", "one-vs-all"] = "crammer-singer", |
|
ignore_index: Optional[int] = None, |
|
validate_args: bool = False, |
|
) -> Tensor: |
|
r"""Compute the mean `Hinge loss`_ typically used for Support Vector Machines (SVMs) for multiclass tasks. |
|
|
|
The metric can be computed in two ways. Either, the definition by Crammer and Singer is used: |
|
|
|
.. math:: |
|
\text{Hinge loss} = \max\left(0, 1 - \hat{y}_y + \max_{i \ne y} (\hat{y}_i)\right) |
|
|
|
Where :math:`y \in {0, ..., \mathrm{C}}` is the target class (where :math:`\mathrm{C}` is the number of classes), |
|
and :math:`\hat{y} \in \mathbb{R}^\mathrm{C}` is the predicted output per class. Alternatively, the metric can |
|
also be computed in one-vs-all approach, where each class is valued against all other classes in a binary fashion. |
|
|
|
Accepts the following input tensors: |
|
|
|
- ``preds`` (float tensor): ``(N, C, ...)``. Preds should be a tensor containing probabilities or logits for each |
|
observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply |
|
softmax per sample. |
|
- ``target`` (int tensor): ``(N, ...)``. Target should be a tensor containing ground truth labels, and therefore |
|
only contain values in the [0, n_classes-1] range (except if `ignore_index` is specified). |
|
|
|
Additional dimension ``...`` will be flattened into the batch dimension. |
|
|
|
Args: |
|
preds: Tensor with predictions |
|
target: Tensor with true labels |
|
num_classes: Integer specifying the number of classes |
|
squared: |
|
If True, this will compute the squared hinge loss. Otherwise, computes the regular hinge loss. |
|
multiclass_mode: |
|
Determines how to compute the metric |
|
ignore_index: |
|
Specifies a target value that is ignored and does not contribute to the metric calculation |
|
validate_args: bool indicating if input arguments and tensors should be validated for correctness. |
|
Set to ``False`` for faster computations. |
|
|
|
Example: |
|
>>> from torch import tensor |
|
>>> from torchmetrics.functional.classification import multiclass_hinge_loss |
|
>>> preds = tensor([[0.25, 0.20, 0.55], |
|
... [0.55, 0.05, 0.40], |
|
... [0.10, 0.30, 0.60], |
|
... [0.90, 0.05, 0.05]]) |
|
>>> target = tensor([0, 1, 2, 0]) |
|
>>> multiclass_hinge_loss(preds, target, num_classes=3) |
|
tensor(0.9125) |
|
>>> multiclass_hinge_loss(preds, target, num_classes=3, squared=True) |
|
tensor(1.1131) |
|
>>> multiclass_hinge_loss(preds, target, num_classes=3, multiclass_mode='one-vs-all') |
|
tensor([0.8750, 1.1250, 1.1000]) |
|
|
|
""" |
|
if validate_args: |
|
_multiclass_hinge_loss_arg_validation(num_classes, squared, multiclass_mode, ignore_index) |
|
_multiclass_hinge_loss_tensor_validation(preds, target, num_classes, ignore_index) |
|
preds, target = _multiclass_confusion_matrix_format(preds, target, ignore_index, convert_to_labels=False) |
|
measures, total = _multiclass_hinge_loss_update(preds, target, squared, multiclass_mode) |
|
return _hinge_loss_compute(measures, total) |
|
|
|
|
|
def hinge_loss( |
|
preds: Tensor, |
|
target: Tensor, |
|
task: Literal["binary", "multiclass"], |
|
num_classes: Optional[int] = None, |
|
squared: bool = False, |
|
multiclass_mode: Literal["crammer-singer", "one-vs-all"] = "crammer-singer", |
|
ignore_index: Optional[int] = None, |
|
validate_args: bool = True, |
|
) -> Tensor: |
|
r"""Compute the mean `Hinge loss`_ typically used for Support Vector Machines (SVMs). |
|
|
|
This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the |
|
``task`` argument to either ``'binary'`` or ``'multiclass'``. See the documentation of |
|
:func:`~torchmetrics.functional.classification.binary_hinge_loss` and |
|
:func:`~torchmetrics.functional.classification.multiclass_hinge_loss` for the specific details of |
|
each argument influence and examples. |
|
|
|
Legacy Example: |
|
>>> from torch import tensor |
|
>>> target = tensor([0, 1, 1]) |
|
>>> preds = tensor([0.5, 0.7, 0.1]) |
|
>>> hinge_loss(preds, target, task="binary") |
|
tensor(0.9000) |
|
|
|
>>> target = tensor([0, 1, 2]) |
|
>>> preds = tensor([[-1.0, 0.9, 0.2], [0.5, -1.1, 0.8], [2.2, -0.5, 0.3]]) |
|
>>> hinge_loss(preds, target, task="multiclass", num_classes=3) |
|
tensor(1.5551) |
|
|
|
>>> target = tensor([0, 1, 2]) |
|
>>> preds = tensor([[-1.0, 0.9, 0.2], [0.5, -1.1, 0.8], [2.2, -0.5, 0.3]]) |
|
>>> hinge_loss(preds, target, task="multiclass", num_classes=3, multiclass_mode="one-vs-all") |
|
tensor([1.3743, 1.1945, 1.2359]) |
|
|
|
""" |
|
task = ClassificationTaskNoMultilabel.from_str(task) |
|
if task == ClassificationTaskNoMultilabel.BINARY: |
|
return binary_hinge_loss(preds, target, squared, ignore_index, validate_args) |
|
if task == ClassificationTaskNoMultilabel.MULTICLASS: |
|
if not isinstance(num_classes, int): |
|
raise ValueError(f"`num_classes` is expected to be `int` but `{type(num_classes)} was passed.`") |
|
return multiclass_hinge_loss(preds, target, num_classes, squared, multiclass_mode, ignore_index, validate_args) |
|
raise ValueError(f"Not handled value: {task}") |
|
|