|
|
|
|
|
|
|
|
|
|
|
"""Useful functions for parallel training.""" |
|
|
|
from enum import Enum |
|
import sys |
|
from typing import TYPE_CHECKING, List, Optional, Sequence |
|
|
|
import torch |
|
import torch.distributed as dist |
|
import torch.nn.functional as F |
|
|
|
if TYPE_CHECKING: |
|
|
|
from torch.distributed import ProcessGroup |
|
|
|
|
|
def chunk_and_pad(tensor: torch.Tensor, num_chunks: int) -> List[torch.Tensor]: |
|
"""Chunk a given Tensor into num_chunks parts and add any necessary padding.""" |
|
chunks = list(torch.flatten(tensor).chunk(num_chunks)) |
|
|
|
num_pad_for_partial_chunk = chunks[0].numel() - chunks[-1].numel() |
|
if num_pad_for_partial_chunk > 0: |
|
chunks[-1] = F.pad(chunks[-1], [0, num_pad_for_partial_chunk]) |
|
if len(chunks) < num_chunks: |
|
chunks.extend([torch.zeros_like(chunks[0]) for _ in range(num_chunks - len(chunks))]) |
|
return chunks |
|
|
|
|
|
def validate_process_group(device: torch.device, process_group: "ProcessGroup") -> None: |
|
"""Do a quick test in case user called FSDP without calling torch.cuda.set_device() |
|
correctly. This can easily happen in cpu_offload case where the model resides on |
|
the CPU. |
|
""" |
|
if not hasattr(process_group, "allgather"): |
|
|
|
return |
|
|
|
world_size = process_group.size() |
|
if "cuda" in str(device): |
|
input_tensor = torch.ones(1).to(device) |
|
output = list(torch.zeros(world_size).to(device).chunk(world_size)) |
|
dist.all_gather(output, input_tensor, group=process_group) |
|
assert torch.cat(output).sum() == float(world_size), ( |
|
f"found {torch.cat(output).sum()} devices in process group but " |
|
f"world_size={world_size}. Check torch.cuda.set_device is called properly" |
|
) |
|
|
|
|
|
def enable_pytorch_sync_bn(module: torch.nn.Module) -> None: |
|
"""Call _specify_ddp_gpu_num for all pytorch SyncBN layers so that it |
|
is happily running even without DDP. E.g. this is used by FSDP. |
|
""" |
|
for layer in module.modules(): |
|
if isinstance(layer, torch.nn.modules.SyncBatchNorm) and hasattr(layer, "_specify_ddp_gpu_num"): |
|
|
|
|
|
|
|
|
|
layer._specify_ddp_gpu_num(1) |
|
|
|
|
|
class ProcessGroupName(str, Enum): |
|
default = "default" |
|
reduce_scatter = "reduce_scatter" |
|
|
|
|
|
def get_process_group_cached( |
|
name: ProcessGroupName = ProcessGroupName.default, ranks: Optional[Sequence[int]] = None |
|
) -> "ProcessGroup": |
|
""" |
|
Singleton PyTorch distributed group cache. Inspired by the code from fairseq. |
|
|
|
Just like torch.distributed.new_group, this method needs to be called on all ranks |
|
at the same time when a new group is created. This is true for all ranks irrespective |
|
of their group membership status. |
|
|
|
For FSDP, it is important to use the same group between outer and inner FSDP instances, |
|
otherwise, inner FSDP instances will not share the gradient reduction bucket buffer with |
|
the root instance. This will result in increased GPU memory utilization. |
|
|
|
Each separate process group also uses separate NCCL library instances, which will have |
|
a significant effect on GPU memory use if too many process groups are created and used. |
|
Setting NCCL_BUFFSIZE=102400 env variable is a useful technique to check if the NCCL |
|
memory is causing GPU OOM. Note, the NCCL buffers are not allocated |
|
through the PyTorch caching allocator, therefore, you may see GPU OOM even when |
|
torch.cuda.reserved_memory() is still way below the total amount of GPU memory. |
|
|
|
Extra process groups can also reduce training speed (observed on VISSL models). |
|
|
|
Args: |
|
name ProcessGroupName: |
|
There are two process groups when reduce_scatter overlap is enabled. The "default" process group is the |
|
default process group. The other group is "reduce_scatter" group. |
|
Default: ProcessGroupName.default |
|
ranks (Optional[List[int]]): |
|
Ranks requested in the target group. None for all ranks. |
|
Default: None |
|
|
|
Returns: |
|
(ProcessGroup): |
|
Return the requested process group. Throws RuntimeError if torch.distributed module is not yet initialized. |
|
""" |
|
if not dist.is_initialized(): |
|
|
|
if name == ProcessGroupName.reduce_scatter and "pytest" in sys.modules: |
|
return None |
|
else: |
|
raise RuntimeError("torch.distributed is not yet initialized but process group is requested.") |
|
|
|
|
|
if not hasattr(get_process_group_cached, "_global_group_cache"): |
|
get_process_group_cached._global_group_cache = {} |
|
|
|
cache = get_process_group_cached._global_group_cache |
|
|
|
default_pg = dist.new_group(ranks=ranks) |
|
cache[None] = default_pg |
|
cache[(ProcessGroupName.default, None)] = default_pg |
|
cache[(ProcessGroupName.default, frozenset(list(range(dist.get_world_size()))))] = default_pg |
|
|
|
|
|
cache = get_process_group_cached._global_group_cache |
|
if ranks is not None: |
|
|
|
|
|
ranks = tuple(sorted(list(set(ranks)))) |
|
if (name, ranks) not in cache: |
|
cache[(name, ranks)] = dist.new_group(ranks=ranks) |
|
|
|
return cache[(name, ranks)] |
|
|