File size: 17,861 Bytes
9c6594c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
# Copyright The Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional, Tuple, Union

import torch
from torch import Tensor
from typing_extensions import Literal

from torchmetrics.functional.classification.roc import binary_roc, multiclass_roc, multilabel_roc
from torchmetrics.utilities import rank_zero_warn
from torchmetrics.utilities.compute import _auc_compute_without_check, _safe_divide
from torchmetrics.utilities.data import interp
from torchmetrics.utilities.enums import ClassificationTask


def _validate_fpr_range(fpr_range: Tuple[float, float]) -> None:
    """Validate the `fpr_range` argument for the logauc metric."""
    if not isinstance(fpr_range, tuple) and not len(fpr_range) == 2:
        raise ValueError(f"The `fpr_range` should be a tuple of two floats, but got {type(fpr_range)}.")
    if not (0 <= fpr_range[0] < fpr_range[1] <= 1):
        raise ValueError(f"The `fpr_range` should be a tuple of two floats in the range [0, 1], but got {fpr_range}.")


def _binary_logauc_compute(
    fpr: Tensor,
    tpr: Tensor,
    fpr_range: Tuple[float, float] = (0.001, 0.1),
) -> Tensor:
    """Compute the logauc score for binary classification tasks."""
    fpr_range = torch.tensor(fpr_range).to(fpr.device)
    if fpr.numel() < 2 or tpr.numel() < 2:
        rank_zero_warn(
            "At least two values on for the fpr and tpr are required to compute the log AUC. Returns 0 score."
        )
        return torch.tensor(0.0, device=fpr.device)

    tpr = torch.cat([tpr, interp(fpr_range, fpr, tpr)]).sort().values
    fpr = torch.cat([fpr, fpr_range]).sort().values

    log_fpr = torch.log10(fpr)
    bounds = torch.log10(torch.tensor(fpr_range))

    lower_bound_idx = torch.where(log_fpr == bounds[0])[0][-1]
    upper_bound_idx = torch.where(log_fpr == bounds[1])[0][-1]

    trimmed_log_fpr = log_fpr[lower_bound_idx : upper_bound_idx + 1]
    trimmed_tpr = tpr[lower_bound_idx : upper_bound_idx + 1]

    # compute area and rescale it to the range of fpr
    return _auc_compute_without_check(trimmed_log_fpr, trimmed_tpr, 1.0) / (bounds[1] - bounds[0])


def _reduce_logauc(
    fpr: Union[Tensor, List[Tensor]],
    tpr: Union[Tensor, List[Tensor]],
    fpr_range: Tuple[float, float] = (0.001, 0.1),
    average: Optional[Literal["macro", "weighted", "none"]] = "macro",
    weights: Optional[Tensor] = None,
) -> Tensor:
    """Reduce the logauc score to a single value for multiclass and multilabel classification tasks."""
    scores = []
    for fpr_i, tpr_i in zip(fpr, tpr):
        scores.append(_binary_logauc_compute(fpr_i, tpr_i, fpr_range))
    scores = torch.stack(scores)
    if torch.isnan(scores).any():
        rank_zero_warn(
            "LogAUC score for one or more classes/labels was `nan`. Ignoring these classes in {average}-average."
        )
    idx = ~torch.isnan(scores)
    if average is None or average == "none":
        return scores
    if average == "macro":
        return scores[idx].mean()
    if average == "weighted" and weights is not None:
        weights = _safe_divide(weights[idx], weights[idx].sum())
        return (scores[idx] * weights).sum()
    raise ValueError(f"Got unknown average parameter: {average}. Please choose one of ['macro', 'weighted', 'none'].")


def binary_logauc(
    preds: Tensor,
    target: Tensor,
    fpr_range: Tuple[float, float] = (0.001, 0.1),
    thresholds: Optional[Union[int, List[float], Tensor]] = None,
    ignore_index: Optional[int] = None,
    validate_args: bool = True,
) -> Tensor:
    r"""Compute the `Log AUC`_ score for binary classification tasks.

    The score is computed by first computing the ROC curve, which then is interpolated to the specified range of false
    positive rates (FPR) and then the log is taken of the FPR before the area under the curve (AUC) is computed. The
    score is commonly used in applications where the positive and negative are imbalanced and a low false positive rate
    is of high importance.

    Accepts the following input tensors:

    - ``preds`` (float tensor): ``(N, ...)``. Preds should be a tensor containing probabilities or logits for each
      observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply
      sigmoid per element.
    - ``target`` (int tensor): ``(N, ...)``. Target should be a tensor containing ground truth labels, and therefore
      only contain {0,1} values (except if `ignore_index` is specified). The value 1 always encodes the positive class.

    Additional dimension ``...`` will be flattened into the batch dimension.

    The implementation both supports calculating the metric in a non-binned but accurate version and a binned version
    that is less accurate but more memory efficient. Setting the `thresholds` argument to `None` will activate the
    non-binned  version that uses memory of size :math:`\mathcal{O}(n_{samples})` whereas setting the `thresholds`
    argument to either an integer, list or a 1d tensor will use a binned version that uses memory of
    size :math:`\mathcal{O}(n_{thresholds})` (constant memory).

    Args:
        preds: Tensor with predictions
        target: Tensor with ground truth labels
        fpr_range: 2-element tuple with the lower and upper bound of the false positive rate range to compute the log
            AUC score.
        thresholds:
            Can be one of:

            - If set to `None`, will use a non-binned approach where thresholds are dynamically calculated from
              all the data. Most accurate but also most memory consuming approach.
            - If set to an `int` (larger than 1), will use that number of thresholds linearly spaced from
              0 to 1 as bins for the calculation.
            - If set to an `list` of floats, will use the indicated thresholds in the list as bins for the calculation
            - If set to an 1d `tensor` of floats, will use the indicated thresholds in the tensor as
              bins for the calculation.

        ignore_index:
            Specifies a target value that is ignored and does not contribute to the metric calculation
        validate_args: bool indicating if input arguments and tensors should be validated for correctness.
            Set to ``False`` for faster computations.

    Returns:
        A single scalar with the log auc score

    Example:
        >>> from torchmetrics.functional.classification import binary_logauc
        >>> from torch import tensor
        >>> preds = tensor([0.75, 0.05, 0.05, 0.05, 0.05])
        >>> target = tensor([1, 0, 0, 0, 0])
        >>> binary_logauc(preds, target)
        tensor(1.)

    """
    _validate_fpr_range(fpr_range)
    fpr, tpr, _ = binary_roc(preds, target, thresholds, ignore_index, validate_args)
    return _binary_logauc_compute(fpr, tpr, fpr_range)


def multiclass_logauc(
    preds: Tensor,
    target: Tensor,
    num_classes: int,
    fpr_range: Tuple[float, float] = (0.001, 0.1),
    average: Optional[Literal["macro", "none"]] = "macro",
    thresholds: Optional[Union[int, List[float], Tensor]] = None,
    ignore_index: Optional[int] = None,
    validate_args: bool = True,
) -> Tensor:
    r"""Compute the `Log AUC`_ score for multiclass classification tasks.

    The score is computed by first computing the ROC curve, which then is interpolated to the specified range of false
    positive rates (FPR) and then the log is taken of the FPR before the area under the curve (AUC) is computed. The
    score is commonly used in applications where the positive and negative are imbalanced and a low false positive rate
    is of high importance.

    Accepts the following input tensors:

    - ``preds`` (float tensor): ``(N, C, ...)``. Preds should be a tensor containing probabilities or logits for each
      observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply
      softmax per sample.
    - ``target`` (int tensor): ``(N, ...)``. Target should be a tensor containing ground truth labels, and therefore
      only contain values in the [0, n_classes-1] range (except if `ignore_index` is specified).

    Additional dimension ``...`` will be flattened into the batch dimension.

    The implementation both supports calculating the metric in a non-binned but accurate version and a binned version
    that is less accurate but more memory efficient. Setting the `thresholds` argument to `None` will activate the
    non-binned  version that uses memory of size :math:`\mathcal{O}(n_{samples})` whereas setting the `thresholds`
    argument to either an integer, list or a 1d tensor will use a binned version that uses memory of
    size :math:`\mathcal{O}(n_{thresholds} \times n_{classes})` (constant memory).

    Args:
        preds: Tensor with predictions
        target: Tensor with true labels
        num_classes: Integer specifying the number of classes
        fpr_range: 2-element tuple with the lower and upper bound of the false positive rate range to compute the log
            AUC score.
        thresholds:
            Can be one of:

            - If set to `None`, will use a non-binned approach where thresholds are dynamically calculated from
              all the data. Most accurate but also most memory consuming approach.
            - If set to an `int` (larger than 1), will use that number of thresholds linearly spaced from
              0 to 1 as bins for the calculation.
            - If set to an `list` of floats, will use the indicated thresholds in the list as bins for the calculation
            - If set to an 1d `tensor` of floats, will use the indicated thresholds in the tensor as
              bins for the calculation.

        average:
            Defines the reduction that is applied over classes. Should be one of the following:

            - ``macro``: Calculate score for each class and average them
            - ``"none"`` or ``None``: calculates score for each class and applies no reduction

        ignore_index:
            Specifies a target value that is ignored and does not contribute to the metric calculation
        validate_args: bool indicating if input arguments and tensors should be validated for correctness.
            Set to ``False`` for faster computations.

    Example:
        >>> from torchmetrics.functional.classification import multiclass_logauc
        >>> preds = torch.tensor([[0.75, 0.05, 0.05, 0.05, 0.05],
        ...                       [0.05, 0.75, 0.05, 0.05, 0.05],
        ...                       [0.05, 0.05, 0.75, 0.05, 0.05],
        ...                       [0.05, 0.05, 0.05, 0.75, 0.05]])
        >>> target = torch.tensor([0, 1, 3, 2])
        >>> multiclass_logauc(preds, target, num_classes=5, average="macro", thresholds=None)
        tensor(0.4000)
        >>> multiclass_logauc(preds, target, num_classes=5, average=None, thresholds=None)
        tensor([1., 1., 0., 0., 0.])

    """
    if validate_args:
        _validate_fpr_range(fpr_range)
    fpr, tpr, _ = multiclass_roc(
        preds, target, num_classes, thresholds, average=None, ignore_index=ignore_index, validate_args=validate_args
    )
    return _reduce_logauc(fpr, tpr, fpr_range, average)


def multilabel_logauc(
    preds: Tensor,
    target: Tensor,
    num_labels: int,
    fpr_range: Tuple[float, float] = (0.001, 0.1),
    average: Optional[Literal["macro", "none"]] = "macro",
    thresholds: Optional[Union[int, List[float], Tensor]] = None,
    ignore_index: Optional[int] = None,
    validate_args: bool = True,
) -> Tensor:
    r"""Compute the `Log AUC`_ score for multilabel classification tasks.

    The score is computed by first computing the ROC curve, which then is interpolated to the specified range of false
    positive rates (FPR) and then the log is taken of the FPR before the area under the curve (AUC) is computed. The
    score is commonly used in applications where the positive and negative are imbalanced and a low false positive rate
    is of high importance.

    Accepts the following input tensors:

    - ``preds`` (float tensor): ``(N, C, ...)``. Preds should be a tensor containing probabilities or logits for each
      observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply
      sigmoid per element.
    - ``target`` (int tensor): ``(N, C, ...)``. Target should be a tensor containing ground truth labels, and therefore
      only contain {0,1} values (except if `ignore_index` is specified).

    Additional dimension ``...`` will be flattened into the batch dimension.

    The implementation both supports calculating the metric in a non-binned but accurate version and a binned version
    that is less accurate but more memory efficient. Setting the `thresholds` argument to `None` will activate the
    non-binned  version that uses memory of size :math:`\mathcal{O}(n_{samples})` whereas setting the `thresholds`
    argument to either an integer, list or a 1d tensor will use a binned version that uses memory of
    size :math:`\mathcal{O}(n_{thresholds} \times n_{labels})` (constant memory).

    Args:
        preds: Tensor with predictions
        target: Tensor with true labels
        num_labels: Integer specifying the number of labels
        fpr_range: 2-element tuple with the lower and upper bound of the false positive rate range to compute the log
            AUC score.
        average:
            Defines the reduction that is applied over labels. Should be one of the following:

            - ``macro``: Calculate score for each label and average them
            - ``"none"`` or ``None``: calculates score for each label and applies no reduction

        thresholds:
            Can be one of:

            - If set to `None`, will use a non-binned approach where thresholds are dynamically calculated from
              all the data. Most accurate but also most memory consuming approach.
            - If set to an `int` (larger than 1), will use that number of thresholds linearly spaced from
              0 to 1 as bins for the calculation.
            - If set to an `list` of floats, will use the indicated thresholds in the list as bins for the calculation
            - If set to an 1d `tensor` of floats, will use the indicated thresholds in the tensor as
              bins for the calculation.

        ignore_index:
            Specifies a target value that is ignored and does not contribute to the metric calculation
        validate_args: bool indicating if input arguments and tensors should be validated for correctness.
            Set to ``False`` for faster computations.

    Example:
        >>> from torchmetrics.functional.classification import multilabel_logauc
        >>> preds = torch.tensor([[0.75, 0.05, 0.35],
        ...                       [0.45, 0.75, 0.05],
        ...                       [0.05, 0.55, 0.75],
        ...                       [0.05, 0.65, 0.05]])
        >>> target = torch.tensor([[1, 0, 1],
        ...                        [0, 0, 0],
        ...                        [0, 1, 1],
        ...                        [1, 1, 1]])
        >>> multilabel_logauc(preds, target, num_labels=3, average="macro", thresholds=None)
        tensor(0.3945)
        >>> multilabel_logauc(preds, target, num_labels=3, average=None, thresholds=None)
        tensor([0.5000, 0.0000, 0.6835])

    """
    fpr, tpr, _ = multilabel_roc(preds, target, num_labels, thresholds, ignore_index, validate_args)
    return _reduce_logauc(fpr, tpr, fpr_range, average=average)


def logauc(
    preds: Tensor,
    target: Tensor,
    task: Literal["binary", "multiclass", "multilabel"],
    thresholds: Optional[Union[int, List[float], Tensor]] = None,
    num_classes: Optional[int] = None,
    num_labels: Optional[int] = None,
    fpr_range: Tuple[float, float] = (0.001, 0.1),
    average: Optional[Literal["macro", "none"]] = None,
    ignore_index: Optional[int] = None,
    validate_args: bool = True,
) -> Optional[Tensor]:
    r"""Compute the `Log AUC`_ score for classification tasks.

    The score is computed by first computing the ROC curve, which then is interpolated to the specified range of false
    positive rates (FPR) and then the log is taken of the FPR before the area under the curve (AUC) is computed. The
    score is commonly used in applications where the positive and negative are imbalanced and a low false positive rate
    is of high importance.

    """
    task = ClassificationTask.from_str(task)
    if task == ClassificationTask.BINARY:
        return binary_logauc(preds, target, fpr_range, thresholds, ignore_index, validate_args)
    if task == ClassificationTask.MULTICLASS:
        if not isinstance(num_classes, int):
            raise ValueError(f"`num_classes` is expected to be `int` but `{type(num_classes)} was passed.`")
        return multiclass_logauc(
            preds, target, num_classes, fpr_range, average, thresholds, ignore_index, validate_args
        )
    if task == ClassificationTask.MULTILABEL:
        if not isinstance(num_labels, int):
            raise ValueError(f"`num_labels` is expected to be `int` but `{type(num_labels)} was passed.`")
        return multilabel_logauc(preds, target, num_labels, fpr_range, average, thresholds, ignore_index, validate_args)
    return None