|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import queue |
|
import time |
|
from typing import TYPE_CHECKING, Any, Callable, Optional, Union |
|
|
|
import torch.multiprocessing as mp |
|
from typing_extensions import override |
|
|
|
from lightning.fabric.accelerators.xla import _XLA_AVAILABLE |
|
from lightning.fabric.strategies.launchers.launcher import _Launcher |
|
from lightning.fabric.strategies.launchers.multiprocessing import _GlobalStateSnapshot |
|
from lightning.fabric.utilities.apply_func import move_data_to_device |
|
|
|
if TYPE_CHECKING: |
|
from lightning.fabric.strategies import XLAFSDPStrategy, XLAStrategy |
|
|
|
|
|
class _XLALauncher(_Launcher): |
|
r"""Launches processes that run a given function in parallel on XLA supported hardware, and joins them all at the |
|
end. |
|
|
|
The main process in which this launcher is invoked creates N so-called worker processes (using the |
|
`torch_xla` :func:`xmp.spawn`) that run the given function. |
|
Worker processes have a rank that ranges from 0 to N - 1. |
|
|
|
Note: |
|
- This launcher requires all objects to be pickleable. |
|
- It is important that the entry point to the program/script is guarded by ``if __name__ == "__main__"``. |
|
|
|
Args: |
|
strategy: A reference to the strategy that is used together with this launcher |
|
|
|
""" |
|
|
|
def __init__(self, strategy: Union["XLAStrategy", "XLAFSDPStrategy"]) -> None: |
|
if not _XLA_AVAILABLE: |
|
raise ModuleNotFoundError(str(_XLA_AVAILABLE)) |
|
self._strategy = strategy |
|
self._start_method = "fork" |
|
|
|
@property |
|
@override |
|
def is_interactive_compatible(self) -> bool: |
|
return True |
|
|
|
@override |
|
def launch(self, function: Callable, *args: Any, **kwargs: Any) -> Any: |
|
"""Launches processes that run the given function in parallel. |
|
|
|
The function is allowed to have a return value. However, when all processes join, only the return value |
|
of worker process 0 gets returned from this `launch` method in the main process. |
|
|
|
Arguments: |
|
function: The entry point for all launched processes. |
|
*args: Optional positional arguments to be passed to the given function. |
|
**kwargs: Optional keyword arguments to be passed to the given function. |
|
|
|
""" |
|
return_queue: Union[queue.Queue, mp.SimpleQueue] |
|
return_queue = mp.Manager().Queue() |
|
|
|
import torch_xla.distributed.xla_multiprocessing as xmp |
|
|
|
spawn_kwargs = {} |
|
nprocs = self._strategy.num_processes |
|
if nprocs == 1: |
|
|
|
|
|
spawn_kwargs["nprocs"] = nprocs |
|
|
|
xmp.spawn( |
|
self._wrapping_function, |
|
args=(function, args, kwargs, return_queue), |
|
start_method=self._start_method, |
|
**spawn_kwargs, |
|
) |
|
return return_queue.get() |
|
|
|
def _wrapping_function( |
|
self, |
|
|
|
|
|
process_idx: int, |
|
function: Callable, |
|
args: Any, |
|
kwargs: Any, |
|
return_queue: Union[mp.SimpleQueue, queue.Queue], |
|
global_states: Optional[_GlobalStateSnapshot] = None, |
|
) -> None: |
|
import torch_xla.core.xla_model as xm |
|
|
|
if len(xm.get_xla_supported_devices()) > 1: |
|
|
|
|
|
import copy |
|
|
|
function, args, kwargs = copy.deepcopy((function, args, kwargs)) |
|
|
|
results = function(*args, **kwargs) |
|
|
|
if self._strategy.local_rank == 0: |
|
return_queue.put(move_data_to_device(results, "cpu")) |
|
|
|
_rank_teardown(self._strategy.local_rank) |
|
|
|
|
|
def _rank_teardown(rank: int) -> None: |
|
import torch_xla.core.xla_model as xm |
|
|
|
|
|
|
|
xm.rendezvous("end-process") |
|
|
|
|
|
if rank == 0: |
|
time.sleep(1) |
|
|