|
|
|
|
|
import logging |
|
import math |
|
import threading |
|
from functools import reduce |
|
from itertools import chain |
|
from typing import Optional, TYPE_CHECKING, Union |
|
|
|
import torch |
|
from torch.distributed import is_available |
|
from torch.utils._typing_utils import not_none |
|
|
|
|
|
__all__ = ["init_device_mesh", "DeviceMesh"] |
|
|
|
|
|
if not is_available(): |
|
import sys |
|
|
|
|
|
|
|
|
|
|
|
|
|
class _DeviceMeshStub: |
|
pass |
|
|
|
def _init_device_mesh_stub(): |
|
pass |
|
|
|
sys.modules["torch.distributed.device_mesh"].DeviceMesh = _DeviceMeshStub |
|
sys.modules[ |
|
"torch.distributed.device_mesh" |
|
].init_device_mesh = _init_device_mesh_stub |
|
|
|
|
|
else: |
|
from torch._C._distributed_c10d import Backend as C10dBackend |
|
from torch.distributed.distributed_c10d import ( |
|
_find_pg_by_ranks_and_tag, |
|
_get_default_group, |
|
_get_group_tag, |
|
get_backend, |
|
get_process_group_ranks, |
|
get_rank, |
|
get_world_size, |
|
init_process_group, |
|
is_initialized, |
|
new_group, |
|
ProcessGroup, |
|
split_group, |
|
) |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
if TYPE_CHECKING: |
|
try: |
|
from numpy.typing import ArrayLike |
|
except ImportError: |
|
logger.warning( |
|
"DeviceMesh requires numpy >= 1.21 to be installed for type checking" |
|
) |
|
|
|
class _MeshEnv(threading.local): |
|
def __init__(self) -> None: |
|
self.mesh_stack: list[DeviceMesh] = [] |
|
self.child_to_root_mapping: dict[DeviceMesh, DeviceMesh] = {} |
|
self.mesh_dim_group_options: dict[ |
|
int, tuple[str, Optional[C10dBackend.Options]] |
|
] = {} |
|
self.root_to_flatten_mapping: dict[DeviceMesh, dict[str, DeviceMesh]] = {} |
|
|
|
self.flatten_name_to_root_dims: dict[ |
|
DeviceMesh, dict[str, tuple[int, ...]] |
|
] = {} |
|
|
|
def get_current_mesh(self) -> "DeviceMesh": |
|
if len(self.mesh_stack) == 0: |
|
raise RuntimeError("No device mesh is currently active!") |
|
return self.mesh_stack[-1] |
|
|
|
def create_sub_mesh( |
|
self, |
|
device_mesh: "DeviceMesh", |
|
submesh_dim_names: tuple[str, ...], |
|
submesh_dims: list[tuple[int, ...]], |
|
) -> "DeviceMesh": |
|
|
|
|
|
|
|
|
|
slice_dim_size = [ |
|
reduce( |
|
lambda x, y: x * device_mesh.mesh.size(y), |
|
mesh_dim, |
|
1, |
|
) |
|
for mesh_dim in submesh_dims |
|
] |
|
|
|
mesh_tensor = device_mesh.mesh |
|
|
|
slice_dim_idx = [] |
|
slice_dim_group_info = [] |
|
|
|
|
|
num_dims_flatten = 0 |
|
for mesh_dim_indices, mesh_dim_name in zip(submesh_dims, submesh_dim_names): |
|
|
|
|
|
if len(mesh_dim_indices) > 1: |
|
|
|
mesh_tensor = mesh_tensor.flatten( |
|
start_dim=mesh_dim_indices[0] - num_dims_flatten, |
|
end_dim=mesh_dim_indices[-1] - num_dims_flatten, |
|
) |
|
|
|
|
|
|
|
slice_dim_idx.append(mesh_dim_indices[0] - num_dims_flatten) |
|
num_dims_flatten += len(mesh_dim_indices) - 1 |
|
slice_dim_group_info.append( |
|
self.root_to_flatten_mapping[device_mesh][ |
|
mesh_dim_name |
|
]._dim_group_infos[0] |
|
) |
|
else: |
|
slice_dim_idx.append(mesh_dim_indices[0] - num_dims_flatten) |
|
slice_dim_group_info.append( |
|
device_mesh._dim_group_infos[mesh_dim_indices[0]] |
|
) |
|
|
|
|
|
mesh_dims_remained_idx = list(range(mesh_tensor.ndim)) |
|
for idx in slice_dim_idx: |
|
mesh_dims_remained_idx.remove(idx) |
|
|
|
|
|
|
|
|
|
pg_ranks_by_dim = mesh_tensor.permute( |
|
*mesh_dims_remained_idx, *slice_dim_idx |
|
).reshape(-1, *slice_dim_size) |
|
|
|
cur_rank = device_mesh.get_rank() |
|
for mesh_nd in pg_ranks_by_dim: |
|
submesh = DeviceMesh( |
|
device_mesh.device_type, |
|
mesh_nd, |
|
mesh_dim_names=submesh_dim_names, |
|
_init_backend=False, |
|
) |
|
if cur_rank in mesh_nd: |
|
res_submesh = submesh |
|
|
|
res_submesh._dim_group_infos = slice_dim_group_info |
|
self.child_to_root_mapping[res_submesh] = device_mesh |
|
|
|
return res_submesh |
|
|
|
def create_flatten_mesh( |
|
self, device_mesh: "DeviceMesh", mesh_dim_name: Optional[str] = None |
|
) -> "DeviceMesh": |
|
root_mesh = _mesh_resources.get_root_mesh(device_mesh) |
|
|
|
flatten_dims_in_root = [ |
|
not_none(root_mesh.mesh_dim_names).index(flattened_mesh_dim_name) |
|
for flattened_mesh_dim_name in not_none(device_mesh.mesh_dim_names) |
|
] |
|
|
|
if not mesh_dim_name: |
|
mesh_dim_name = "_".join( |
|
[ |
|
not_none(root_mesh.mesh_dim_names)[dim] |
|
for dim in flatten_dims_in_root |
|
] |
|
) |
|
|
|
|
|
self.flatten_name_to_root_dims.setdefault(root_mesh, {}) |
|
invalid_dim_names = chain( |
|
*list(not_none(root_mesh.mesh_dim_names)), |
|
*self.flatten_name_to_root_dims[root_mesh].keys(), |
|
) |
|
if mesh_dim_name in invalid_dim_names: |
|
raise RuntimeError( |
|
f"{mesh_dim_name} already exists for submesh of the {root_mesh}. ", |
|
f"The mesh_dim_names of submesh and flattened mesh are {invalid_dim_names}. " |
|
f"Please specify another valid mesh_dim_name.", |
|
) |
|
|
|
|
|
|
|
|
|
if ( |
|
root_mesh in self.root_to_flatten_mapping |
|
and mesh_dim_name in self.root_to_flatten_mapping[root_mesh] |
|
): |
|
return self.root_to_flatten_mapping[root_mesh][mesh_dim_name] |
|
|
|
flattened_mesh_dim_size = math.prod(device_mesh.mesh.size()) |
|
|
|
remained_dims_in_root = list(range(root_mesh.mesh.ndim)) |
|
for flatten_dim_in_root in flatten_dims_in_root: |
|
remained_dims_in_root.remove(flatten_dim_in_root) |
|
|
|
pg_ranks_by_dim = root_mesh.mesh.permute( |
|
*remained_dims_in_root, *flatten_dims_in_root |
|
).reshape(-1, flattened_mesh_dim_size) |
|
|
|
cur_rank = root_mesh.get_rank() |
|
for mesh_nd in pg_ranks_by_dim: |
|
|
|
flattened_mesh = DeviceMesh( |
|
root_mesh.device_type, |
|
mesh_nd, |
|
mesh_dim_names=(mesh_dim_name,), |
|
) |
|
if cur_rank in mesh_nd: |
|
res_flattened_mesh = flattened_mesh |
|
self.child_to_root_mapping[res_flattened_mesh] = root_mesh |
|
self.root_to_flatten_mapping.setdefault(root_mesh, {})[mesh_dim_name] = ( |
|
res_flattened_mesh |
|
) |
|
self.flatten_name_to_root_dims[root_mesh][mesh_dim_name] = tuple( |
|
flatten_dims_in_root |
|
) |
|
|
|
return res_flattened_mesh |
|
|
|
def get_root_mesh(self, device_mesh: "DeviceMesh") -> "DeviceMesh": |
|
|
|
|
|
|
|
root_mesh = self.child_to_root_mapping.get(device_mesh, None) |
|
return device_mesh if not root_mesh else root_mesh |
|
|
|
def get_root_mesh_dim(self, device_mesh: "DeviceMesh") -> Optional[int]: |
|
""" |
|
Returns the index of the mesh dim in the root mesh. |
|
The device_mesh passed in needs to be sliced out from the root mesh |
|
or submesh of the root mesh. |
|
""" |
|
root_mesh = self.get_root_mesh(device_mesh) |
|
child_mesh_dim_names = device_mesh.mesh_dim_names |
|
if root_mesh and child_mesh_dim_names: |
|
assert len(child_mesh_dim_names) == 1, ( |
|
"The submesh can only be a 1D mesh." |
|
) |
|
child_mesh_dim_name = child_mesh_dim_names[0] |
|
return self.get_mesh_dim_by_name(root_mesh, child_mesh_dim_name) |
|
return None |
|
|
|
@staticmethod |
|
def num_devices_per_host(device_type: str) -> int: |
|
return _get_device_handle(device_type).device_count() |
|
|
|
@staticmethod |
|
def num_hosts(device_type: str) -> int: |
|
|
|
|
|
return get_world_size() // _MeshEnv.num_devices_per_host(device_type) |
|
|
|
def get_mesh_dim_by_name( |
|
self, device_mesh: "DeviceMesh", mesh_dim_name: str |
|
) -> int: |
|
if ( |
|
device_mesh.mesh_dim_names is None |
|
or len(device_mesh.mesh_dim_names) == 0 |
|
): |
|
raise KeyError( |
|
"No `mesh_dim_names` found.", |
|
) |
|
if mesh_dim_name not in device_mesh.mesh_dim_names: |
|
raise KeyError( |
|
f"Mesh dimension '{mesh_dim_name}' does not exist.", |
|
f"Available mesh dimensions are: mesh_dim_names={device_mesh.mesh_dim_names}", |
|
) |
|
return not_none(device_mesh.mesh_dim_names.index(mesh_dim_name)) |
|
|
|
def _set_mesh_dim_group_options( |
|
self, |
|
dim: int, |
|
backend: str, |
|
pg_options: Optional[C10dBackend.Options] = None, |
|
) -> None: |
|
self.mesh_dim_group_options[dim] = (backend, pg_options) |
|
|
|
def _get_slice_mesh_dims( |
|
self, device_mesh, mesh_dim_names |
|
) -> list[tuple[int, ...]]: |
|
""" |
|
Validate whether the mesh_dim_names is valid for slicing the given device_mesh. |
|
If valid, return dim indexes of the slice mesh in the device mesh. |
|
""" |
|
if device_mesh != self.get_root_mesh(device_mesh): |
|
raise RuntimeError("Cannot create a submesh from a submesh.") |
|
|
|
|
|
|
|
self.flatten_name_to_root_dims.setdefault(device_mesh, {}) |
|
flatten_name_to_root_dims = self.flatten_name_to_root_dims[device_mesh] |
|
valid_mesh_dim_names = [ |
|
*device_mesh.mesh_dim_names, |
|
*flatten_name_to_root_dims, |
|
] |
|
|
|
if not all( |
|
mesh_dim_name in valid_mesh_dim_names |
|
for mesh_dim_name in mesh_dim_names |
|
): |
|
raise KeyError( |
|
f"Invalid mesh_dim_names {mesh_dim_names} specified. " |
|
f"Valid mesh_dim_names are {valid_mesh_dim_names}." |
|
) |
|
|
|
|
|
|
|
curr_idx = -1 |
|
slice_mesh_dims = [] |
|
for mesh_dim_name in mesh_dim_names: |
|
if mesh_dim_name in flatten_name_to_root_dims: |
|
mesh_indices = flatten_name_to_root_dims[mesh_dim_name] |
|
|
|
|
|
next_idx = mesh_indices[-1] |
|
slice_mesh_dims.append(mesh_indices) |
|
else: |
|
next_idx = device_mesh.mesh_dim_names.index(mesh_dim_name) |
|
slice_mesh_dims.append((next_idx,)) |
|
if next_idx <= curr_idx: |
|
raise KeyError( |
|
f"Invalid mesh_dim_names {mesh_dim_names} specified. ", |
|
f"Found mesh dim indices to slice: {slice_mesh_dims}. ", |
|
"Mesh dim indices should be in ascending order.", |
|
) |
|
curr_idx = next_idx |
|
|
|
return slice_mesh_dims |
|
|
|
def _get_all_submeshes( |
|
self, device_mesh: "DeviceMesh", mesh_dim_name: str |
|
) -> list["DeviceMesh"]: |
|
""" |
|
Return all the submeshes of a given mesh dimension of the device mesh. |
|
""" |
|
mesh_dim = self.get_mesh_dim_by_name(device_mesh, mesh_dim_name) |
|
pg_ranks_by_dim = device_mesh.mesh.swapdims(-1, mesh_dim).reshape( |
|
-1, device_mesh.mesh.size(mesh_dim) |
|
) |
|
|
|
cur_rank = device_mesh.get_rank() |
|
res_submeshes = [] |
|
for mesh_1d in pg_ranks_by_dim: |
|
submesh = DeviceMesh( |
|
device_mesh.device_type, |
|
mesh_1d, |
|
mesh_dim_names=(mesh_dim_name,), |
|
_init_backend=False, |
|
) |
|
submesh._dim_group_infos = ( |
|
[device_mesh._dim_group_infos[mesh_dim]] |
|
if cur_rank in mesh_1d |
|
else [] |
|
) |
|
res_submeshes.append(submesh) |
|
|
|
return res_submeshes |
|
|
|
_mesh_resources: _MeshEnv = _MeshEnv() |
|
|
|
def _get_device_handle(device_type: str = "cuda"): |
|
""" |
|
Get the module corresponding to the device_type which is cuda or cuda-like device. |
|
For example, when the device_type is cuda, the module `torch.cuda` is returned. |
|
Return None when there is no corresponding module for device_type, otherwise |
|
return the corresponding module. |
|
""" |
|
return getattr(torch, device_type, None) |
|
|
|
class DeviceMesh: |
|
""" |
|
DeviceMesh represents a mesh of devices, where layout of devices could be |
|
represented as a n-d dimension array, and each value of the n-d dimensional |
|
array is the global id of the default process group ranks. |
|
|
|
DeviceMesh could be used to describe the layout of devices across the cluster, |
|
and serves as a proxy for communication among the device lists within the cluster. |
|
|
|
DeviceMesh can be used as a context manager. |
|
|
|
.. note:: |
|
DeviceMesh follows SPMD programming model, which means the same PyTorch Python program |
|
is running on all processes/ranks in the cluster. Therefore, users need to make sure the |
|
`mesh` array (which describes the layout of devices) should be identical across all ranks. |
|
Inconsistent `mesh` will lead to silent hang. |
|
|
|
Args: |
|
device_type (str): The device type of the mesh. Currently supports: "cpu", "cuda/cuda-like". |
|
mesh (ndarray): A multi-dimensional array or an integer tensor describing the layout |
|
of devices, where the IDs are global IDs of the default process group. |
|
|
|
Returns: |
|
DeviceMesh: A :class:`DeviceMesh` object representing the device layout. |
|
|
|
The following program runs on each process/rank in an SPMD manner. In this example, we have 2 |
|
hosts with 4 GPUs each. |
|
A reduction over the first dimension of mesh will reduce across |
|
columns (0, 4), .. and (3, 7), a reduction over the second dimension |
|
of mesh reduces across rows (0, 1, 2, 3) and (4, 5, 6, 7). |
|
|
|
Example:: |
|
>>> # xdoctest: +SKIP("no rank") |
|
>>> from torch.distributed.device_mesh import DeviceMesh |
|
>>> |
|
>>> # Initialize device mesh as (2, 4) to represent the topology |
|
>>> # of cross-host(dim 0), and within-host (dim 1). |
|
>>> mesh = DeviceMesh(device_type="cuda", mesh=[[0, 1, 2, 3],[4, 5, 6, 7]]) |
|
""" |
|
|
|
device_type: str |
|
mesh: torch.Tensor |
|
mesh_dim_names: Optional[tuple[str, ...]] |
|
|
|
def __init__( |
|
self, |
|
device_type: str, |
|
mesh: Union[torch.Tensor, "ArrayLike"], |
|
*, |
|
mesh_dim_names: Optional[tuple[str, ...]] = None, |
|
_init_backend: bool = True, |
|
) -> None: |
|
self.device_type = device_type |
|
if isinstance(mesh, torch.Tensor) and mesh.device.type != "cpu": |
|
raise ValueError(f"`mesh` must be a CPU tensor, got {mesh}") |
|
self.mesh = ( |
|
mesh.detach().to(dtype=torch.int) |
|
if isinstance(mesh, torch.Tensor) |
|
else torch.tensor(mesh, device="cpu", dtype=torch.int) |
|
) |
|
self.mesh_dim_names = tuple(mesh_dim_names) if mesh_dim_names else None |
|
|
|
|
|
self._flatten_mesh_list = tuple(self.mesh.flatten().tolist()) |
|
self._thread_id = None |
|
|
|
|
|
|
|
if device_type != "xla": |
|
|
|
|
|
|
|
if _init_backend: |
|
self._get_or_create_default_group() |
|
self._init_process_groups() |
|
|
|
if is_initialized() and get_backend() == "threaded": |
|
self._thread_id = threading.get_ident() |
|
|
|
|
|
rank_coords = (self.mesh == get_rank()).nonzero() |
|
assert rank_coords.size(0) in (0, 1) |
|
self._coordinate_on_dim: Optional[list[int]] = ( |
|
rank_coords[0].tolist() if rank_coords.size(0) > 0 else None |
|
) |
|
|
|
def _get_or_create_default_group(self): |
|
default_initialized = is_initialized() |
|
if not default_initialized: |
|
init_process_group() |
|
|
|
world_size = get_world_size() |
|
if self.mesh.numel() > world_size: |
|
raise RuntimeError( |
|
f"Mesh should not be bigger than default world size {world_size}, but found {self.mesh.numel()} ranks!" |
|
) |
|
|
|
device_handle = _get_device_handle(self.device_type) |
|
|
|
if not default_initialized and device_handle: |
|
|
|
|
|
num_devices_per_host = device_handle.device_count() |
|
if ( |
|
world_size > num_devices_per_host |
|
and world_size % num_devices_per_host != 0 |
|
): |
|
raise RuntimeError( |
|
f"DeviceMesh only support homogeneous hardware, but found " |
|
f"{world_size} ranks and {num_devices_per_host} {self.device_type} devices!" |
|
) |
|
device_handle.set_device(get_rank() % num_devices_per_host) |
|
|
|
return _get_default_group() |
|
|
|
def _init_process_groups(self): |
|
|
|
|
|
|
|
|
|
|
|
|
|
dim_group_infos: list[tuple[str, list[int], str]] = [] |
|
default_group = _get_default_group() |
|
|
|
if self.mesh.ndim == 1 and self.mesh.numel() == get_world_size(): |
|
|
|
|
|
ranks = list(range(get_world_size())) |
|
dim_group = ( |
|
new_group( |
|
backend="cpu:gloo,cuda:nccl", |
|
ranks=ranks, |
|
group_desc="mesh_default", |
|
) |
|
if torch.cuda.is_available() |
|
and get_backend(default_group) == "gloo" |
|
else default_group |
|
) |
|
dim_group_infos.append( |
|
( |
|
_get_group_tag(dim_group), |
|
ranks, |
|
dim_group.group_name, |
|
) |
|
) |
|
else: |
|
|
|
for dim in range(self.mesh.ndim): |
|
|
|
|
|
pg_ranks_by_dim = self.mesh.swapdims(-1, dim).reshape( |
|
-1, self.mesh.size(dim) |
|
) |
|
|
|
|
|
|
|
if dim in _mesh_resources.mesh_dim_group_options: |
|
( |
|
backend, |
|
pg_options, |
|
) = _mesh_resources.mesh_dim_group_options[dim] |
|
else: |
|
backend, pg_options = None, None |
|
|
|
|
|
|
|
|
|
|
|
group_desc = ( |
|
f"mesh_{self.mesh_dim_names[dim]}" |
|
if self.mesh_dim_names |
|
else f"mesh_dim_{dim}" |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dim_group = None |
|
has_split_group = False |
|
if ( |
|
bound_device_id := getattr( |
|
default_group, "bound_device_id", None |
|
) |
|
) is not None and torch.cuda.is_available(): |
|
dim_group = split_group( |
|
parent_pg=default_group, |
|
pg_options=pg_options, |
|
split_ranks=pg_ranks_by_dim.tolist(), |
|
group_desc=group_desc, |
|
) |
|
has_split_group = True |
|
|
|
|
|
|
|
|
|
|
|
|
|
for dim_mesh in pg_ranks_by_dim: |
|
subgroup_ranks = dim_mesh.tolist() |
|
|
|
|
|
|
|
|
|
if bound_device_id is None or not has_split_group: |
|
dim_group = new_group( |
|
ranks=subgroup_ranks, |
|
backend=backend, |
|
pg_options=pg_options, |
|
group_desc=group_desc, |
|
) |
|
|
|
|
|
if self.get_rank() in subgroup_ranks: |
|
if len(dim_group_infos) > dim: |
|
raise RuntimeError( |
|
f"Each device mesh dimension should get only one process group, but got {self.get_rank()} " |
|
f"in {subgroup_ranks}!" |
|
) |
|
dim_group_infos.append( |
|
( |
|
_get_group_tag(not_none(dim_group)), |
|
subgroup_ranks, |
|
dim_group.group_name, |
|
) |
|
) |
|
self._dim_group_infos = dim_group_infos |
|
|
|
def __enter__(self) -> "DeviceMesh": |
|
|
|
_mesh_resources.mesh_stack.append(self) |
|
return self |
|
|
|
|
|
def __exit__(self, exc_type, exc_value, exc_traceback) -> None: |
|
|
|
_mesh_resources.mesh_stack.pop() |
|
|
|
def __repr__(self) -> str: |
|
device_mesh_repr = ( |
|
f"DeviceMesh('{self.device_type}', {self.mesh.tolist()})" |
|
if not self.mesh_dim_names |
|
else f"DeviceMesh('{self.device_type}', {self.mesh.tolist()}, mesh_dim_names={self.mesh_dim_names})" |
|
) |
|
return device_mesh_repr |
|
|
|
def __hash__(self): |
|
|
|
self._hash = getattr(self, "_hash", None) |
|
if not self._hash: |
|
self._hash = hash( |
|
( |
|
self._flatten_mesh_list, |
|
self.mesh.shape, |
|
self.device_type, |
|
self.mesh_dim_names, |
|
self._thread_id, |
|
) |
|
) |
|
return self._hash |
|
|
|
def __eq__(self, other: object) -> bool: |
|
if not isinstance(other, DeviceMesh): |
|
return False |
|
if id(self) == id(other): |
|
return True |
|
else: |
|
return ( |
|
self._flatten_mesh_list == other._flatten_mesh_list |
|
and self.mesh.shape == other.mesh.shape |
|
and self.device_type == other.device_type |
|
and self.mesh_dim_names == other.mesh_dim_names |
|
and self._thread_id == other._thread_id |
|
) |
|
|
|
def __getitem__( |
|
self, mesh_dim_names: Union[str, tuple[str, ...]] |
|
) -> "DeviceMesh": |
|
""" |
|
Slice the current DeviceMesh based on the mesh_dim_names given to create a submesh. |
|
The submesh created consists of the dimensions and the communicators indicated by |
|
``mesh_dim_names`` |
|
|
|
Args: |
|
mesh_dim_names (Union[str, Tuple[str]]): the name or the tuple of names of the |
|
mesh dimension of the DeviceMesh to create the submesh for. |
|
Returns: |
|
A :class:`DeviceMesh` object |
|
|
|
The following program runs on each process/rank in an SPMD manner in a world size of 8. |
|
In the first example: |
|
Calling mesh_2d["tp"] on rank 0, 1, 2, 3 returns a 1D submesh of DeviceMesh:([0, 1, 2, 3]). |
|
Calling mesh_2d["tp"] on rank 4, 5, 6, 7 returns a 1D submesh of DeviceMesh:([4, 5, 6, 7]). |
|
Calling mesh_2d["dp"] on rank 0, 4 returns a 1D submesh of DeviceMesh:([0, 4]). |
|
Calling mesh_2d["dp"] on rank 1, 5 returns a 1D submesh of DeviceMesh:([1, 5]). |
|
Calling mesh_2d["dp"] on rank 2, 6 returns a 1D submesh of DeviceMesh:([2, 6]). |
|
Calling mesh_2d["dp"] on rank 3, 7 returns a 1D submesh of DeviceMesh:([3, 7]). |
|
|
|
In the second example: |
|
Calling mesh_3d["dp", "cp"] on rank 0, 1, 4, 5 returns a 2D submesh of DeviceMesh:([[0, 1], [4, 5]]). |
|
Calling mesh_3d["dp", "cp"] on rank 2, 3, 6, 7 returns a 2D submesh of DeviceMesh:([[2, 3], [6, 7]]). |
|
Calling mesh_3d["cp", "dp"] on rank 0, 1, 4, 5 returns a 2D submesh of DeviceMesh:([[0, 4], [1, 5]]). |
|
Calling mesh_3d["cp", "dp"] on rank 2, 3, 6, 7 returns a 2D submesh of DeviceMesh:([[2, 6], [3, 7]]). |
|
|
|
Example:: |
|
>>> # xdoctest: +SKIP("no rank") |
|
>>> from torch.distributed.device_mesh import DeviceMesh |
|
>>> |
|
>>> # Initialize a 2D device mesh as (2, 4) to represent the topology |
|
>>> # of cross-host(dim 0), and within-host (dim 1). |
|
>>> mesh_2d = init_device_mesh(device_type="cuda", (2,4), mesh_dim_names=("dp", "tp")) |
|
>>> tp_mesh = mesh_2d["tp"] |
|
>>> dp_mesh = mesh_2d["dp"] |
|
>>> |
|
>>> # Initialize a 3D mesh. |
|
>>> mesh_3d = init_device_mesh(device_type="cuda", (2,2,2), mesh_dim_names=("dp", "pp", "cp")) |
|
>>> # The order of the mesh_dim_names provided deteremines the order of dimensions in the submesh. |
|
>>> dp_cp_mesh = mesh_3d["dp", "cp"] |
|
>>> cp_dp_mesh = mesh_3d["cp", "dp"] |
|
""" |
|
if not self.mesh_dim_names: |
|
raise RuntimeError("Cannot slice a DeviceMesh without mesh_dim_names!") |
|
|
|
mesh_dim_names = ( |
|
(mesh_dim_names,) if isinstance(mesh_dim_names, str) else mesh_dim_names |
|
) |
|
|
|
if mesh_dim_names == self.mesh_dim_names: |
|
return self |
|
else: |
|
slice_mesh_dims = _mesh_resources._get_slice_mesh_dims( |
|
self, mesh_dim_names |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with torch._subclasses.fake_tensor.unset_fake_temporarily(): |
|
submesh = _mesh_resources.create_sub_mesh( |
|
self, mesh_dim_names, slice_mesh_dims |
|
) |
|
return submesh |
|
|
|
def get_group(self, mesh_dim: Optional[Union[int, str]] = None) -> ProcessGroup: |
|
""" |
|
Returns the single ProcessGroup specified by mesh_dim, or, if mesh_dim is not specified and the |
|
DeviceMesh is 1-dimensional, returns the only ProcessGroup in the mesh. |
|
|
|
Args: |
|
mesh_dim (str/int, optional): it can be the name of the mesh dimension or the index |
|
of the mesh dimension. Default is None. |
|
|
|
Returns: |
|
A :class:`ProcessGroup` object. |
|
""" |
|
if not hasattr(self, "_dim_group_infos"): |
|
raise RuntimeError("DeviceMesh process groups not initialized!") |
|
|
|
if self.mesh.ndim > 1 and mesh_dim is None: |
|
raise RuntimeError( |
|
f"Found the DeviceMesh have {self.mesh.ndim} dimensions", |
|
"Optional kwarg `mesh_dim` needs to be specified when device_mesh.ndim > 1.", |
|
"If you want to get the list of all the ProcessGroups in the DeviceMesh," |
|
"please use `get_all_groups()` instead.", |
|
) |
|
|
|
|
|
if self.mesh.ndim == 1 and mesh_dim is None: |
|
return not_none( |
|
_find_pg_by_ranks_and_tag(*self._dim_group_infos[0][:2]) |
|
) |
|
|
|
root_mesh = _mesh_resources.get_root_mesh(self) |
|
root_to_flatten_mapping = _mesh_resources.root_to_flatten_mapping.get( |
|
root_mesh, None |
|
) |
|
if root_to_flatten_mapping and mesh_dim in root_to_flatten_mapping.keys(): |
|
dim_group_infos = root_to_flatten_mapping[ |
|
mesh_dim |
|
]._dim_group_infos[0][:2] |
|
return not_none(_find_pg_by_ranks_and_tag(*dim_group_infos)) |
|
else: |
|
mesh_dim = ( |
|
_mesh_resources.get_mesh_dim_by_name(self, mesh_dim) |
|
if isinstance(mesh_dim, str) |
|
else mesh_dim |
|
) |
|
return not_none( |
|
_find_pg_by_ranks_and_tag(*self._dim_group_infos[mesh_dim][:2]) |
|
) |
|
|
|
def get_all_groups(self) -> list[ProcessGroup]: |
|
""" |
|
Returns a list of ProcessGroups for all mesh dimensions. |
|
|
|
Returns: |
|
A list of :class:`ProcessGroup` object. |
|
""" |
|
return [self.get_group(i) for i in range(self.mesh.ndim)] |
|
|
|
@staticmethod |
|
def from_group( |
|
group: Union[ProcessGroup, list[ProcessGroup]], |
|
device_type: str, |
|
mesh: Optional[Union[torch.Tensor, "ArrayLike"]] = None, |
|
*, |
|
mesh_dim_names: Optional[tuple[str, ...]] = None, |
|
) -> "DeviceMesh": |
|
""" |
|
Constructs a :class:`DeviceMesh` with ``device_type`` from an |
|
existing :class:`ProcessGroup` or a list of existing :class:`ProcessGroup`. |
|
|
|
The constructed device mesh has number of dimensions equal to the |
|
number of groups passed. For example, if a single process group is passed in, |
|
the resulted DeviceMesh is a 1D mesh. If a list of 2 process groups is passed in, |
|
the resulted DeviceMesh is a 2D mesh. |
|
|
|
If more than one group is passed, then the ``mesh`` and ``mesh_dim_names`` arguments |
|
are required. The order of the process groups passed in determines the topology of |
|
the mesh. For example, the first process group will be the 0th dimension of the DeviceMesh. |
|
The `mesh` tensor passed in must have the same number of dimensions as the number of process |
|
groups passed in, and the order of the dimensions in the `mesh` tensor must match the order |
|
in the process groups passed in. |
|
|
|
Args: |
|
group (ProcessGroup or list[ProcessGroup]): the existing ProcessGroup |
|
or a list of existing ProcessGroups. |
|
device_type (str): The device type of the mesh. Currently supports: "cpu", |
|
"cuda/cuda-like". Passing in a device type with a GPU index, such as "cuda:0", |
|
is not allowed. |
|
mesh (torch.Tensor or ArrayLike, optional): A multi-dimensional array or an |
|
integer tensor describing the layout of devices, where the IDs are global IDs |
|
of the default process group. Default is None. |
|
mesh_dim_names (tuple[str], optional): A tuple of mesh dimension names to assign |
|
to each dimension of the multi-dimensional array describing the layout of devices. |
|
Its length must match the length of `mesh_shape`. Each string in `mesh_dim_names` |
|
must be unique. Default is None. |
|
|
|
Returns: |
|
DeviceMesh: A :class:`DeviceMesh` object representing the device layout. |
|
""" |
|
|
|
|
|
if isinstance(group, ProcessGroup): |
|
group_ranks = get_process_group_ranks(group) |
|
if ( |
|
isinstance(mesh, torch.Tensor) and mesh.tolist() != group_ranks |
|
) or ( |
|
mesh is not None |
|
and not isinstance(mesh, torch.Tensor) |
|
and mesh != group_ranks |
|
): |
|
raise ValueError( |
|
f"Invalid mesh {str(mesh)} for ProcessGroup with ranks {group_ranks}" |
|
) |
|
mesh = torch.tensor(group_ranks, device="cpu", dtype=torch.int) |
|
device_mesh = DeviceMesh( |
|
device_type, |
|
mesh, |
|
mesh_dim_names=mesh_dim_names, |
|
_init_backend=False, |
|
) |
|
device_mesh._dim_group_infos = [ |
|
(_get_group_tag(group), group_ranks, group.group_name) |
|
] |
|
return device_mesh |
|
|
|
|
|
groups = list(group) |
|
if len(groups) == 0: |
|
raise ValueError("Expects at least one ProcessGroup to be passed") |
|
if mesh is None: |
|
raise ValueError("Must pass mesh if passing multiple ProcessGroups") |
|
if mesh_dim_names is None: |
|
raise ValueError( |
|
"Must pass mesh_dim_names if passing multiple ProcessGroups" |
|
) |
|
mesh = ( |
|
mesh.detach().to(dtype=torch.int, device="cpu") |
|
if isinstance(mesh, torch.Tensor) |
|
else torch.tensor(mesh, device="cpu", dtype=torch.int) |
|
) |
|
if mesh.ndim != len(groups): |
|
raise ValueError( |
|
"Expects mesh with ndim equal to number of ProcessGroups but got " |
|
f"mesh {mesh.tolist()} and {len(groups)} ProcessGroups" |
|
) |
|
device_mesh = DeviceMesh( |
|
device_type, mesh, mesh_dim_names=mesh_dim_names, _init_backend=False |
|
) |
|
device_mesh._dim_group_infos = [ |
|
( |
|
_get_group_tag(group), |
|
get_process_group_ranks(group), |
|
group.group_name, |
|
) |
|
for group in groups |
|
] |
|
return device_mesh |
|
|
|
def size(self, mesh_dim: Optional[int] = None) -> int: |
|
return self.mesh.numel() if mesh_dim is None else self.mesh.size(mesh_dim) |
|
|
|
@property |
|
def ndim(self) -> int: |
|
return self.mesh.ndim |
|
|
|
@property |
|
def shape(self) -> tuple[int, ...]: |
|
return tuple(self.mesh.shape) |
|
|
|
def get_rank(self) -> int: |
|
""" |
|
Returns the current global rank. |
|
""" |
|
return get_rank() |
|
|
|
def get_local_rank(self, mesh_dim: Optional[Union[int, str]] = None) -> int: |
|
""" |
|
Returns the local rank of the given mesh_dim of the DeviceMesh. |
|
|
|
Args: |
|
mesh_dim (str/int, optional): it can be the name of the mesh dimension or the index |
|
of the mesh dimension. Default is None. |
|
|
|
Returns: |
|
An integer denotes the local rank. |
|
|
|
The following program runs on each process/rank in an SPMD manner. In this example, we have 2 |
|
hosts with 4 GPUs each. |
|
Calling mesh_2d.get_local_rank(mesh_dim=0) on rank 0, 1, 2, 3 would return 0. |
|
Calling mesh_2d.get_local_rank(mesh_dim=0) on rank 4, 5, 6, 7 would return 1. |
|
Calling mesh_2d.get_local_rank(mesh_dim=1) on rank 0, 4 would return 0. |
|
Calling mesh_2d.get_local_rank(mesh_dim=1) on rank 1, 5 would return 1. |
|
Calling mesh_2d.get_local_rank(mesh_dim=1) on rank 2, 6 would return 2. |
|
Calling mesh_2d.get_local_rank(mesh_dim=1) on rank 3, 7 would return 3. |
|
|
|
Example:: |
|
>>> # xdoctest: +SKIP("no rank") |
|
>>> from torch.distributed.device_mesh import DeviceMesh |
|
>>> |
|
>>> # Initialize device mesh as (2, 4) to represent the topology |
|
>>> # of cross-host(dim 0), and within-host (dim 1). |
|
>>> mesh = DeviceMesh(device_type="cuda", mesh=[[0, 1, 2, 3],[4, 5, 6, 7]]) |
|
""" |
|
if self.ndim > 1 and mesh_dim is None: |
|
raise RuntimeError( |
|
f"Found the DeviceMesh have {self.mesh.ndim} dimensions", |
|
"Optional kwarg `mesh_dim` needs to be specified when device_mesh.ndim > 1.", |
|
) |
|
elif mesh_dim is None: |
|
mesh_dim = 0 |
|
|
|
mesh_dim_group = not_none(self.get_group(mesh_dim)) |
|
assert isinstance(mesh_dim_group, ProcessGroup), ( |
|
"We expect ProcessGroup before calling `get_rank`!" |
|
) |
|
return not_none(get_rank(mesh_dim_group)) |
|
|
|
def get_coordinate(self) -> Optional[list[int]]: |
|
""" |
|
Return the relative indices of this rank relative to all |
|
dimensions of the mesh. If this rank is not part of the mesh, return None. |
|
""" |
|
return self._coordinate_on_dim if self._coordinate_on_dim else None |
|
|
|
def _flatten(self, mesh_dim_name: Optional[str] = None) -> "DeviceMesh": |
|
""" |
|
Returns a 1D DeviceMesh by flattening the current DeviceMesh. |
|
|
|
If no mesh_dim_name is provided, the default is a string concatentaing the mesh_dim_names of the |
|
given submesh with each mesh_dim_name separated by "_". For example, if we have a 3D mesh |
|
DeviceMesh([[[0, 1], [2, 3]], [[4, 5], [6, 7]]], mesh_dim_names=("dp", "cp", "tp")), calling |
|
mesh_3d["dp", "cp"]._flatten() will create a 1D submesh DeviceMesh([0, 1, 2, 3], mesh_dim_names=("dp_cp",)) |
|
on rank 0, 1, 2, 3 and a 1D submesh DeviceMesh([4, 5, 6, 7], mesh_dim_names=("dp_cp",)) on rank 4, 5, 6, 7. |
|
|
|
After the flattened dimension is created, to access the flattened dimesnion in mesh_3d, one can use the |
|
existing slicing method to obtain the flattened mesh through calling mesh_3d["dp_cp"]. |
|
""" |
|
if not self.mesh_dim_names: |
|
raise RuntimeError( |
|
"Cannot flatten a DeviceMesh without mesh_dim_names!" |
|
) |
|
|
|
return _mesh_resources.create_flatten_mesh(self, mesh_dim_name) |
|
|
|
def init_device_mesh( |
|
device_type: str, |
|
mesh_shape: tuple[int, ...], |
|
*, |
|
mesh_dim_names: Optional[tuple[str, ...]] = None, |
|
) -> DeviceMesh: |
|
""" |
|
Initializes a `DeviceMesh` based on `device_type`, `mesh_shape`, and `mesh_dim_names` parameters. |
|
|
|
This creates a DeviceMesh with an n-dimensional array layout, where `n` is the length of `mesh_shape`. |
|
If `mesh_dim_names` is provided, each dimension is labeled as `mesh_dim_names[i]`. |
|
|
|
.. note:: |
|
`init_device_mesh` follows SPMD programming model, meaning the same PyTorch Python program |
|
runs on all processes/ranks in the cluster. Ensure `mesh_shape` (the dimensions of the nD array |
|
describing device layout) is identical across all ranks. Inconsistent `mesh_shape` may lead to hanging. |
|
|
|
.. note:: |
|
If no process group is found, init_device_mesh will initialize distributed process group/groups |
|
required for distributed communications behind the scene. |
|
|
|
Args: |
|
device_type (str): The device type of the mesh. Currently supports: "cpu", "cuda/cuda-like". |
|
Passing in a device type with a GPU index, such as "cuda:0", is not allowed. |
|
mesh_shape (Tuple[int]): A tuple defining the dimensions of the multi-dimensional array |
|
describing the layout of devices. |
|
mesh_dim_names (Tuple[str], optional): A tuple of mesh dimension names to assign to each dimension |
|
of the multi-dimensional array describing the layout of devices. Its length must match the length |
|
of `mesh_shape`. Each string in `mesh_dim_names` must be unique. |
|
|
|
Returns: |
|
DeviceMesh: A :class:`DeviceMesh` object representing the device layout. |
|
|
|
Example:: |
|
>>> # xdoctest: +SKIP("no rank") |
|
>>> from torch.distributed.device_mesh import init_device_mesh |
|
>>> |
|
>>> mesh_1d = init_device_mesh("cuda", mesh_shape=(8,)) |
|
>>> mesh_2d = init_device_mesh("cuda", mesh_shape=(2, 8), mesh_dim_names=("dp", "tp")) |
|
|
|
""" |
|
if mesh_dim_names is not None: |
|
if len(set(mesh_dim_names)) != len(mesh_dim_names): |
|
raise RuntimeError( |
|
"Each mesh_dim_name must be unique.", |
|
f"Found repeated mesh_dim_name in mesh_dim_names {mesh_dim_names}", |
|
) |
|
|
|
if len(mesh_shape) != len(mesh_dim_names): |
|
raise RuntimeError( |
|
"mesh_shape and mesh_dim_names should have same length!", |
|
f"Found len(mesh_dim_names): {len(mesh_dim_names)} and len(mesh_shape):{len(mesh_shape)}.", |
|
) |
|
|
|
|
|
if device_type and not device_type.isalpha(): |
|
raise RuntimeError( |
|
f"Device type with index is not supported but got {device_type}. ", |
|
"If you maintained a 'torch.device' object, it's recommended to pass in 'device.type'.", |
|
) |
|
|
|
|
|
|
|
with torch.device("cpu"): |
|
mesh = torch.arange(math.prod(mesh_shape), dtype=torch.int).view(mesh_shape) |
|
device_mesh = DeviceMesh( |
|
device_type=device_type, |
|
mesh=mesh, |
|
mesh_dim_names=mesh_dim_names, |
|
) |
|
|
|
return device_mesh |
|
|