|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import functools |
|
from typing import Any, Union |
|
|
|
import torch |
|
from lightning_utilities.core.imports import RequirementCache |
|
from typing_extensions import override |
|
|
|
from lightning_fabric.accelerators.accelerator import Accelerator |
|
from lightning_fabric.accelerators.registry import _AcceleratorRegistry |
|
from lightning_fabric.utilities.device_parser import _check_data_type |
|
|
|
|
|
class XLAAccelerator(Accelerator): |
|
"""Accelerator for XLA devices, normally TPUs. |
|
|
|
.. warning:: Use of this accelerator beyond import and instantiation is experimental. |
|
|
|
""" |
|
|
|
def __init__(self, *args: Any, **kwargs: Any) -> None: |
|
if not _XLA_AVAILABLE: |
|
raise ModuleNotFoundError(str(_XLA_AVAILABLE)) |
|
if not _using_pjrt(): |
|
raise RuntimeError("The XLA XRT runtime is not supported anymore.") |
|
super().__init__(*args, **kwargs) |
|
|
|
@override |
|
def setup_device(self, device: torch.device) -> None: |
|
pass |
|
|
|
@override |
|
def teardown(self) -> None: |
|
pass |
|
|
|
@staticmethod |
|
@override |
|
def parse_devices(devices: Union[int, str, list[int]]) -> Union[int, list[int]]: |
|
"""Accelerator device parsing logic.""" |
|
return _parse_tpu_devices(devices) |
|
|
|
@staticmethod |
|
@override |
|
def get_parallel_devices(devices: Union[int, list[int]]) -> list[torch.device]: |
|
"""Gets parallel devices for the Accelerator.""" |
|
devices = _parse_tpu_devices(devices) |
|
if isinstance(devices, int): |
|
return [torch.device("xla", i) for i in range(devices)] |
|
|
|
return [torch.device("xla", devices[0])] |
|
|
|
|
|
|
|
|
|
@staticmethod |
|
@override |
|
|
|
|
|
@functools.lru_cache(maxsize=1) |
|
def auto_device_count() -> int: |
|
"""Get the devices when set to auto.""" |
|
if not _XLA_AVAILABLE: |
|
return 0 |
|
if _XLA_GREATER_EQUAL_2_1: |
|
from torch_xla._internal import tpu |
|
|
|
return tpu.num_available_devices() |
|
from torch_xla.experimental import tpu |
|
|
|
device_count_on_version = {2: 8, 3: 8, 4: 4} |
|
return device_count_on_version.get(tpu.version(), 8) |
|
|
|
@staticmethod |
|
@override |
|
@functools.lru_cache(maxsize=1) |
|
def is_available() -> bool: |
|
try: |
|
return XLAAccelerator.auto_device_count() > 0 |
|
except (ValueError, AssertionError, OSError): |
|
|
|
|
|
return False |
|
|
|
@classmethod |
|
@override |
|
def register_accelerators(cls, accelerator_registry: _AcceleratorRegistry) -> None: |
|
accelerator_registry.register("tpu", cls, description=cls.__name__) |
|
|
|
|
|
|
|
_XLA_AVAILABLE = RequirementCache("torch_xla>=1.13", "torch_xla") |
|
_XLA_GREATER_EQUAL_2_1 = RequirementCache("torch_xla>=2.1") |
|
_XLA_GREATER_EQUAL_2_5 = RequirementCache("torch_xla>=2.5") |
|
|
|
|
|
def _using_pjrt() -> bool: |
|
|
|
if _XLA_GREATER_EQUAL_2_5: |
|
from torch_xla import runtime as xr |
|
|
|
return xr.device_type() is not None |
|
|
|
if _XLA_GREATER_EQUAL_2_1: |
|
from torch_xla import runtime as xr |
|
|
|
return xr.using_pjrt() |
|
|
|
from torch_xla.experimental import pjrt |
|
|
|
return pjrt.using_pjrt() |
|
|
|
|
|
def _parse_tpu_devices(devices: Union[int, str, list[int]]) -> Union[int, list[int]]: |
|
"""Parses the TPU devices given in the format as accepted by the |
|
:class:`~pytorch_lightning.trainer.trainer.Trainer` and :class:`~lightning_fabric.Fabric`. |
|
|
|
Args: |
|
devices: An int of 1 or string '1' indicates that 1 core with multi-processing should be used |
|
An int 8 or string '8' indicates that all 8 cores with multi-processing should be used |
|
A single element list of int or string can be used to indicate the specific TPU core to use. |
|
|
|
Returns: |
|
A list of tpu cores to be used. |
|
|
|
""" |
|
_check_data_type(devices) |
|
if isinstance(devices, str): |
|
devices = _parse_tpu_devices_str(devices) |
|
_check_tpu_devices_valid(devices) |
|
return devices |
|
|
|
|
|
def _check_tpu_devices_valid(devices: object) -> None: |
|
device_count = XLAAccelerator.auto_device_count() |
|
if ( |
|
|
|
isinstance(devices, int) |
|
and devices in {1, device_count} |
|
|
|
or isinstance(devices, (list, tuple)) |
|
and len(devices) == 1 |
|
and 0 <= devices[0] <= device_count - 1 |
|
): |
|
return |
|
raise ValueError( |
|
f"`devices` can only be 'auto', 1, {device_count} or [<0-{device_count - 1}>] for TPUs. Got {devices!r}" |
|
) |
|
|
|
|
|
def _parse_tpu_devices_str(devices: str) -> Union[int, list[int]]: |
|
devices = devices.strip() |
|
try: |
|
return int(devices) |
|
except ValueError: |
|
try: |
|
return [int(x.strip()) for x in devices.split(",") if len(x) > 0] |
|
except ValueError: |
|
raise ValueError(f"Could not parse the selected TPU devices: {devices!r}") |
|
|