jamtur01's picture
Upload folder using huggingface_hub
9c6594c verified
# Copyright The Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
import torch
from torch import Tensor
from typing_extensions import Literal
from torchmetrics.utilities.prints import rank_zero_warn
def _nominal_input_validation(nan_strategy: str, nan_replace_value: Optional[float]) -> None:
if nan_strategy not in ["replace", "drop"]:
raise ValueError(
f"Argument `nan_strategy` is expected to be one of `['replace', 'drop']`, but got {nan_strategy}"
)
if nan_strategy == "replace" and not isinstance(nan_replace_value, (float, int)):
raise ValueError(
"Argument `nan_replace` is expected to be of a type `int` or `float` when `nan_strategy = 'replace`, "
f"but got {nan_replace_value}"
)
def _compute_expected_freqs(confmat: Tensor) -> Tensor:
"""Compute the expected frequenceis from the provided confusion matrix."""
margin_sum_rows, margin_sum_cols = confmat.sum(1), confmat.sum(0)
return torch.einsum("r, c -> rc", margin_sum_rows, margin_sum_cols) / confmat.sum()
def _compute_chi_squared(confmat: Tensor, bias_correction: bool) -> Tensor:
"""Chi-square test of independenc of variables in a confusion matrix table.
Adapted from: https://github.com/scipy/scipy/blob/v1.9.2/scipy/stats/contingency.py.
"""
expected_freqs = _compute_expected_freqs(confmat)
# Get degrees of freedom
df = expected_freqs.numel() - sum(expected_freqs.shape) + expected_freqs.ndim - 1
if df == 0:
return torch.tensor(0.0, device=confmat.device)
if df == 1 and bias_correction:
diff = expected_freqs - confmat
direction = diff.sign()
confmat += direction * torch.minimum(0.5 * torch.ones_like(direction), direction.abs())
return torch.sum((confmat - expected_freqs) ** 2 / expected_freqs)
def _drop_empty_rows_and_cols(confmat: Tensor) -> Tensor:
"""Drop all rows and columns containing only zeros.
Example:
>>> from torch import randint
>>> from torchmetrics.functional.nominal.utils import _drop_empty_rows_and_cols
>>> matrix = randint(10, size=(4, 3))
>>> matrix[1, :] = matrix[:, 1] = 0
>>> matrix
tensor([[2, 0, 6],
[0, 0, 0],
[0, 0, 0],
[3, 0, 4]])
>>> _drop_empty_rows_and_cols(matrix)
tensor([[2, 6],
[3, 4]])
"""
confmat = confmat[confmat.sum(1) != 0]
return confmat[:, confmat.sum(0) != 0]
def _compute_phi_squared_corrected(
phi_squared: Tensor,
num_rows: int,
num_cols: int,
confmat_sum: Tensor,
) -> Tensor:
"""Compute bias-corrected Phi Squared."""
return torch.max(
torch.tensor(0.0, device=phi_squared.device),
phi_squared - ((num_rows - 1) * (num_cols - 1)) / (confmat_sum - 1),
)
def _compute_rows_and_cols_corrected(num_rows: int, num_cols: int, confmat_sum: Tensor) -> tuple[Tensor, Tensor]:
"""Compute bias-corrected number of rows and columns."""
rows_corrected = num_rows - (num_rows - 1) ** 2 / (confmat_sum - 1)
cols_corrected = num_cols - (num_cols - 1) ** 2 / (confmat_sum - 1)
return rows_corrected, cols_corrected
def _compute_bias_corrected_values(
phi_squared: Tensor, num_rows: int, num_cols: int, confmat_sum: Tensor
) -> tuple[Tensor, Tensor, Tensor]:
"""Compute bias-corrected Phi Squared and number of rows and columns."""
phi_squared_corrected = _compute_phi_squared_corrected(phi_squared, num_rows, num_cols, confmat_sum)
rows_corrected, cols_corrected = _compute_rows_and_cols_corrected(num_rows, num_cols, confmat_sum)
return phi_squared_corrected, rows_corrected, cols_corrected
def _handle_nan_in_data(
preds: Tensor,
target: Tensor,
nan_strategy: Literal["replace", "drop"] = "replace",
nan_replace_value: Optional[float] = 0.0,
) -> tuple[Tensor, Tensor]:
"""Handle ``NaN`` values in input data.
If ``nan_strategy = 'replace'``, all ``NaN`` values are replaced with ``nan_replace_value``.
If ``nan_strategy = 'drop'``, all rows containing ``NaN`` in any of two vectors are dropped.
Args:
preds: 1D tensor of categorical (nominal) data
target: 1D tensor of categorical (nominal) data
nan_strategy: Indication of whether to replace or drop ``NaN`` values
nan_replace_value: Value to replace ``NaN`s when ``nan_strategy = 'replace```
Returns:
Updated ``preds`` and ``target`` tensors which contain no ``Nan``
Raises:
ValueError: If ``nan_strategy`` is not from ``['replace', 'drop']``.
ValueError: If ``nan_strategy = replace`` and ``nan_replace_value`` is not of a type ``int`` or ``float``.
"""
if nan_strategy == "replace":
return preds.nan_to_num(nan_replace_value), target.nan_to_num(nan_replace_value)
rows_contain_nan = torch.logical_or(preds.isnan(), target.isnan())
return preds[~rows_contain_nan], target[~rows_contain_nan]
def _unable_to_use_bias_correction_warning(metric_name: str) -> None:
rank_zero_warn(
f"Unable to compute {metric_name} using bias correction. Please consider to set `bias_correction=False`."
)