jamtur01's picture
Upload folder using huggingface_hub
9c6594c verified
# Copyright The Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
from collections.abc import Sequence
from typing import Any, List, Optional, Union
import torch
from lightning_utilities import apply_to_collection
from torch import Tensor
from torchmetrics.utilities.exceptions import TorchMetricsUserWarning
from torchmetrics.utilities.imports import _TORCH_LESS_THAN_2_6, _XLA_AVAILABLE
from torchmetrics.utilities.prints import rank_zero_warn
METRIC_EPS = 1e-6
def dim_zero_cat(x: Union[Tensor, List[Tensor]]) -> Tensor:
"""Concatenation along the zero dimension."""
if isinstance(x, torch.Tensor):
return x
x = [y.unsqueeze(0) if y.numel() == 1 and y.ndim == 0 else y for y in x]
if not x: # empty list
raise ValueError("No samples to concatenate")
return torch.cat(x, dim=0)
def dim_zero_sum(x: Tensor) -> Tensor:
"""Summation along the zero dimension."""
return torch.sum(x, dim=0)
def dim_zero_mean(x: Tensor) -> Tensor:
"""Average along the zero dimension."""
return torch.mean(x, dim=0)
def dim_zero_max(x: Tensor) -> Tensor:
"""Max along the zero dimension."""
return torch.max(x, dim=0).values
def dim_zero_min(x: Tensor) -> Tensor:
"""Min along the zero dimension."""
return torch.min(x, dim=0).values
def _flatten(x: Sequence) -> list:
"""Flatten list of list into single list."""
return [item for sublist in x for item in sublist]
def _flatten_dict(x: dict) -> tuple[dict, bool]:
"""Flatten dict of dicts into single dict and checking for duplicates in keys along the way."""
new_dict = {}
duplicates = False
for key, value in x.items():
if isinstance(value, dict):
for k, v in value.items():
if k in new_dict:
duplicates = True
new_dict[k] = v
else:
if key in new_dict:
duplicates = True
new_dict[key] = value
return new_dict, duplicates
def to_onehot(
label_tensor: Tensor,
num_classes: Optional[int] = None,
) -> Tensor:
"""Convert a dense label tensor to one-hot format.
Args:
label_tensor: dense label tensor, with shape [N, d1, d2, ...]
num_classes: number of classes C
Returns:
A sparse label tensor with shape [N, C, d1, d2, ...]
Example:
>>> x = torch.tensor([1, 2, 3])
>>> to_onehot(x)
tensor([[0, 1, 0, 0],
[0, 0, 1, 0],
[0, 0, 0, 1]])
"""
if num_classes is None:
num_classes = int(label_tensor.max().detach().item() + 1)
tensor_onehot = torch.zeros(
label_tensor.shape[0],
num_classes,
*label_tensor.shape[1:],
dtype=label_tensor.dtype,
device=label_tensor.device,
)
index = label_tensor.long().unsqueeze(1).expand_as(tensor_onehot)
return tensor_onehot.scatter_(1, index, 1.0)
def _top_k_with_half_precision_support(x: Tensor, k: int = 1, dim: int = 1) -> Tensor:
"""torch.top_k does not support half precision on CPU."""
if x.dtype == torch.half and not x.is_cuda:
idx = torch.argsort(x, dim=dim, stable=True).flip(dim)
return idx.narrow(dim, 0, k)
return x.topk(k=k, dim=dim).indices
def select_topk(prob_tensor: Tensor, topk: int = 1, dim: int = 1) -> Tensor:
"""Convert a probability tensor to binary by selecting top-k the highest entries.
Args:
prob_tensor: dense tensor of shape ``[..., C, ...]``, where ``C`` is in the
position defined by the ``dim`` argument
topk: number of the highest entries to turn into 1s
dim: dimension on which to compare entries
Returns:
A binary tensor of the same shape as the input tensor of type ``torch.int32``
Example:
>>> x = torch.tensor([[1.1, 2.0, 3.0], [2.0, 1.0, 0.5]])
>>> select_topk(x, topk=2)
tensor([[0, 1, 1],
[1, 1, 0]], dtype=torch.int32)
"""
topk_tensor = torch.zeros_like(prob_tensor, dtype=torch.int)
if topk == 1: # argmax has better performance than topk
topk_tensor.scatter_(dim, prob_tensor.argmax(dim=dim, keepdim=True), 1.0)
else:
topk_tensor.scatter_(dim, _top_k_with_half_precision_support(prob_tensor, k=topk, dim=dim), 1.0)
return topk_tensor.int()
def to_categorical(x: Tensor, argmax_dim: int = 1) -> Tensor:
"""Convert a tensor of probabilities to a dense label tensor.
Args:
x: probabilities to get the categorical label [N, d1, d2, ...]
argmax_dim: dimension to apply
Return:
A tensor with categorical labels [N, d2, ...]
Example:
>>> x = torch.tensor([[0.2, 0.5], [0.9, 0.1]])
>>> to_categorical(x)
tensor([1, 0])
"""
return torch.argmax(x, dim=argmax_dim)
def _squeeze_scalar_element_tensor(x: Tensor) -> Tensor:
return x.squeeze() if x.numel() == 1 else x
def _squeeze_if_scalar(data: Any) -> Any:
return apply_to_collection(data, Tensor, _squeeze_scalar_element_tensor)
def _bincount(x: Tensor, minlength: Optional[int] = None) -> Tensor:
"""Implement custom bincount.
PyTorch currently does not support ``torch.bincount`` when running in deterministic mode on GPU or when running
MPS devices or when running on XLA device. This implementation therefore falls back to using a combination of
`torch.arange` and `torch.eq` in these scenarios. A small performance hit can expected and higher memory consumption
as `[batch_size, mincount]` tensor needs to be initialized compared to native ``torch.bincount``.
Args:
x: tensor to count
minlength: minimum length to count
Returns:
Number of occurrences for each unique element in x
Example:
>>> x = torch.tensor([0,0,0,1,1,2,2,2,2])
>>> _bincount(x, minlength=3)
tensor([3, 2, 4])
"""
if minlength is None:
minlength = len(torch.unique(x))
if torch.are_deterministic_algorithms_enabled() or _XLA_AVAILABLE or x.is_mps:
mesh = torch.arange(minlength, device=x.device).repeat(len(x), 1)
return torch.eq(x.reshape(-1, 1), mesh).sum(dim=0)
return torch.bincount(x, minlength=minlength)
def _cumsum(x: Tensor, dim: Optional[int] = 0, dtype: Optional[torch.dtype] = None) -> Tensor:
"""Implement custom cumulative summation for Torch versions which does not support it natively."""
is_cuda_fp_deterministic = torch.are_deterministic_algorithms_enabled() and x.is_cuda and x.is_floating_point()
if _TORCH_LESS_THAN_2_6 and is_cuda_fp_deterministic and sys.platform != "win32":
rank_zero_warn(
"You are trying to use a metric in deterministic mode on GPU that uses `torch.cumsum`, which is currently"
" not supported. The tensor will be copied to the CPU memory to compute it and then copied back to GPU."
" Expect some slowdowns.",
TorchMetricsUserWarning,
)
return x.cpu().cumsum(dim=dim, dtype=dtype).to(x.device)
return torch.cumsum(x, dim=dim, dtype=dtype)
def _flexible_bincount(x: Tensor) -> Tensor:
"""Similar to `_bincount`, but works also with tensor that do not contain continuous values.
Args:
x: tensor to count
Returns:
Number of occurrences for each unique element in x
"""
# make sure elements in x start from 0
x = x - x.min()
unique_x = torch.unique(x)
output = _bincount(x, minlength=torch.max(unique_x) + 1) # type: ignore[arg-type]
# remove zeros from output tensor
return output[unique_x]
def allclose(tensor1: Tensor, tensor2: Tensor) -> bool:
"""Wrap torch.allclose to be robust towards dtype difference."""
if tensor1.dtype != tensor2.dtype:
tensor2 = tensor2.to(dtype=tensor1.dtype)
return torch.allclose(tensor1, tensor2)
def interp(x: Tensor, xp: Tensor, fp: Tensor) -> Tensor:
"""Interpolation function comparable to numpy.interp.
Args:
x: x-coordinates where to evaluate the interpolated values
xp: x-coordinates of the data points
fp: y-coordinates of the data points
"""
# Sort xp and fp based on xp for compatibility with np.interp
sorted_indices = torch.argsort(xp)
xp = xp[sorted_indices]
fp = fp[sorted_indices]
# Calculate slopes for each interval
slopes = (fp[1:] - fp[:-1]) / (xp[1:] - xp[:-1])
# Identify where x falls relative to xp
indices = torch.searchsorted(xp, x) - 1
indices = torch.clamp(indices, 0, len(slopes) - 1)
# Compute interpolated values
return fp[indices] + slopes[indices] * (x - xp[indices])