File size: 8,398 Bytes
9c6594c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 |
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import functools
import os
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple
import torch
from torch import Tensor
import torch.distributed as dist
if TYPE_CHECKING:
from torch.distributed import ProcessGroup
# TODO: Remove the toggle-enable_nccl_base_collectives when github open issue #801 is resolved.
if os.getenv("ENABLE_NCCL_BASE_COLLECTIVES", "1") == "0":
enable_nccl_base_collectives = False
else:
enable_nccl_base_collectives = True
class Bucket:
def __init__(self, data: Tensor, group: "ProcessGroup"):
self.data = data
self.group = group
self.offset = 0
self.callbacks: List[Callable] = []
self.output_shard = torch.zeros_like(data[0])
def flush(self) -> None:
"""Flush content of the bucket."""
if self.offset == 0:
assert len(self.callbacks) == 0
return
# reduce-scatter bucket
if hasattr(dist, "_reduce_scatter_base") and enable_nccl_base_collectives:
dist._reduce_scatter_base(
self.output_shard[: self.offset], self.data[:, : self.offset].contiguous(), group=self.group
)
else:
dist.reduce_scatter(
self.output_shard[: self.offset], list(self.data[:, : self.offset].unbind(0)), group=self.group
)
# execute post-reduction callbacks
for callback_fn in self.callbacks:
callback_fn()
# reuse input bucket but allocate a fresh output shard
self.data[:, : self.offset].zero_()
self.offset = 0
self.callbacks.clear()
self.output_shard = torch.zeros_like(self.data[0])
def setup(self) -> None:
"""Setup the buffers if they are not allocated.
Using ``setup`` and ``teardown``, we can ensure that the bucket
buffers are only allocated during the backward pass, hence saving more
memory to other parts of the training process, such as the forward pass
for activation memory.
"""
for tensor in [self.data, self.output_shard]:
if tensor.storage().size() == 0:
tensor.storage().resize_(tensor.size().numel())
def teardown(self) -> None:
"""Tear down the bucket by freeing the memory"""
assert self.offset == 0 and self.callbacks == [], "Incorrect call of teardown"
for tensor in [self.data, self.output_shard]:
tensor.storage().resize_(0)
class ReduceScatterBucketer:
"""
Helper for bucketing multiple reduce-scatter operations on small tensors
into larger reduce-scatter ops to improve communication efficiency.
Usage::
bucketer = ReduceScatterBucketer()
bucketer.reduce_scatter_async(
small_tensors, callback_fn=lambda result: print("small")
)
bucketer.reduce_scatter_async(
big_tensors, callback_fn=lambda result: print("big")
)
bucketer.reduce_scatter_async(
more_small_tensors, callback_fn=lambda result: print("small2")
)
bucketer.flush() # callbacks only guaranteed to be called after flush()
# Example output (note that it is out of order, due to bucketing):
# big
# small
# small2
Args:
bucket_cap_mb (int, Optional): bucket size for communicating. Buckets
are sub-divided based on world_size. Values <= 0 disable bucketing.
"""
def __init__(self, bucket_cap_mb: int = 25):
self.bucket_cap_mb = bucket_cap_mb
self.buckets: Dict[Tuple[torch.dtype, torch.device, "ProcessGroup"], Bucket] = {}
@torch.no_grad()
def reduce_scatter_async(
self,
input_list: List[Tensor],
group: "ProcessGroup",
callback_fn: Optional[Callable] = None,
) -> None:
"""
Reduce-scatter a list of tensors asynchronously, so smaller reductions
can be bucketed together. The given callback (``callback_fn``) will be
called with the reduced result at some later time. Call ``flush()`` to
force all queued ops and callbacks to be executed.
Note that large inputs will be reduced immediately, and this function
may also flush the relevant bucket to make room for ``input_list``.
Args:
input_list (List[Tensor]): list of tensors to reduce-scatter. List
should contain ``group.size()`` tensors and each tensor should
have identical shape, dtype and device.
group (ProcessGroup): process group for reduction
callback_fn (Callable, Optional): callback function to call after
the reduction executes. Function will be called with a single
argument corresponding to the reduced result.
"""
world_size = group.size()
assert (
len(input_list) == world_size
), f"reduce_scatter received {len(input_list)} inputs, expected group.size() ({world_size})"
first_input = input_list[0]
first_input_size = first_input.numel()
bucket_shard_size = self._get_shard_size(first_input.element_size(), world_size)
if first_input_size > bucket_shard_size:
# TODO: investigate how to avoid using torch.cat (because it seems to be slow for CPU tensors)
# input is too big to fit in the bucket, reduce-scatter directly
output = torch.zeros_like(input_list[0])
if hasattr(dist, "_reduce_scatter_base") and enable_nccl_base_collectives:
input_flattened = torch.cat(input_list)
dist._reduce_scatter_base(output, input_flattened, group=group)
else:
# fallback
dist.reduce_scatter(output, input_list, group=group)
if callback_fn is not None:
callback_fn(output)
return
bucket = self._get_bucket(first_input, group)
if first_input_size > bucket.data.size(1) - bucket.offset:
# not enough space remaining in bucket, flush it now
bucket.flush()
# copy data from input_list into bucket
stacked_input = torch.stack(input_list).view(world_size, first_input_size)
offset = bucket.offset
bucket.data[:, offset : offset + first_input_size].copy_(stacked_input)
bucket.offset += first_input_size
# callback will be given the reduced result
if callback_fn is not None:
result_view = bucket.output_shard[offset : offset + first_input_size].view_as(first_input)
bucket.callbacks.append(functools.partial(callback_fn, result_view))
@torch.no_grad()
def flush(self) -> None:
"""Reduce-scatter any partial buckets."""
for bucket in self.buckets.values():
bucket.flush()
@torch.no_grad()
def teardown(self) -> None:
"""Free buffers from all buckets."""
for bucket in self.buckets.values():
bucket.teardown()
@functools.lru_cache()
def _get_shard_size(self, element_size: int, num_shards: int) -> int:
if self.bucket_cap_mb <= 0: # Values <= 0 disable bucketing.
return 0
MB = 1024 * 1024
bucket_size = self.bucket_cap_mb * MB / element_size
return int(bucket_size // num_shards)
def _get_bucket(self, tensor: Tensor, group: "ProcessGroup") -> Bucket:
# TODO (Min): the `group` used here in the key is the object hash, not the content
# hash. That means if FSDP instances are initialized with different process groups,
# even when the group members are in fact the same, we end up creating different
# buckets here.
key = (tensor.dtype, tensor.device, group)
if key not in self.buckets:
# buckets are divided into world_size pieces, bucket.data shaped (world_size, shard_size)
world_size = group.size()
shard_size = self._get_shard_size(tensor.element_size(), world_size)
data = tensor.new_zeros((world_size, shard_size))
self.buckets[key] = Bucket(data, group)
self.buckets[key].setup()
return self.buckets[key]
|