|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Any, List, Optional |
|
|
|
import torch |
|
from torch import Tensor |
|
|
|
from torchmetrics.functional.regression.csi import _critical_success_index_compute, _critical_success_index_update |
|
from torchmetrics.metric import Metric |
|
from torchmetrics.utilities import dim_zero_cat |
|
|
|
|
|
class CriticalSuccessIndex(Metric): |
|
r"""Calculate critical success index (CSI). |
|
|
|
Critical success index (also known as the threat score) is a statistic used weather forecasting that measures |
|
forecast performance over inputs binarized at a specified threshold. It is defined as: |
|
|
|
.. math:: \text{CSI} = \frac{\text{TP}}{\text{TP}+\text{FN}+\text{FP}} |
|
|
|
Where :math:`\text{TP}`, :math:`\text{FN}` and :math:`\text{FP}` represent the number of true positives, false |
|
negatives and false positives respectively after binarizing the input tensors. |
|
|
|
Args: |
|
threshold: Values above or equal to threshold are replaced with 1, below by 0 |
|
keep_sequence_dim: Index of the sequence dimension if the inputs are sequences of images. If specified, |
|
the score will be calculated separately for each image in the sequence. If ``None``, the score will be |
|
calculated across all dimensions. |
|
|
|
Example: |
|
>>> import torch |
|
>>> from torchmetrics.regression import CriticalSuccessIndex |
|
>>> x = torch.Tensor([[0.2, 0.7], [0.9, 0.3]]) |
|
>>> y = torch.Tensor([[0.4, 0.2], [0.8, 0.6]]) |
|
>>> csi = CriticalSuccessIndex(0.5) |
|
>>> csi(x, y) |
|
tensor(0.3333) |
|
|
|
Example: |
|
>>> import torch |
|
>>> from torchmetrics.regression import CriticalSuccessIndex |
|
>>> x = torch.Tensor([[[0.2, 0.7], [0.9, 0.3]], [[0.2, 0.7], [0.9, 0.3]]]) |
|
>>> y = torch.Tensor([[[0.4, 0.2], [0.8, 0.6]], [[0.4, 0.2], [0.8, 0.6]]]) |
|
>>> csi = CriticalSuccessIndex(0.5, keep_sequence_dim=0) |
|
>>> csi(x, y) |
|
tensor([0.3333, 0.3333]) |
|
|
|
""" |
|
|
|
is_differentiable: bool = False |
|
higher_is_better: bool = True |
|
|
|
hits: Tensor |
|
misses: Tensor |
|
false_alarms: Tensor |
|
hits_list: List[Tensor] |
|
misses_list: List[Tensor] |
|
false_alarms_list: List[Tensor] |
|
|
|
def __init__(self, threshold: float, keep_sequence_dim: Optional[int] = None, **kwargs: Any) -> None: |
|
super().__init__(**kwargs) |
|
self.threshold = float(threshold) |
|
|
|
if keep_sequence_dim and (not isinstance(keep_sequence_dim, int) or keep_sequence_dim < 0): |
|
raise ValueError(f"Expected keep_sequence_dim to be a non-negative integer but got {keep_sequence_dim}") |
|
self.keep_sequence_dim = keep_sequence_dim |
|
|
|
if keep_sequence_dim is None: |
|
self.add_state("hits", default=torch.tensor(0), dist_reduce_fx="sum") |
|
self.add_state("misses", default=torch.tensor(0), dist_reduce_fx="sum") |
|
self.add_state("false_alarms", default=torch.tensor(0), dist_reduce_fx="sum") |
|
else: |
|
self.add_state("hits_list", default=[], dist_reduce_fx="cat") |
|
self.add_state("misses_list", default=[], dist_reduce_fx="cat") |
|
self.add_state("false_alarms_list", default=[], dist_reduce_fx="cat") |
|
|
|
def update(self, preds: Tensor, target: Tensor) -> None: |
|
"""Update state with predictions and targets.""" |
|
hits, misses, false_alarms = _critical_success_index_update( |
|
preds, target, self.threshold, self.keep_sequence_dim |
|
) |
|
if self.keep_sequence_dim is None: |
|
self.hits += hits |
|
self.misses += misses |
|
self.false_alarms += false_alarms |
|
else: |
|
self.hits_list.append(hits) |
|
self.misses_list.append(misses) |
|
self.false_alarms_list.append(false_alarms) |
|
|
|
def compute(self) -> Tensor: |
|
"""Compute critical success index over state.""" |
|
if self.keep_sequence_dim is None: |
|
hits = self.hits |
|
misses = self.misses |
|
false_alarms = self.false_alarms |
|
else: |
|
hits = dim_zero_cat(self.hits_list) |
|
misses = dim_zero_cat(self.misses_list) |
|
false_alarms = dim_zero_cat(self.false_alarms_list) |
|
return _critical_success_index_compute(hits, misses, false_alarms) |
|
|