|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import sys |
|
from collections.abc import Sequence |
|
from typing import Any, List, Optional, Union |
|
|
|
import torch |
|
from lightning_utilities import apply_to_collection |
|
from torch import Tensor |
|
|
|
from torchmetrics.utilities.exceptions import TorchMetricsUserWarning |
|
from torchmetrics.utilities.imports import _TORCH_LESS_THAN_2_6, _XLA_AVAILABLE |
|
from torchmetrics.utilities.prints import rank_zero_warn |
|
|
|
METRIC_EPS = 1e-6 |
|
|
|
|
|
def dim_zero_cat(x: Union[Tensor, List[Tensor]]) -> Tensor: |
|
"""Concatenation along the zero dimension.""" |
|
if isinstance(x, torch.Tensor): |
|
return x |
|
x = [y.unsqueeze(0) if y.numel() == 1 and y.ndim == 0 else y for y in x] |
|
if not x: |
|
raise ValueError("No samples to concatenate") |
|
return torch.cat(x, dim=0) |
|
|
|
|
|
def dim_zero_sum(x: Tensor) -> Tensor: |
|
"""Summation along the zero dimension.""" |
|
return torch.sum(x, dim=0) |
|
|
|
|
|
def dim_zero_mean(x: Tensor) -> Tensor: |
|
"""Average along the zero dimension.""" |
|
return torch.mean(x, dim=0) |
|
|
|
|
|
def dim_zero_max(x: Tensor) -> Tensor: |
|
"""Max along the zero dimension.""" |
|
return torch.max(x, dim=0).values |
|
|
|
|
|
def dim_zero_min(x: Tensor) -> Tensor: |
|
"""Min along the zero dimension.""" |
|
return torch.min(x, dim=0).values |
|
|
|
|
|
def _flatten(x: Sequence) -> list: |
|
"""Flatten list of list into single list.""" |
|
return [item for sublist in x for item in sublist] |
|
|
|
|
|
def _flatten_dict(x: dict) -> tuple[dict, bool]: |
|
"""Flatten dict of dicts into single dict and checking for duplicates in keys along the way.""" |
|
new_dict = {} |
|
duplicates = False |
|
for key, value in x.items(): |
|
if isinstance(value, dict): |
|
for k, v in value.items(): |
|
if k in new_dict: |
|
duplicates = True |
|
new_dict[k] = v |
|
else: |
|
if key in new_dict: |
|
duplicates = True |
|
new_dict[key] = value |
|
return new_dict, duplicates |
|
|
|
|
|
def to_onehot( |
|
label_tensor: Tensor, |
|
num_classes: Optional[int] = None, |
|
) -> Tensor: |
|
"""Convert a dense label tensor to one-hot format. |
|
|
|
Args: |
|
label_tensor: dense label tensor, with shape [N, d1, d2, ...] |
|
num_classes: number of classes C |
|
|
|
Returns: |
|
A sparse label tensor with shape [N, C, d1, d2, ...] |
|
|
|
Example: |
|
>>> x = torch.tensor([1, 2, 3]) |
|
>>> to_onehot(x) |
|
tensor([[0, 1, 0, 0], |
|
[0, 0, 1, 0], |
|
[0, 0, 0, 1]]) |
|
|
|
""" |
|
if num_classes is None: |
|
num_classes = int(label_tensor.max().detach().item() + 1) |
|
|
|
tensor_onehot = torch.zeros( |
|
label_tensor.shape[0], |
|
num_classes, |
|
*label_tensor.shape[1:], |
|
dtype=label_tensor.dtype, |
|
device=label_tensor.device, |
|
) |
|
index = label_tensor.long().unsqueeze(1).expand_as(tensor_onehot) |
|
return tensor_onehot.scatter_(1, index, 1.0) |
|
|
|
|
|
def _top_k_with_half_precision_support(x: Tensor, k: int = 1, dim: int = 1) -> Tensor: |
|
"""torch.top_k does not support half precision on CPU.""" |
|
if x.dtype == torch.half and not x.is_cuda: |
|
idx = torch.argsort(x, dim=dim, stable=True).flip(dim) |
|
return idx.narrow(dim, 0, k) |
|
return x.topk(k=k, dim=dim).indices |
|
|
|
|
|
def select_topk(prob_tensor: Tensor, topk: int = 1, dim: int = 1) -> Tensor: |
|
"""Convert a probability tensor to binary by selecting top-k the highest entries. |
|
|
|
Args: |
|
prob_tensor: dense tensor of shape ``[..., C, ...]``, where ``C`` is in the |
|
position defined by the ``dim`` argument |
|
topk: number of the highest entries to turn into 1s |
|
dim: dimension on which to compare entries |
|
|
|
Returns: |
|
A binary tensor of the same shape as the input tensor of type ``torch.int32`` |
|
|
|
Example: |
|
>>> x = torch.tensor([[1.1, 2.0, 3.0], [2.0, 1.0, 0.5]]) |
|
>>> select_topk(x, topk=2) |
|
tensor([[0, 1, 1], |
|
[1, 1, 0]], dtype=torch.int32) |
|
|
|
""" |
|
topk_tensor = torch.zeros_like(prob_tensor, dtype=torch.int) |
|
if topk == 1: |
|
topk_tensor.scatter_(dim, prob_tensor.argmax(dim=dim, keepdim=True), 1.0) |
|
else: |
|
topk_tensor.scatter_(dim, _top_k_with_half_precision_support(prob_tensor, k=topk, dim=dim), 1.0) |
|
return topk_tensor.int() |
|
|
|
|
|
def to_categorical(x: Tensor, argmax_dim: int = 1) -> Tensor: |
|
"""Convert a tensor of probabilities to a dense label tensor. |
|
|
|
Args: |
|
x: probabilities to get the categorical label [N, d1, d2, ...] |
|
argmax_dim: dimension to apply |
|
|
|
Return: |
|
A tensor with categorical labels [N, d2, ...] |
|
|
|
Example: |
|
>>> x = torch.tensor([[0.2, 0.5], [0.9, 0.1]]) |
|
>>> to_categorical(x) |
|
tensor([1, 0]) |
|
|
|
""" |
|
return torch.argmax(x, dim=argmax_dim) |
|
|
|
|
|
def _squeeze_scalar_element_tensor(x: Tensor) -> Tensor: |
|
return x.squeeze() if x.numel() == 1 else x |
|
|
|
|
|
def _squeeze_if_scalar(data: Any) -> Any: |
|
return apply_to_collection(data, Tensor, _squeeze_scalar_element_tensor) |
|
|
|
|
|
def _bincount(x: Tensor, minlength: Optional[int] = None) -> Tensor: |
|
"""Implement custom bincount. |
|
|
|
PyTorch currently does not support ``torch.bincount`` when running in deterministic mode on GPU or when running |
|
MPS devices or when running on XLA device. This implementation therefore falls back to using a combination of |
|
`torch.arange` and `torch.eq` in these scenarios. A small performance hit can expected and higher memory consumption |
|
as `[batch_size, mincount]` tensor needs to be initialized compared to native ``torch.bincount``. |
|
|
|
Args: |
|
x: tensor to count |
|
minlength: minimum length to count |
|
|
|
Returns: |
|
Number of occurrences for each unique element in x |
|
|
|
Example: |
|
>>> x = torch.tensor([0,0,0,1,1,2,2,2,2]) |
|
>>> _bincount(x, minlength=3) |
|
tensor([3, 2, 4]) |
|
|
|
""" |
|
if minlength is None: |
|
minlength = len(torch.unique(x)) |
|
|
|
if torch.are_deterministic_algorithms_enabled() or _XLA_AVAILABLE or x.is_mps: |
|
mesh = torch.arange(minlength, device=x.device).repeat(len(x), 1) |
|
return torch.eq(x.reshape(-1, 1), mesh).sum(dim=0) |
|
|
|
return torch.bincount(x, minlength=minlength) |
|
|
|
|
|
def _cumsum(x: Tensor, dim: Optional[int] = 0, dtype: Optional[torch.dtype] = None) -> Tensor: |
|
"""Implement custom cumulative summation for Torch versions which does not support it natively.""" |
|
is_cuda_fp_deterministic = torch.are_deterministic_algorithms_enabled() and x.is_cuda and x.is_floating_point() |
|
if _TORCH_LESS_THAN_2_6 and is_cuda_fp_deterministic and sys.platform != "win32": |
|
rank_zero_warn( |
|
"You are trying to use a metric in deterministic mode on GPU that uses `torch.cumsum`, which is currently" |
|
" not supported. The tensor will be copied to the CPU memory to compute it and then copied back to GPU." |
|
" Expect some slowdowns.", |
|
TorchMetricsUserWarning, |
|
) |
|
return x.cpu().cumsum(dim=dim, dtype=dtype).to(x.device) |
|
return torch.cumsum(x, dim=dim, dtype=dtype) |
|
|
|
|
|
def _flexible_bincount(x: Tensor) -> Tensor: |
|
"""Similar to `_bincount`, but works also with tensor that do not contain continuous values. |
|
|
|
Args: |
|
x: tensor to count |
|
|
|
Returns: |
|
Number of occurrences for each unique element in x |
|
|
|
""" |
|
|
|
x = x - x.min() |
|
unique_x = torch.unique(x) |
|
|
|
output = _bincount(x, minlength=torch.max(unique_x) + 1) |
|
|
|
return output[unique_x] |
|
|
|
|
|
def allclose(tensor1: Tensor, tensor2: Tensor) -> bool: |
|
"""Wrap torch.allclose to be robust towards dtype difference.""" |
|
if tensor1.dtype != tensor2.dtype: |
|
tensor2 = tensor2.to(dtype=tensor1.dtype) |
|
return torch.allclose(tensor1, tensor2) |
|
|
|
|
|
def interp(x: Tensor, xp: Tensor, fp: Tensor) -> Tensor: |
|
"""Interpolation function comparable to numpy.interp. |
|
|
|
Args: |
|
x: x-coordinates where to evaluate the interpolated values |
|
xp: x-coordinates of the data points |
|
fp: y-coordinates of the data points |
|
|
|
""" |
|
|
|
sorted_indices = torch.argsort(xp) |
|
xp = xp[sorted_indices] |
|
fp = fp[sorted_indices] |
|
|
|
|
|
slopes = (fp[1:] - fp[:-1]) / (xp[1:] - xp[:-1]) |
|
|
|
|
|
indices = torch.searchsorted(xp, x) - 1 |
|
indices = torch.clamp(indices, 0, len(slopes) - 1) |
|
|
|
|
|
return fp[indices] + slopes[indices] * (x - xp[indices]) |
|
|