jamtur01's picture
Upload folder using huggingface_hub
9c6594c verified
# Copyright The Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional, Union
import torch
from torch import Tensor
from typing_extensions import Literal
from torchmetrics.functional.regression.utils import _check_data_shape_to_num_outputs
from torchmetrics.utilities.checks import _check_same_shape
from torchmetrics.utilities.data import _bincount, _cumsum, dim_zero_cat
from torchmetrics.utilities.enums import EnumStr
class _MetricVariant(EnumStr):
"""Enumerate for metric variants."""
A = "a"
B = "b"
C = "c"
@staticmethod
def _name() -> str:
return "variant"
class _TestAlternative(EnumStr):
"""Enumerate for test alternative options."""
TWO_SIDED = "two-sided"
LESS = "less"
GREATER = "greater"
@staticmethod
def _name() -> str:
return "alternative"
def _sort_on_first_sequence(x: Tensor, y: Tensor) -> tuple[Tensor, Tensor]:
"""Sort sequences in an ascent order according to the sequence ``x``."""
# We need to clone `y` tensor not to change an object in memory
y = torch.clone(y)
x, y = x.T, y.T
x, perm = x.sort()
for i in range(x.shape[0]):
y[i] = y[i][perm[i]]
return x.T, y.T
def _concordant_element_sum(x: Tensor, y: Tensor, i: int) -> Tensor:
"""Count a total number of concordant pairs in a single sequence."""
return torch.logical_and(x[i] < x[(i + 1) :], y[i] < y[(i + 1) :]).sum(0).unsqueeze(0)
def _count_concordant_pairs(preds: Tensor, target: Tensor) -> Tensor:
"""Count a total number of concordant pairs in given sequences."""
return torch.cat([_concordant_element_sum(preds, target, i) for i in range(preds.shape[0])]).sum(0)
def _discordant_element_sum(x: Tensor, y: Tensor, i: int) -> Tensor:
"""Count a total number of discordant pairs in a single sequences."""
return (
torch.logical_or(
torch.logical_and(x[i] > x[(i + 1) :], y[i] < y[(i + 1) :]),
torch.logical_and(x[i] < x[(i + 1) :], y[i] > y[(i + 1) :]),
)
.sum(0)
.unsqueeze(0)
)
def _count_discordant_pairs(preds: Tensor, target: Tensor) -> Tensor:
"""Count a total number of discordant pairs in given sequences."""
return torch.cat([_discordant_element_sum(preds, target, i) for i in range(preds.shape[0])]).sum(0)
def _convert_sequence_to_dense_rank(x: Tensor, sort: bool = False) -> Tensor:
"""Convert a sequence to the rank tensor."""
# Sort if a sequence has not been sorted before
if sort:
x = x.sort(dim=0).values
_ones = torch.zeros(1, x.shape[1], dtype=torch.int32, device=x.device)
return _cumsum(torch.cat([_ones, (x[1:] != x[:-1]).int()], dim=0), dim=0)
def _get_ties(x: Tensor) -> tuple[Tensor, Tensor, Tensor]:
"""Get a total number of ties and staistics for p-value calculation for a given sequence."""
ties = torch.zeros(x.shape[1], dtype=x.dtype, device=x.device)
ties_p1 = torch.zeros(x.shape[1], dtype=x.dtype, device=x.device)
ties_p2 = torch.zeros(x.shape[1], dtype=x.dtype, device=x.device)
for dim in range(x.shape[1]):
n_ties = _bincount(x[:, dim])
n_ties = n_ties[n_ties > 1]
ties[dim] = (n_ties * (n_ties - 1) // 2).sum()
ties_p1[dim] = (n_ties * (n_ties - 1.0) * (n_ties - 2)).sum()
ties_p2[dim] = (n_ties * (n_ties - 1.0) * (2 * n_ties + 5)).sum()
return ties, ties_p1, ties_p2
def _get_metric_metadata(
preds: Tensor, target: Tensor, variant: _MetricVariant
) -> tuple[
Tensor,
Tensor,
Optional[Tensor],
Optional[Tensor],
Optional[Tensor],
Optional[Tensor],
Optional[Tensor],
Optional[Tensor],
Tensor,
]:
"""Obtain statistics to calculate metric value."""
preds, target = _sort_on_first_sequence(preds, target)
concordant_pairs = _count_concordant_pairs(preds, target)
discordant_pairs = _count_discordant_pairs(preds, target)
n_total = torch.tensor(preds.shape[0], device=preds.device)
preds_ties = target_ties = None
preds_ties_p1 = preds_ties_p2 = target_ties_p1 = target_ties_p2 = None
if variant != _MetricVariant.A:
preds = _convert_sequence_to_dense_rank(preds)
target = _convert_sequence_to_dense_rank(target, sort=True)
preds_ties, preds_ties_p1, preds_ties_p2 = _get_ties(preds)
target_ties, target_ties_p1, target_ties_p2 = _get_ties(target)
return (
concordant_pairs,
discordant_pairs,
preds_ties,
preds_ties_p1,
preds_ties_p2,
target_ties,
target_ties_p1,
target_ties_p2,
n_total,
)
def _calculate_tau(
preds: Tensor,
target: Tensor,
concordant_pairs: Tensor,
discordant_pairs: Tensor,
con_min_dis_pairs: Tensor,
n_total: Tensor,
preds_ties: Optional[Tensor],
target_ties: Optional[Tensor],
variant: _MetricVariant,
) -> Tensor:
"""Calculate Kendall's tau from metric metadata."""
if variant == _MetricVariant.A:
return con_min_dis_pairs / (concordant_pairs + discordant_pairs)
if variant == _MetricVariant.B:
total_combinations: Tensor = n_total * (n_total - 1) // 2
if preds_ties is None:
preds_ties = torch.tensor(0.0, dtype=total_combinations.dtype, device=total_combinations.device)
if target_ties is None:
target_ties = torch.tensor(0.0, dtype=total_combinations.dtype, device=total_combinations.device)
denominator = (total_combinations - preds_ties) * (total_combinations - target_ties)
return con_min_dis_pairs / torch.sqrt(denominator)
preds_unique = torch.tensor([len(p.unique()) for p in preds.T], dtype=preds.dtype, device=preds.device)
target_unique = torch.tensor([len(t.unique()) for t in target.T], dtype=target.dtype, device=target.device)
min_classes = torch.minimum(preds_unique, target_unique)
return 2 * con_min_dis_pairs / ((min_classes - 1) / min_classes * n_total**2)
def _get_p_value_for_t_value_from_dist(t_value: Tensor) -> Tensor:
"""Obtain p-value for a given Tensor of t-values. Handle ``nan`` which cannot be passed into torch distributions.
When t-value is ``nan``, a resulted p-value should be alson ``nan``.
"""
device = t_value
normal_dist = torch.distributions.normal.Normal(torch.tensor([0.0]).to(device), torch.tensor([1.0]).to(device))
is_nan = t_value.isnan()
t_value = t_value.nan_to_num()
p_value = normal_dist.cdf(t_value)
return p_value.where(~is_nan, torch.tensor(float("nan"), dtype=p_value.dtype, device=p_value.device))
def _calculate_p_value(
con_min_dis_pairs: Tensor,
n_total: Tensor,
preds_ties: Optional[Tensor],
preds_ties_p1: Optional[Tensor],
preds_ties_p2: Optional[Tensor],
target_ties: Optional[Tensor],
target_ties_p1: Optional[Tensor],
target_ties_p2: Optional[Tensor],
variant: _MetricVariant,
alternative: Optional[_TestAlternative],
) -> Tensor:
"""Calculate p-value for Kendall's tau from metric metadata."""
t_value_denominator_base = n_total * (n_total - 1) * (2 * n_total + 5)
if variant == _MetricVariant.A:
t_value = 3 * con_min_dis_pairs / torch.sqrt(t_value_denominator_base / 2)
else:
m = n_total * (n_total - 1)
t_value_denominator: Tensor = (
t_value_denominator_base
- (preds_ties_p2 if preds_ties_p2 is not None else 0)
- (target_ties_p2 if target_ties_p2 is not None else 0)
) / 18
t_value_denominator += (
2 * (preds_ties if preds_ties is not None else 0) * (target_ties if target_ties is not None else 0)
) / m
t_value_denominator += (
(preds_ties_p1 if preds_ties_p1 is not None else 0)
* (target_ties_p1 if target_ties_p1 is not None else 0)
/ (9 * m * (n_total - 2))
)
t_value = con_min_dis_pairs / torch.sqrt(t_value_denominator)
if alternative == _TestAlternative.TWO_SIDED:
t_value = torch.abs(t_value)
if alternative in [_TestAlternative.TWO_SIDED, _TestAlternative.GREATER]:
t_value *= -1
p_value = _get_p_value_for_t_value_from_dist(t_value)
if alternative == _TestAlternative.TWO_SIDED:
p_value *= 2
return p_value
def _kendall_corrcoef_update(
preds: Tensor,
target: Tensor,
concat_preds: Optional[List[Tensor]] = None,
concat_target: Optional[List[Tensor]] = None,
num_outputs: int = 1,
) -> tuple[List[Tensor], List[Tensor]]:
"""Update variables required to compute Kendall rank correlation coefficient.
Args:
preds: Sequence of data
target: Sequence of data
concat_preds: List of batches of preds sequence to be concatenated
concat_target: List of batches of target sequence to be concatenated
num_outputs: Number of outputs in multioutput setting
Raises:
RuntimeError: If ``preds`` and ``target`` do not have the same shape
"""
concat_preds = concat_preds or []
concat_target = concat_target or []
# Data checking
_check_same_shape(preds, target)
_check_data_shape_to_num_outputs(preds, target, num_outputs)
if num_outputs == 1:
preds = preds.unsqueeze(1)
target = target.unsqueeze(1)
concat_preds.append(preds)
concat_target.append(target)
return concat_preds, concat_target
def _kendall_corrcoef_compute(
preds: Tensor,
target: Tensor,
variant: _MetricVariant,
alternative: Optional[_TestAlternative] = None,
) -> tuple[Tensor, Optional[Tensor]]:
"""Compute Kendall rank correlation coefficient, and optionally p-value of corresponding statistical test.
Args:
Args:
preds: Sequence of data
target: Sequence of data
variant: Indication of which variant of Kendall's tau to be used
alternative: Alternative hypothesis for for t-test. Possible values:
- 'two-sided': the rank correlation is nonzero
- 'less': the rank correlation is negative (less than zero)
- 'greater': the rank correlation is positive (greater than zero)
"""
(
concordant_pairs,
discordant_pairs,
preds_ties,
preds_ties_p1,
preds_ties_p2,
target_ties,
target_ties_p1,
target_ties_p2,
n_total,
) = _get_metric_metadata(preds, target, variant)
con_min_dis_pairs = concordant_pairs - discordant_pairs
tau = _calculate_tau(
preds, target, concordant_pairs, discordant_pairs, con_min_dis_pairs, n_total, preds_ties, target_ties, variant
)
p_value = (
_calculate_p_value(
con_min_dis_pairs,
n_total,
preds_ties,
preds_ties_p1,
preds_ties_p2,
target_ties,
target_ties_p1,
target_ties_p2,
variant,
alternative,
)
if alternative
else None
)
# Squeeze tensor if num_outputs=1
if tau.shape[0] == 1:
tau = tau.squeeze()
p_value = p_value.squeeze() if p_value is not None else None
return tau.clamp(-1, 1), p_value
def kendall_rank_corrcoef(
preds: Tensor,
target: Tensor,
variant: Literal["a", "b", "c"] = "b",
t_test: bool = False,
alternative: Optional[Literal["two-sided", "less", "greater"]] = "two-sided",
) -> Union[Tensor, tuple[Tensor, Tensor]]:
r"""Compute `Kendall Rank Correlation Coefficient`_.
.. math::
tau_a = \frac{C - D}{C + D}
where :math:`C` represents concordant pairs, :math:`D` stands for discordant pairs.
.. math::
tau_b = \frac{C - D}{\sqrt{(C + D + T_{preds}) * (C + D + T_{target})}}
where :math:`C` represents concordant pairs, :math:`D` stands for discordant pairs and :math:`T` represents
a total number of ties.
.. math::
tau_c = 2 * \frac{C - D}{n^2 * \frac{m - 1}{m}}
where :math:`C` represents concordant pairs, :math:`D` stands for discordant pairs, :math:`n` is a total number
of observations and :math:`m` is a ``min`` of unique values in ``preds`` and ``target`` sequence.
Definitions according to Definition according to `The Treatment of Ties in Ranking Problems`_.
Args:
preds: Sequence of data of either shape ``(N,)`` or ``(N,d)``
target: Sequence of data of either shape ``(N,)`` or ``(N,d)``
variant: Indication of which variant of Kendall's tau to be used
t_test: Indication whether to run t-test
alternative: Alternative hypothesis for t-test. Possible values:
- 'two-sided': the rank correlation is nonzero
- 'less': the rank correlation is negative (less than zero)
- 'greater': the rank correlation is positive (greater than zero)
Return:
Correlation tau statistic
(Optional) p-value of corresponding statistical test (asymptotic)
Raises:
ValueError: If ``t_test`` is not of a type bool
ValueError: If ``t_test=True`` and ``alternative=None``
Example (single output regression):
>>> from torchmetrics.functional.regression import kendall_rank_corrcoef
>>> preds = torch.tensor([2.5, 0.0, 2, 8])
>>> target = torch.tensor([3, -0.5, 2, 1])
>>> kendall_rank_corrcoef(preds, target)
tensor(0.3333)
Example (multi output regression):
>>> from torchmetrics.functional.regression import kendall_rank_corrcoef
>>> preds = torch.tensor([[2.5, 0.0], [2, 8]])
>>> target = torch.tensor([[3, -0.5], [2, 1]])
>>> kendall_rank_corrcoef(preds, target)
tensor([1., 1.])
Example (single output regression with t-test)
>>> from torchmetrics.functional.regression import kendall_rank_corrcoef
>>> preds = torch.tensor([2.5, 0.0, 2, 8])
>>> target = torch.tensor([3, -0.5, 2, 1])
>>> kendall_rank_corrcoef(preds, target, t_test=True, alternative='two-sided')
(tensor(0.3333), tensor(0.4969))
Example (multi output regression with t-test):
>>> from torchmetrics.functional.regression import kendall_rank_corrcoef
>>> preds = torch.tensor([[2.5, 0.0], [2, 8]])
>>> target = torch.tensor([[3, -0.5], [2, 1]])
>>> kendall_rank_corrcoef(preds, target, t_test=True, alternative='two-sided')
(tensor([1., 1.]), tensor([nan, nan]))
"""
if not isinstance(t_test, bool):
raise ValueError(f"Argument `t_test` is expected to be of a type `bool`, but got {type(t_test)}.")
if t_test and alternative is None:
raise ValueError("Argument `alternative` is required if `t_test=True` but got `None`.")
_variant = _MetricVariant.from_str(str(variant))
_alternative = _TestAlternative.from_str(str(alternative)) if t_test else None
_preds, _target = _kendall_corrcoef_update(
preds, target, [], [], num_outputs=1 if preds.ndim == 1 else preds.shape[-1]
)
tau, p_value = _kendall_corrcoef_compute(
dim_zero_cat(_preds),
dim_zero_cat(_target),
_variant, # type: ignore[arg-type] # todo
_alternative, # type: ignore[arg-type] # todo
)
if p_value is not None:
return tau, p_value
return tau