File size: 4,792 Bytes
9c6594c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
# Copyright The Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, List, Optional

import torch
from torch import Tensor

from torchmetrics.functional.regression.csi import _critical_success_index_compute, _critical_success_index_update
from torchmetrics.metric import Metric
from torchmetrics.utilities import dim_zero_cat


class CriticalSuccessIndex(Metric):
    r"""Calculate critical success index (CSI).

    Critical success index (also known as the threat score) is a statistic used weather forecasting that measures
    forecast performance over inputs binarized at a specified threshold. It is defined as:

    .. math:: \text{CSI} = \frac{\text{TP}}{\text{TP}+\text{FN}+\text{FP}}

    Where :math:`\text{TP}`, :math:`\text{FN}` and :math:`\text{FP}` represent the number of true positives, false
    negatives and false positives respectively after binarizing the input tensors.

    Args:
        threshold: Values above or equal to threshold are replaced with 1, below by 0
        keep_sequence_dim: Index of the sequence dimension if the inputs are sequences of images. If specified,
            the score will be calculated separately for each image in the sequence. If ``None``, the score will be
            calculated across all dimensions.

    Example:
        >>> import torch
        >>> from torchmetrics.regression import CriticalSuccessIndex
        >>> x = torch.Tensor([[0.2, 0.7], [0.9, 0.3]])
        >>> y = torch.Tensor([[0.4, 0.2], [0.8, 0.6]])
        >>> csi = CriticalSuccessIndex(0.5)
        >>> csi(x, y)
        tensor(0.3333)

    Example:
        >>> import torch
        >>> from torchmetrics.regression import CriticalSuccessIndex
        >>> x = torch.Tensor([[[0.2, 0.7], [0.9, 0.3]], [[0.2, 0.7], [0.9, 0.3]]])
        >>> y = torch.Tensor([[[0.4, 0.2], [0.8, 0.6]], [[0.4, 0.2], [0.8, 0.6]]])
        >>> csi = CriticalSuccessIndex(0.5, keep_sequence_dim=0)
        >>> csi(x, y)
        tensor([0.3333, 0.3333])

    """

    is_differentiable: bool = False
    higher_is_better: bool = True

    hits: Tensor
    misses: Tensor
    false_alarms: Tensor
    hits_list: List[Tensor]
    misses_list: List[Tensor]
    false_alarms_list: List[Tensor]

    def __init__(self, threshold: float, keep_sequence_dim: Optional[int] = None, **kwargs: Any) -> None:
        super().__init__(**kwargs)
        self.threshold = float(threshold)

        if keep_sequence_dim and (not isinstance(keep_sequence_dim, int) or keep_sequence_dim < 0):
            raise ValueError(f"Expected keep_sequence_dim to be a non-negative integer but got {keep_sequence_dim}")
        self.keep_sequence_dim = keep_sequence_dim

        if keep_sequence_dim is None:
            self.add_state("hits", default=torch.tensor(0), dist_reduce_fx="sum")
            self.add_state("misses", default=torch.tensor(0), dist_reduce_fx="sum")
            self.add_state("false_alarms", default=torch.tensor(0), dist_reduce_fx="sum")
        else:
            self.add_state("hits_list", default=[], dist_reduce_fx="cat")
            self.add_state("misses_list", default=[], dist_reduce_fx="cat")
            self.add_state("false_alarms_list", default=[], dist_reduce_fx="cat")

    def update(self, preds: Tensor, target: Tensor) -> None:
        """Update state with predictions and targets."""
        hits, misses, false_alarms = _critical_success_index_update(
            preds, target, self.threshold, self.keep_sequence_dim
        )
        if self.keep_sequence_dim is None:
            self.hits += hits
            self.misses += misses
            self.false_alarms += false_alarms
        else:
            self.hits_list.append(hits)
            self.misses_list.append(misses)
            self.false_alarms_list.append(false_alarms)

    def compute(self) -> Tensor:
        """Compute critical success index over state."""
        if self.keep_sequence_dim is None:
            hits = self.hits
            misses = self.misses
            false_alarms = self.false_alarms
        else:
            hits = dim_zero_cat(self.hits_list)
            misses = dim_zero_cat(self.misses_list)
            false_alarms = dim_zero_cat(self.false_alarms_list)
        return _critical_success_index_compute(hits, misses, false_alarms)