|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import List, Optional, Union |
|
|
|
import torch |
|
from torch import Tensor, tensor |
|
from typing_extensions import Literal |
|
|
|
from torchmetrics.functional.classification.precision_recall_curve import ( |
|
_binary_precision_recall_curve_arg_validation, |
|
_binary_precision_recall_curve_format, |
|
_binary_precision_recall_curve_tensor_validation, |
|
_binary_precision_recall_curve_update, |
|
_multiclass_precision_recall_curve_arg_validation, |
|
_multiclass_precision_recall_curve_format, |
|
_multiclass_precision_recall_curve_tensor_validation, |
|
_multiclass_precision_recall_curve_update, |
|
_multilabel_precision_recall_curve_arg_validation, |
|
_multilabel_precision_recall_curve_format, |
|
_multilabel_precision_recall_curve_tensor_validation, |
|
_multilabel_precision_recall_curve_update, |
|
) |
|
from torchmetrics.functional.classification.roc import ( |
|
_binary_roc_compute, |
|
_multiclass_roc_compute, |
|
_multilabel_roc_compute, |
|
) |
|
from torchmetrics.utilities.compute import _auc_compute_without_check, _safe_divide |
|
from torchmetrics.utilities.data import _bincount |
|
from torchmetrics.utilities.enums import ClassificationTask |
|
from torchmetrics.utilities.prints import rank_zero_warn |
|
|
|
|
|
def _reduce_auroc( |
|
fpr: Union[Tensor, List[Tensor]], |
|
tpr: Union[Tensor, List[Tensor]], |
|
average: Optional[Literal["macro", "weighted", "none"]] = "macro", |
|
weights: Optional[Tensor] = None, |
|
direction: float = 1.0, |
|
) -> Tensor: |
|
"""Reduce multiple average precision score into one number.""" |
|
if isinstance(fpr, Tensor) and isinstance(tpr, Tensor): |
|
res = _auc_compute_without_check(fpr, tpr, direction=direction, axis=1) |
|
else: |
|
res = torch.stack([_auc_compute_without_check(x, y, direction=direction) for x, y in zip(fpr, tpr)]) |
|
if average is None or average == "none": |
|
return res |
|
if torch.isnan(res).any(): |
|
rank_zero_warn( |
|
f"Average precision score for one or more classes was `nan`. Ignoring these classes in {average}-average", |
|
UserWarning, |
|
) |
|
idx = ~torch.isnan(res) |
|
if average == "macro": |
|
return res[idx].mean() |
|
if average == "weighted" and weights is not None: |
|
weights = _safe_divide(weights[idx], weights[idx].sum()) |
|
return (res[idx] * weights).sum() |
|
raise ValueError("Received an incompatible combinations of inputs to make reduction.") |
|
|
|
|
|
def _binary_auroc_arg_validation( |
|
max_fpr: Optional[float] = None, |
|
thresholds: Optional[Union[int, list[float], Tensor]] = None, |
|
ignore_index: Optional[int] = None, |
|
) -> None: |
|
_binary_precision_recall_curve_arg_validation(thresholds, ignore_index) |
|
if max_fpr is not None and not isinstance(max_fpr, float) and 0 < max_fpr <= 1: |
|
raise ValueError(f"Arguments `max_fpr` should be a float in range (0, 1], but got: {max_fpr}") |
|
|
|
|
|
def _binary_auroc_compute( |
|
state: Union[Tensor, tuple[Tensor, Tensor]], |
|
thresholds: Optional[Tensor], |
|
max_fpr: Optional[float] = None, |
|
pos_label: int = 1, |
|
) -> Tensor: |
|
fpr, tpr, _ = _binary_roc_compute(state, thresholds, pos_label) |
|
if max_fpr is None or max_fpr == 1 or fpr.sum() == 0 or tpr.sum() == 0: |
|
return _auc_compute_without_check(fpr, tpr, 1.0) |
|
|
|
_device = fpr.device if isinstance(fpr, Tensor) else fpr[0].device |
|
max_area: Tensor = tensor(max_fpr, device=_device) |
|
|
|
stop = torch.bucketize(max_area, fpr, out_int32=True, right=True) |
|
weight = (max_area - fpr[stop - 1]) / (fpr[stop] - fpr[stop - 1]) |
|
interp_tpr: Tensor = torch.lerp(tpr[stop - 1], tpr[stop], weight) |
|
tpr = torch.cat([tpr[:stop], interp_tpr.view(1)]) |
|
fpr = torch.cat([fpr[:stop], max_area.view(1)]) |
|
|
|
|
|
partial_auc = _auc_compute_without_check(fpr, tpr, 1.0) |
|
|
|
|
|
min_area: Tensor = 0.5 * max_area**2 |
|
return 0.5 * (1 + (partial_auc - min_area) / (max_area - min_area)) |
|
|
|
|
|
def binary_auroc( |
|
preds: Tensor, |
|
target: Tensor, |
|
max_fpr: Optional[float] = None, |
|
thresholds: Optional[Union[int, list[float], Tensor]] = None, |
|
ignore_index: Optional[int] = None, |
|
validate_args: bool = True, |
|
) -> Tensor: |
|
r"""Compute Area Under the Receiver Operating Characteristic Curve (`ROC AUC`_) for binary tasks. |
|
|
|
The AUROC score summarizes the ROC curve into an single number that describes the performance of a model for |
|
multiple thresholds at the same time. Notably, an AUROC score of 1 is a perfect score and an AUROC score of 0.5 |
|
corresponds to random guessing. |
|
|
|
Accepts the following input tensors: |
|
|
|
- ``preds`` (float tensor): ``(N, ...)``. Preds should be a tensor containing probabilities or logits for each |
|
observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply |
|
sigmoid per element. |
|
- ``target`` (int tensor): ``(N, ...)``. Target should be a tensor containing ground truth labels, and therefore |
|
only contain {0,1} values (except if `ignore_index` is specified). The value 1 always encodes the positive class. |
|
|
|
Additional dimension ``...`` will be flattened into the batch dimension. |
|
|
|
The implementation both supports calculating the metric in a non-binned but accurate version and a binned version |
|
that is less accurate but more memory efficient. Setting the `thresholds` argument to `None` will activate the |
|
non-binned version that uses memory of size :math:`\mathcal{O}(n_{samples})` whereas setting the `thresholds` |
|
argument to either an integer, list or a 1d tensor will use a binned version that uses memory of |
|
size :math:`\mathcal{O}(n_{thresholds})` (constant memory). |
|
|
|
Args: |
|
preds: Tensor with predictions |
|
target: Tensor with true labels |
|
max_fpr: If not ``None``, calculates standardized partial AUC over the range ``[0, max_fpr]``. |
|
thresholds: |
|
Can be one of: |
|
|
|
- If set to `None`, will use a non-binned approach where thresholds are dynamically calculated from |
|
all the data. Most accurate but also most memory consuming approach. |
|
- If set to an `int` (larger than 1), will use that number of thresholds linearly spaced from |
|
0 to 1 as bins for the calculation. |
|
- If set to an `list` of floats, will use the indicated thresholds in the list as bins for the calculation |
|
- If set to an 1d `tensor` of floats, will use the indicated thresholds in the tensor as |
|
bins for the calculation. |
|
|
|
ignore_index: |
|
Specifies a target value that is ignored and does not contribute to the metric calculation |
|
validate_args: bool indicating if input arguments and tensors should be validated for correctness. |
|
Set to ``False`` for faster computations. |
|
|
|
Returns: |
|
A single scalar with the auroc score |
|
|
|
Example: |
|
>>> from torchmetrics.functional.classification import binary_auroc |
|
>>> preds = torch.tensor([0, 0.5, 0.7, 0.8]) |
|
>>> target = torch.tensor([0, 1, 1, 0]) |
|
>>> binary_auroc(preds, target, thresholds=None) |
|
tensor(0.5000) |
|
>>> binary_auroc(preds, target, thresholds=5) |
|
tensor(0.5000) |
|
|
|
""" |
|
if validate_args: |
|
_binary_auroc_arg_validation(max_fpr, thresholds, ignore_index) |
|
_binary_precision_recall_curve_tensor_validation(preds, target, ignore_index) |
|
preds, target, thresholds = _binary_precision_recall_curve_format(preds, target, thresholds, ignore_index) |
|
state = _binary_precision_recall_curve_update(preds, target, thresholds) |
|
return _binary_auroc_compute(state, thresholds, max_fpr) |
|
|
|
|
|
def _multiclass_auroc_arg_validation( |
|
num_classes: int, |
|
average: Optional[Literal["macro", "weighted", "none"]] = "macro", |
|
thresholds: Optional[Union[int, list[float], Tensor]] = None, |
|
ignore_index: Optional[int] = None, |
|
) -> None: |
|
_multiclass_precision_recall_curve_arg_validation(num_classes, thresholds, ignore_index) |
|
allowed_average = ("macro", "weighted", "none", None) |
|
if average not in allowed_average: |
|
raise ValueError(f"Expected argument `average` to be one of {allowed_average} but got {average}") |
|
|
|
|
|
def _multiclass_auroc_compute( |
|
state: Union[Tensor, tuple[Tensor, Tensor]], |
|
num_classes: int, |
|
average: Optional[Literal["macro", "weighted", "none"]] = "macro", |
|
thresholds: Optional[Tensor] = None, |
|
) -> Tensor: |
|
fpr, tpr, _ = _multiclass_roc_compute(state, num_classes, thresholds) |
|
return _reduce_auroc( |
|
fpr, |
|
tpr, |
|
average, |
|
weights=_bincount(state[1], minlength=num_classes).float() if thresholds is None else state[0][:, 1, :].sum(-1), |
|
) |
|
|
|
|
|
def multiclass_auroc( |
|
preds: Tensor, |
|
target: Tensor, |
|
num_classes: int, |
|
average: Optional[Literal["macro", "weighted", "none"]] = "macro", |
|
thresholds: Optional[Union[int, list[float], Tensor]] = None, |
|
ignore_index: Optional[int] = None, |
|
validate_args: bool = True, |
|
) -> Tensor: |
|
r"""Compute Area Under the Receiver Operating Characteristic Curve (`ROC AUC`_) for multiclass tasks. |
|
|
|
The AUROC score summarizes the ROC curve into an single number that describes the performance of a model for |
|
multiple thresholds at the same time. Notably, an AUROC score of 1 is a perfect score and an AUROC score of 0.5 |
|
corresponds to random guessing. |
|
|
|
Accepts the following input tensors: |
|
|
|
- ``preds`` (float tensor): ``(N, C, ...)``. Preds should be a tensor containing probabilities or logits for each |
|
observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply |
|
softmax per sample. |
|
- ``target`` (int tensor): ``(N, ...)``. Target should be a tensor containing ground truth labels, and therefore |
|
only contain values in the [0, n_classes-1] range (except if `ignore_index` is specified). |
|
|
|
Additional dimension ``...`` will be flattened into the batch dimension. |
|
|
|
The implementation both supports calculating the metric in a non-binned but accurate version and a binned version |
|
that is less accurate but more memory efficient. Setting the `thresholds` argument to `None` will activate the |
|
non-binned version that uses memory of size :math:`\mathcal{O}(n_{samples})` whereas setting the `thresholds` |
|
argument to either an integer, list or a 1d tensor will use a binned version that uses memory of |
|
size :math:`\mathcal{O}(n_{thresholds} \times n_{classes})` (constant memory). |
|
|
|
Args: |
|
preds: Tensor with predictions |
|
target: Tensor with true labels |
|
num_classes: Integer specifying the number of classes |
|
average: |
|
Defines the reduction that is applied over classes. Should be one of the following: |
|
|
|
- ``macro``: Calculate score for each class and average them |
|
- ``weighted``: calculates score for each class and computes weighted average using their support |
|
- ``"none"`` or ``None``: calculates score for each class and applies no reduction |
|
thresholds: |
|
Can be one of: |
|
|
|
- If set to `None`, will use a non-binned approach where thresholds are dynamically calculated from |
|
all the data. Most accurate but also most memory consuming approach. |
|
- If set to an `int` (larger than 1), will use that number of thresholds linearly spaced from |
|
0 to 1 as bins for the calculation. |
|
- If set to an `list` of floats, will use the indicated thresholds in the list as bins for the calculation |
|
- If set to an 1d `tensor` of floats, will use the indicated thresholds in the tensor as |
|
bins for the calculation. |
|
|
|
ignore_index: |
|
Specifies a target value that is ignored and does not contribute to the metric calculation |
|
validate_args: bool indicating if input arguments and tensors should be validated for correctness. |
|
Set to ``False`` for faster computations. |
|
|
|
Returns: |
|
If `average=None|"none"` then a 1d tensor of shape (n_classes, ) will be returned with auroc score per class. |
|
If `average="macro"|"weighted"` then a single scalar is returned. |
|
|
|
Example: |
|
>>> from torchmetrics.functional.classification import multiclass_auroc |
|
>>> preds = torch.tensor([[0.75, 0.05, 0.05, 0.05, 0.05], |
|
... [0.05, 0.75, 0.05, 0.05, 0.05], |
|
... [0.05, 0.05, 0.75, 0.05, 0.05], |
|
... [0.05, 0.05, 0.05, 0.75, 0.05]]) |
|
>>> target = torch.tensor([0, 1, 3, 2]) |
|
>>> multiclass_auroc(preds, target, num_classes=5, average="macro", thresholds=None) |
|
tensor(0.5333) |
|
>>> multiclass_auroc(preds, target, num_classes=5, average=None, thresholds=None) |
|
tensor([1.0000, 1.0000, 0.3333, 0.3333, 0.0000]) |
|
>>> multiclass_auroc(preds, target, num_classes=5, average="macro", thresholds=5) |
|
tensor(0.5333) |
|
>>> multiclass_auroc(preds, target, num_classes=5, average=None, thresholds=5) |
|
tensor([1.0000, 1.0000, 0.3333, 0.3333, 0.0000]) |
|
|
|
""" |
|
if validate_args: |
|
_multiclass_auroc_arg_validation(num_classes, average, thresholds, ignore_index) |
|
_multiclass_precision_recall_curve_tensor_validation(preds, target, num_classes, ignore_index) |
|
preds, target, thresholds = _multiclass_precision_recall_curve_format( |
|
preds, target, num_classes, thresholds, ignore_index |
|
) |
|
state = _multiclass_precision_recall_curve_update(preds, target, num_classes, thresholds) |
|
return _multiclass_auroc_compute(state, num_classes, average, thresholds) |
|
|
|
|
|
def _multilabel_auroc_arg_validation( |
|
num_labels: int, |
|
average: Optional[Literal["micro", "macro", "weighted", "none"]], |
|
thresholds: Optional[Union[int, list[float], Tensor]] = None, |
|
ignore_index: Optional[int] = None, |
|
) -> None: |
|
_multilabel_precision_recall_curve_arg_validation(num_labels, thresholds, ignore_index) |
|
allowed_average = ("micro", "macro", "weighted", "none", None) |
|
if average not in allowed_average: |
|
raise ValueError(f"Expected argument `average` to be one of {allowed_average} but got {average}") |
|
|
|
|
|
def _multilabel_auroc_compute( |
|
state: Union[Tensor, tuple[Tensor, Tensor]], |
|
num_labels: int, |
|
average: Optional[Literal["micro", "macro", "weighted", "none"]], |
|
thresholds: Optional[Tensor], |
|
ignore_index: Optional[int] = None, |
|
) -> Tensor: |
|
if average == "micro": |
|
if isinstance(state, Tensor) and thresholds is not None: |
|
return _binary_auroc_compute(state.sum(1), thresholds, max_fpr=None) |
|
|
|
preds = state[0].flatten() |
|
target = state[1].flatten() |
|
if ignore_index is not None: |
|
idx = target == ignore_index |
|
preds = preds[~idx] |
|
target = target[~idx] |
|
return _binary_auroc_compute((preds, target), thresholds, max_fpr=None) |
|
|
|
fpr, tpr, _ = _multilabel_roc_compute(state, num_labels, thresholds, ignore_index) |
|
return _reduce_auroc( |
|
fpr, |
|
tpr, |
|
average, |
|
weights=(state[1] == 1).sum(dim=0).float() if thresholds is None else state[0][:, 1, :].sum(-1), |
|
) |
|
|
|
|
|
def multilabel_auroc( |
|
preds: Tensor, |
|
target: Tensor, |
|
num_labels: int, |
|
average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro", |
|
thresholds: Optional[Union[int, list[float], Tensor]] = None, |
|
ignore_index: Optional[int] = None, |
|
validate_args: bool = True, |
|
) -> Tensor: |
|
r"""Compute Area Under the Receiver Operating Characteristic Curve (`ROC AUC`_) for multilabel tasks. |
|
|
|
The AUROC score summarizes the ROC curve into an single number that describes the performance of a model for |
|
multiple thresholds at the same time. Notably, an AUROC score of 1 is a perfect score and an AUROC score of 0.5 |
|
corresponds to random guessing. |
|
|
|
Accepts the following input tensors: |
|
|
|
- ``preds`` (float tensor): ``(N, C, ...)``. Preds should be a tensor containing probabilities or logits for each |
|
observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply |
|
sigmoid per element. |
|
- ``target`` (int tensor): ``(N, C, ...)``. Target should be a tensor containing ground truth labels, and therefore |
|
only contain {0,1} values (except if `ignore_index` is specified). |
|
|
|
Additional dimension ``...`` will be flattened into the batch dimension. |
|
|
|
The implementation both supports calculating the metric in a non-binned but accurate version and a binned version |
|
that is less accurate but more memory efficient. Setting the `thresholds` argument to `None` will activate the |
|
non-binned version that uses memory of size :math:`\mathcal{O}(n_{samples})` whereas setting the `thresholds` |
|
argument to either an integer, list or a 1d tensor will use a binned version that uses memory of |
|
size :math:`\mathcal{O}(n_{thresholds} \times n_{labels})` (constant memory). |
|
|
|
Args: |
|
preds: Tensor with predictions |
|
target: Tensor with true labels |
|
num_labels: Integer specifying the number of labels |
|
average: |
|
Defines the reduction that is applied over labels. Should be one of the following: |
|
|
|
- ``micro``: Sum score over all labels |
|
- ``macro``: Calculate score for each label and average them |
|
- ``weighted``: calculates score for each label and computes weighted average using their support |
|
- ``"none"`` or ``None``: calculates score for each label and applies no reduction |
|
thresholds: |
|
Can be one of: |
|
|
|
- If set to `None`, will use a non-binned approach where thresholds are dynamically calculated from |
|
all the data. Most accurate but also most memory consuming approach. |
|
- If set to an `int` (larger than 1), will use that number of thresholds linearly spaced from |
|
0 to 1 as bins for the calculation. |
|
- If set to an `list` of floats, will use the indicated thresholds in the list as bins for the calculation |
|
- If set to an 1d `tensor` of floats, will use the indicated thresholds in the tensor as |
|
bins for the calculation. |
|
|
|
ignore_index: |
|
Specifies a target value that is ignored and does not contribute to the metric calculation |
|
validate_args: bool indicating if input arguments and tensors should be validated for correctness. |
|
Set to ``False`` for faster computations. |
|
|
|
Returns: |
|
If `average=None|"none"` then a 1d tensor of shape (n_classes, ) will be returned with auroc score per class. |
|
If `average="micro|macro"|"weighted"` then a single scalar is returned. |
|
|
|
Example: |
|
>>> from torchmetrics.functional.classification import multilabel_auroc |
|
>>> preds = torch.tensor([[0.75, 0.05, 0.35], |
|
... [0.45, 0.75, 0.05], |
|
... [0.05, 0.55, 0.75], |
|
... [0.05, 0.65, 0.05]]) |
|
>>> target = torch.tensor([[1, 0, 1], |
|
... [0, 0, 0], |
|
... [0, 1, 1], |
|
... [1, 1, 1]]) |
|
>>> multilabel_auroc(preds, target, num_labels=3, average="macro", thresholds=None) |
|
tensor(0.6528) |
|
>>> multilabel_auroc(preds, target, num_labels=3, average=None, thresholds=None) |
|
tensor([0.6250, 0.5000, 0.8333]) |
|
>>> multilabel_auroc(preds, target, num_labels=3, average="macro", thresholds=5) |
|
tensor(0.6528) |
|
>>> multilabel_auroc(preds, target, num_labels=3, average=None, thresholds=5) |
|
tensor([0.6250, 0.5000, 0.8333]) |
|
|
|
""" |
|
if validate_args: |
|
_multilabel_auroc_arg_validation(num_labels, average, thresholds, ignore_index) |
|
_multilabel_precision_recall_curve_tensor_validation(preds, target, num_labels, ignore_index) |
|
preds, target, thresholds = _multilabel_precision_recall_curve_format( |
|
preds, target, num_labels, thresholds, ignore_index |
|
) |
|
state = _multilabel_precision_recall_curve_update(preds, target, num_labels, thresholds) |
|
return _multilabel_auroc_compute(state, num_labels, average, thresholds, ignore_index) |
|
|
|
|
|
def auroc( |
|
preds: Tensor, |
|
target: Tensor, |
|
task: Literal["binary", "multiclass", "multilabel"], |
|
thresholds: Optional[Union[int, list[float], Tensor]] = None, |
|
num_classes: Optional[int] = None, |
|
num_labels: Optional[int] = None, |
|
average: Optional[Literal["macro", "weighted", "none"]] = "macro", |
|
max_fpr: Optional[float] = None, |
|
ignore_index: Optional[int] = None, |
|
validate_args: bool = True, |
|
) -> Optional[Tensor]: |
|
r"""Compute Area Under the Receiver Operating Characteristic Curve (`ROC AUC`_). |
|
|
|
The AUROC score summarizes the ROC curve into an single number that describes the performance of a model for |
|
multiple thresholds at the same time. Notably, an AUROC score of 1 is a perfect score and an AUROC score of 0.5 |
|
corresponds to random guessing. |
|
|
|
This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the |
|
``task`` argument to either ``'binary'``, ``'multiclass'`` or ``multilabel``. See the documentation of |
|
:func:`~torchmetrics.functional.classification.binary_auroc`, |
|
:func:`~torchmetrics.functional.classification.multiclass_auroc` and |
|
:func:`~torchmetrics.functional.classification.multilabel_auroc` for the specific details of |
|
each argument influence and examples. |
|
|
|
Legacy Example: |
|
>>> preds = torch.tensor([0.13, 0.26, 0.08, 0.19, 0.34]) |
|
>>> target = torch.tensor([0, 0, 1, 1, 1]) |
|
>>> auroc(preds, target, task='binary') |
|
tensor(0.5000) |
|
|
|
>>> preds = torch.tensor([[0.90, 0.05, 0.05], |
|
... [0.05, 0.90, 0.05], |
|
... [0.05, 0.05, 0.90], |
|
... [0.85, 0.05, 0.10], |
|
... [0.10, 0.10, 0.80]]) |
|
>>> target = torch.tensor([0, 1, 1, 2, 2]) |
|
>>> auroc(preds, target, task='multiclass', num_classes=3) |
|
tensor(0.7778) |
|
|
|
""" |
|
task = ClassificationTask.from_str(task) |
|
if task == ClassificationTask.BINARY: |
|
return binary_auroc(preds, target, max_fpr, thresholds, ignore_index, validate_args) |
|
if task == ClassificationTask.MULTICLASS: |
|
if not isinstance(num_classes, int): |
|
raise ValueError(f"`num_classes` is expected to be `int` but `{type(num_classes)} was passed.`") |
|
return multiclass_auroc(preds, target, num_classes, average, thresholds, ignore_index, validate_args) |
|
if task == ClassificationTask.MULTILABEL: |
|
if not isinstance(num_labels, int): |
|
raise ValueError(f"`num_labels` is expected to be `int` but `{type(num_labels)} was passed.`") |
|
return multilabel_auroc(preds, target, num_labels, average, thresholds, ignore_index, validate_args) |
|
return None |
|
|