File size: 6,089 Bytes
9c6594c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 |
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
"""Useful functions for parallel training."""
from enum import Enum
import sys
from typing import TYPE_CHECKING, List, Optional, Sequence
import torch
import torch.distributed as dist
import torch.nn.functional as F
if TYPE_CHECKING:
# See comments in FSDP code for reason of this import.
from torch.distributed import ProcessGroup
def chunk_and_pad(tensor: torch.Tensor, num_chunks: int) -> List[torch.Tensor]:
"""Chunk a given Tensor into num_chunks parts and add any necessary padding."""
chunks = list(torch.flatten(tensor).chunk(num_chunks))
# torch.chunk may return fewer than num_chunks chunks, pad accordingly.
num_pad_for_partial_chunk = chunks[0].numel() - chunks[-1].numel()
if num_pad_for_partial_chunk > 0:
chunks[-1] = F.pad(chunks[-1], [0, num_pad_for_partial_chunk])
if len(chunks) < num_chunks:
chunks.extend([torch.zeros_like(chunks[0]) for _ in range(num_chunks - len(chunks))])
return chunks
def validate_process_group(device: torch.device, process_group: "ProcessGroup") -> None:
"""Do a quick test in case user called FSDP without calling torch.cuda.set_device()
correctly. This can easily happen in cpu_offload case where the model resides on
the CPU.
"""
if not hasattr(process_group, "allgather"):
# Likely a dummy pg for unit test, skip checking.
return
world_size = process_group.size()
if "cuda" in str(device):
input_tensor = torch.ones(1).to(device)
output = list(torch.zeros(world_size).to(device).chunk(world_size))
dist.all_gather(output, input_tensor, group=process_group)
assert torch.cat(output).sum() == float(world_size), (
f"found {torch.cat(output).sum()} devices in process group but "
f"world_size={world_size}. Check torch.cuda.set_device is called properly"
)
def enable_pytorch_sync_bn(module: torch.nn.Module) -> None:
"""Call _specify_ddp_gpu_num for all pytorch SyncBN layers so that it
is happily running even without DDP. E.g. this is used by FSDP.
"""
for layer in module.modules():
if isinstance(layer, torch.nn.modules.SyncBatchNorm) and hasattr(layer, "_specify_ddp_gpu_num"):
# Number "1" below meant to be the number of GPUs for each DDP worker.
# (i.e. "device_ids" in DDP. As far as I see, the value is not actually
# used, but this call needs to be made to avoid an exception.
# This function is removed from pytorch since 1.9.
layer._specify_ddp_gpu_num(1) # type: ignore
class ProcessGroupName(str, Enum):
default = "default"
reduce_scatter = "reduce_scatter"
def get_process_group_cached(
name: ProcessGroupName = ProcessGroupName.default, ranks: Optional[Sequence[int]] = None
) -> "ProcessGroup":
"""
Singleton PyTorch distributed group cache. Inspired by the code from fairseq.
Just like torch.distributed.new_group, this method needs to be called on all ranks
at the same time when a new group is created. This is true for all ranks irrespective
of their group membership status.
For FSDP, it is important to use the same group between outer and inner FSDP instances,
otherwise, inner FSDP instances will not share the gradient reduction bucket buffer with
the root instance. This will result in increased GPU memory utilization.
Each separate process group also uses separate NCCL library instances, which will have
a significant effect on GPU memory use if too many process groups are created and used.
Setting NCCL_BUFFSIZE=102400 env variable is a useful technique to check if the NCCL
memory is causing GPU OOM. Note, the NCCL buffers are not allocated
through the PyTorch caching allocator, therefore, you may see GPU OOM even when
torch.cuda.reserved_memory() is still way below the total amount of GPU memory.
Extra process groups can also reduce training speed (observed on VISSL models).
Args:
name ProcessGroupName:
There are two process groups when reduce_scatter overlap is enabled. The "default" process group is the
default process group. The other group is "reduce_scatter" group.
Default: ProcessGroupName.default
ranks (Optional[List[int]]):
Ranks requested in the target group. None for all ranks.
Default: None
Returns:
(ProcessGroup):
Return the requested process group. Throws RuntimeError if torch.distributed module is not yet initialized.
"""
if not dist.is_initialized():
# Likely caused by initiating a dummy pg for unit test, skip checking.
if name == ProcessGroupName.reduce_scatter and "pytest" in sys.modules:
return None
else:
raise RuntimeError("torch.distributed is not yet initialized but process group is requested.")
# Init the cache if needed.
if not hasattr(get_process_group_cached, "_global_group_cache"):
get_process_group_cached._global_group_cache = {} # type: ignore
# Populate with default process group.
cache = get_process_group_cached._global_group_cache # type: ignore
default_pg = dist.new_group(ranks=ranks)
cache[None] = default_pg
cache[(ProcessGroupName.default, None)] = default_pg
cache[(ProcessGroupName.default, frozenset(list(range(dist.get_world_size()))))] = default_pg
# Lookup and fill the cache if needed.
cache = get_process_group_cached._global_group_cache # type: ignore
if ranks is not None:
# take care of ordering and duplicates in the ranks list. use tuple so that ranks
# can be used as a cache index.
ranks = tuple(sorted(list(set(ranks))))
if (name, ranks) not in cache:
cache[(name, ranks)] = dist.new_group(ranks=ranks)
return cache[(name, ranks)]
|