|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Utilities used for collections.""" |
|
|
|
from abc import ABC |
|
from functools import partial |
|
from typing import TYPE_CHECKING, Any, Callable, Union |
|
|
|
import torch |
|
from lightning_utilities.core.apply_func import apply_to_collection |
|
from torch import Tensor |
|
|
|
from lightning_fabric.utilities.imports import _NUMPY_AVAILABLE |
|
from lightning_fabric.utilities.types import _DEVICE |
|
|
|
if TYPE_CHECKING: |
|
import numpy as np |
|
|
|
_BLOCKING_DEVICE_TYPES = ("cpu", "mps") |
|
|
|
|
|
def _from_numpy(value: "np.ndarray", device: _DEVICE) -> Tensor: |
|
return torch.from_numpy(value).to(device) |
|
|
|
|
|
CONVERSION_DTYPES: list[tuple[Any, Callable[[Any, Any], Tensor]]] = [ |
|
|
|
(bool, partial(torch.tensor, dtype=torch.uint8)), |
|
(int, partial(torch.tensor, dtype=torch.int)), |
|
(float, partial(torch.tensor, dtype=torch.float)), |
|
] |
|
|
|
if _NUMPY_AVAILABLE: |
|
import numpy as np |
|
|
|
CONVERSION_DTYPES.append((np.ndarray, _from_numpy)) |
|
|
|
|
|
class _TransferableDataType(ABC): |
|
"""A custom type for data that can be moved to a torch device via ``.to(...)``. |
|
|
|
Example: |
|
|
|
>>> isinstance(dict, _TransferableDataType) |
|
False |
|
>>> isinstance(torch.rand(2, 3), _TransferableDataType) |
|
True |
|
>>> class CustomObject: |
|
... def __init__(self): |
|
... self.x = torch.rand(2, 2) |
|
... def to(self, device): |
|
... self.x = self.x.to(device) |
|
... return self |
|
>>> isinstance(CustomObject(), _TransferableDataType) |
|
True |
|
|
|
""" |
|
|
|
@classmethod |
|
def __subclasshook__(cls, subclass: Any) -> Union[bool, Any]: |
|
if cls is _TransferableDataType: |
|
to = getattr(subclass, "to", None) |
|
return callable(to) |
|
return NotImplemented |
|
|
|
|
|
def move_data_to_device(batch: Any, device: _DEVICE) -> Any: |
|
"""Transfers a collection of data to the given device. Any object that defines a method ``to(device)`` will be |
|
moved and all other objects in the collection will be left untouched. |
|
|
|
Args: |
|
batch: A tensor or collection of tensors or anything that has a method ``.to(...)``. |
|
See :func:`apply_to_collection` for a list of supported collection types. |
|
device: The device to which the data should be moved |
|
|
|
Return: |
|
the same collection but with all contained tensors residing on the new device. |
|
|
|
See Also: |
|
- :meth:`torch.Tensor.to` |
|
- :class:`torch.device` |
|
|
|
""" |
|
if isinstance(device, str): |
|
device = torch.device(device) |
|
|
|
def batch_to(data: Any) -> Any: |
|
kwargs = {} |
|
|
|
|
|
if isinstance(data, Tensor) and isinstance(device, torch.device) and device.type not in _BLOCKING_DEVICE_TYPES: |
|
kwargs["non_blocking"] = True |
|
data_output = data.to(device, **kwargs) |
|
if data_output is not None: |
|
return data_output |
|
|
|
return data |
|
|
|
return apply_to_collection(batch, dtype=_TransferableDataType, function=batch_to) |
|
|
|
|
|
def convert_to_tensors(data: Any, device: _DEVICE) -> Any: |
|
|
|
for src_dtype, conversion_func in CONVERSION_DTYPES: |
|
data = apply_to_collection(data, src_dtype, conversion_func, device=device) |
|
return move_data_to_device(data, device) |
|
|
|
|
|
def convert_tensors_to_scalars(data: Any) -> Any: |
|
"""Recursively walk through a collection and convert single-item tensors to scalar values. |
|
|
|
Raises: |
|
ValueError: |
|
If tensors inside ``metrics`` contains multiple elements, hence preventing conversion to a scalar. |
|
|
|
""" |
|
|
|
def to_item(value: Tensor) -> Union[int, float, bool]: |
|
if value.numel() != 1: |
|
raise ValueError( |
|
f"The metric `{value}` does not contain a single element, thus it cannot be converted to a scalar." |
|
) |
|
return value.item() |
|
|
|
return apply_to_collection(data, Tensor, to_item) |
|
|