File size: 5,745 Bytes
9c6594c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 |
# Copyright The Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
import torch
from torch import Tensor
from typing_extensions import Literal
from torchmetrics.utilities.prints import rank_zero_warn
def _nominal_input_validation(nan_strategy: str, nan_replace_value: Optional[float]) -> None:
if nan_strategy not in ["replace", "drop"]:
raise ValueError(
f"Argument `nan_strategy` is expected to be one of `['replace', 'drop']`, but got {nan_strategy}"
)
if nan_strategy == "replace" and not isinstance(nan_replace_value, (float, int)):
raise ValueError(
"Argument `nan_replace` is expected to be of a type `int` or `float` when `nan_strategy = 'replace`, "
f"but got {nan_replace_value}"
)
def _compute_expected_freqs(confmat: Tensor) -> Tensor:
"""Compute the expected frequenceis from the provided confusion matrix."""
margin_sum_rows, margin_sum_cols = confmat.sum(1), confmat.sum(0)
return torch.einsum("r, c -> rc", margin_sum_rows, margin_sum_cols) / confmat.sum()
def _compute_chi_squared(confmat: Tensor, bias_correction: bool) -> Tensor:
"""Chi-square test of independenc of variables in a confusion matrix table.
Adapted from: https://github.com/scipy/scipy/blob/v1.9.2/scipy/stats/contingency.py.
"""
expected_freqs = _compute_expected_freqs(confmat)
# Get degrees of freedom
df = expected_freqs.numel() - sum(expected_freqs.shape) + expected_freqs.ndim - 1
if df == 0:
return torch.tensor(0.0, device=confmat.device)
if df == 1 and bias_correction:
diff = expected_freqs - confmat
direction = diff.sign()
confmat += direction * torch.minimum(0.5 * torch.ones_like(direction), direction.abs())
return torch.sum((confmat - expected_freqs) ** 2 / expected_freqs)
def _drop_empty_rows_and_cols(confmat: Tensor) -> Tensor:
"""Drop all rows and columns containing only zeros.
Example:
>>> from torch import randint
>>> from torchmetrics.functional.nominal.utils import _drop_empty_rows_and_cols
>>> matrix = randint(10, size=(4, 3))
>>> matrix[1, :] = matrix[:, 1] = 0
>>> matrix
tensor([[2, 0, 6],
[0, 0, 0],
[0, 0, 0],
[3, 0, 4]])
>>> _drop_empty_rows_and_cols(matrix)
tensor([[2, 6],
[3, 4]])
"""
confmat = confmat[confmat.sum(1) != 0]
return confmat[:, confmat.sum(0) != 0]
def _compute_phi_squared_corrected(
phi_squared: Tensor,
num_rows: int,
num_cols: int,
confmat_sum: Tensor,
) -> Tensor:
"""Compute bias-corrected Phi Squared."""
return torch.max(
torch.tensor(0.0, device=phi_squared.device),
phi_squared - ((num_rows - 1) * (num_cols - 1)) / (confmat_sum - 1),
)
def _compute_rows_and_cols_corrected(num_rows: int, num_cols: int, confmat_sum: Tensor) -> tuple[Tensor, Tensor]:
"""Compute bias-corrected number of rows and columns."""
rows_corrected = num_rows - (num_rows - 1) ** 2 / (confmat_sum - 1)
cols_corrected = num_cols - (num_cols - 1) ** 2 / (confmat_sum - 1)
return rows_corrected, cols_corrected
def _compute_bias_corrected_values(
phi_squared: Tensor, num_rows: int, num_cols: int, confmat_sum: Tensor
) -> tuple[Tensor, Tensor, Tensor]:
"""Compute bias-corrected Phi Squared and number of rows and columns."""
phi_squared_corrected = _compute_phi_squared_corrected(phi_squared, num_rows, num_cols, confmat_sum)
rows_corrected, cols_corrected = _compute_rows_and_cols_corrected(num_rows, num_cols, confmat_sum)
return phi_squared_corrected, rows_corrected, cols_corrected
def _handle_nan_in_data(
preds: Tensor,
target: Tensor,
nan_strategy: Literal["replace", "drop"] = "replace",
nan_replace_value: Optional[float] = 0.0,
) -> tuple[Tensor, Tensor]:
"""Handle ``NaN`` values in input data.
If ``nan_strategy = 'replace'``, all ``NaN`` values are replaced with ``nan_replace_value``.
If ``nan_strategy = 'drop'``, all rows containing ``NaN`` in any of two vectors are dropped.
Args:
preds: 1D tensor of categorical (nominal) data
target: 1D tensor of categorical (nominal) data
nan_strategy: Indication of whether to replace or drop ``NaN`` values
nan_replace_value: Value to replace ``NaN`s when ``nan_strategy = 'replace```
Returns:
Updated ``preds`` and ``target`` tensors which contain no ``Nan``
Raises:
ValueError: If ``nan_strategy`` is not from ``['replace', 'drop']``.
ValueError: If ``nan_strategy = replace`` and ``nan_replace_value`` is not of a type ``int`` or ``float``.
"""
if nan_strategy == "replace":
return preds.nan_to_num(nan_replace_value), target.nan_to_num(nan_replace_value)
rows_contain_nan = torch.logical_or(preds.isnan(), target.isnan())
return preds[~rows_contain_nan], target[~rows_contain_nan]
def _unable_to_use_bias_correction_warning(metric_name: str) -> None:
rank_zero_warn(
f"Unable to compute {metric_name} using bias correction. Please consider to set `bias_correction=False`."
)
|