|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""The Pipe interface.""" |
|
from collections import OrderedDict |
|
from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Tuple, Union, cast |
|
import warnings |
|
|
|
import torch |
|
from torch import Tensor, nn |
|
import torch.autograd |
|
import torch.cuda |
|
|
|
from fairscale.internal import torch_version |
|
|
|
from . import microbatch |
|
from .batchnorm import DeferredBatchNorm |
|
from .pipeline import Pipeline |
|
from .skip.layout import inspect_skip_layout |
|
from .skip.skippable import verify_skippables |
|
from .stream import AbstractStream, new_stream |
|
|
|
__all__ = ["Pipe"] |
|
|
|
|
|
Device = Union[torch.device, int, str] |
|
Devices = Union[Iterable[Device], List[Device]] |
|
|
|
Tensors = Tuple[Tensor, ...] |
|
TensorOrTensors = Union[Tensor, Tensors] |
|
|
|
if TYPE_CHECKING: |
|
Module = nn.Module[TensorOrTensors] |
|
NamedModules = OrderedDict[str, Module] |
|
else: |
|
Module = nn.Module |
|
NamedModules = OrderedDict |
|
|
|
|
|
def recommend_auto_balance(message: str) -> str: |
|
"""Expands a message with recommendation to :mod:`torchpipe.balance`.""" |
|
return f"""{message} |
|
|
|
If your model is still under development, its optimal balance would change |
|
frequently. In this case, we highly recommend 'fairscale.nn.pipe.balance' for |
|
naive automatic balancing: |
|
|
|
from fairscale.nn import Pipe |
|
from fairscale.nn.pipe.balance import balance_by_time |
|
|
|
partitions = torch.cuda.device_count() |
|
sample = torch.empty(...) |
|
balance = balance_by_time(partitions, model, sample) |
|
|
|
model = Pipe(model, balance, ...) |
|
""" |
|
|
|
|
|
def verify_module(module: nn.Sequential) -> None: |
|
if not isinstance(module, nn.Sequential): |
|
raise TypeError("module must be nn.Sequential to be partitioned") |
|
|
|
named_children = list(module.named_children()) |
|
if len(named_children) != len(module): |
|
raise ValueError("module with duplicate children is not supported") |
|
|
|
|
|
def verify_splitting( |
|
module: nn.Sequential, partitions: List[nn.Sequential], balance: Iterable[int], devices: List[torch.device] |
|
) -> None: |
|
num_parameters = len(list(module.parameters())) |
|
num_child_parameters = sum(len(list(child.parameters())) for child in module.children()) |
|
if num_parameters == num_child_parameters: |
|
return |
|
|
|
for i in range(len(partitions)): |
|
for j in range(i + 1, len(partitions)): |
|
parti = partitions[i] |
|
partj = partitions[j] |
|
if devices[i] == devices[j]: |
|
continue |
|
for p in parti.parameters(): |
|
for q in partj.parameters(): |
|
if p is q: |
|
raise ValueError("module with duplicate parameters on distinct devices is not supported") |
|
|
|
|
|
class BalanceError(ValueError): |
|
pass |
|
|
|
|
|
def split_module( |
|
module: nn.Sequential, |
|
balance: Iterable[int], |
|
devices: List[torch.device], |
|
) -> Tuple[List[nn.Sequential], List[int], List[torch.device]]: |
|
"""Splits a module into multiple partitions. |
|
|
|
Returns: |
|
A tuple of (partitions, balance, devices). |
|
|
|
Partitions are represented as a :class:`~torch.nn.ModuleList` whose |
|
item is a partition. All layers in a partition are placed in the |
|
same device. |
|
|
|
Raises: |
|
BalanceError: |
|
wrong balance |
|
IndexError: |
|
the number of devices is fewer than the number of partitions. |
|
|
|
""" |
|
balance = list(balance) |
|
|
|
if len(module) != sum(balance): |
|
raise BalanceError( |
|
"module and sum of balance have different length " |
|
f"(module: {len(module)}, sum of balance: {sum(balance)})" |
|
) |
|
|
|
if any(x <= 0 for x in balance): |
|
raise BalanceError(f"all balance numbers must be positive integer (balance: {balance})") |
|
|
|
if len(balance) > len(devices): |
|
raise IndexError( |
|
"too few devices to hold given partitions " f"(devices: {len(devices)}, partitions: {len(balance)})" |
|
) |
|
|
|
j = 0 |
|
partitions = [] |
|
layers: NamedModules = OrderedDict() |
|
|
|
for name, layer in module.named_children(): |
|
layers[name] = layer |
|
|
|
if len(layers) == balance[j]: |
|
|
|
partition = nn.Sequential(layers) |
|
|
|
device = devices[j] |
|
partition.to(device) |
|
|
|
partitions.append(partition) |
|
|
|
|
|
layers.clear() |
|
j += 1 |
|
|
|
partitions = cast(List[nn.Sequential], nn.ModuleList(partitions)) |
|
del devices[j:] |
|
|
|
return partitions, balance, devices |
|
|
|
|
|
MOVING_DENIED = TypeError("denied to move parameters and buffers, " "because Pipe should manage device placement") |
|
|
|
|
|
class Pipe(Module): |
|
"""Wraps an arbitrary :class:`nn.Sequential <torch.nn.Sequential>` module |
|
to train on Pipe_. If the module requires lots of memory, Pipe will be |
|
very efficient. |
|
:: |
|
|
|
model = nn.Sequential(a, b, c, d) |
|
model = Pipe(model, balance=[1, 1, 1, 1], chunks=8) |
|
output = model(input) |
|
|
|
.. _Pipe: https://arxiv.org/abs/1811.06965 |
|
|
|
Pipe combines pipeline parallelism with checkpointing to reduce peak |
|
memory required to train while minimizing device under-utilization. |
|
|
|
You should determine the balance when defining a :class:`Pipe` module, as |
|
balancing will not be done automatically. The module will be partitioned |
|
into multiple devices according to the given balance. You may rely on |
|
heuristics to find your own optimal configuration. |
|
|
|
Args: |
|
module (torch.nn.Sequential): |
|
sequential module to be parallelized |
|
balance (ints): |
|
list of number of layers in each partition |
|
|
|
Keyword Args: |
|
devices (iterable of devices): |
|
devices to use (default: all CUDA devices) |
|
chunks (int): |
|
number of micro-batches (default: ``1``) |
|
checkpoint (str): |
|
when to enable checkpointing, one of ``'always'``, |
|
``'except_last'``, or ``'never'`` (default: ``'except_last'``) |
|
deferred_batch_norm (bool): |
|
whether to use deferred BatchNorm moving statistics (default: |
|
:data:`False`, see :class:`Deferred Batch Normalization <DeferredBatchNorm>` for more |
|
details) |
|
|
|
Raises: |
|
TypeError: |
|
the module is not a :class:`nn.Sequential <torch.nn.Sequential>`. |
|
ValueError: |
|
invalid arguments, or wrong balance |
|
IndexError: |
|
the number of devices is fewer than the number of partitions. |
|
|
|
""" |
|
|
|
|
|
balance: List[int] = [] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
devices: List[torch.device] = [] |
|
|
|
|
|
chunks: int = 1 |
|
|
|
|
|
|
|
checkpoint: str = "except_last" |
|
|
|
def __init__( |
|
self, |
|
module: nn.Sequential, |
|
balance: Optional[Iterable[int]] = None, |
|
*, |
|
devices: Optional[Devices] = None, |
|
chunks: int = chunks, |
|
checkpoint: str = checkpoint, |
|
deferred_batch_norm: bool = False, |
|
) -> None: |
|
super().__init__() |
|
|
|
if torch_version()[:2] >= (1, 8): |
|
warnings.warn( |
|
"fairscale.nn.Pipe has been upstreamed to PyTorch as torch.distributed.pipeline.sync.Pipe. " |
|
"It is now deprecated and will be removed in a future version of fairscale. " |
|
"The PyTorch API has minor changes. Please see https://pytorch.org/docs/stable/pipeline.html for details.", |
|
DeprecationWarning, |
|
) |
|
|
|
chunks = int(chunks) |
|
checkpoint = str(checkpoint) |
|
|
|
if balance is None: |
|
raise ValueError(recommend_auto_balance("balance is required")) |
|
if chunks <= 0: |
|
raise ValueError("number of chunks must be positive integer") |
|
if checkpoint not in ["always", "except_last", "never"]: |
|
raise ValueError("checkpoint is not one of 'always', 'except_last', or 'never'") |
|
|
|
verify_module(module) |
|
|
|
|
|
|
|
verify_skippables(module) |
|
|
|
self.chunks = chunks |
|
self.checkpoint = checkpoint |
|
|
|
if deferred_batch_norm: |
|
module = DeferredBatchNorm.convert_deferred_batch_norm(module, chunks) |
|
|
|
if devices is None: |
|
devices = range(torch.cuda.device_count()) |
|
devices = [torch.device(d) for d in devices] |
|
devices = cast(List[torch.device], devices) |
|
|
|
try: |
|
self.partitions, self.balance, self.devices = split_module(module, balance, devices) |
|
except BalanceError as exc: |
|
raise ValueError(recommend_auto_balance(str(exc))) |
|
|
|
verify_splitting(module, self.partitions, self.balance, self.devices) |
|
|
|
self._copy_streams: List[List[AbstractStream]] = [] |
|
self._skip_layout = inspect_skip_layout(self.partitions) |
|
|
|
|
|
copy_streams = self._ensure_copy_streams() |
|
|
|
|
|
checkpoint_stop = {"always": self.chunks, "except_last": self.chunks - 1, "never": 0}[self.checkpoint] |
|
|
|
self.pipeline = Pipeline(self.partitions, self.devices, copy_streams, self._skip_layout, checkpoint_stop) |
|
|
|
def __len__(self) -> int: |
|
"""Counts the length of the underlying sequential module.""" |
|
return sum(len(p) for p in self.partitions) |
|
|
|
def __getitem__(self, index: int) -> nn.Module: |
|
"""Gets a layer in the underlying sequential module.""" |
|
partitions = self.partitions |
|
if index < 0: |
|
partitions = partitions[::-1] |
|
|
|
for partition in partitions: |
|
try: |
|
return partition[index] |
|
except IndexError: |
|
pass |
|
|
|
shift = len(partition) |
|
|
|
if index < 0: |
|
index += shift |
|
else: |
|
index -= shift |
|
|
|
raise IndexError |
|
|
|
def __iter__(self) -> Iterable[nn.Module]: |
|
"""Iterates over children of the underlying sequential module.""" |
|
for partition in self.partitions: |
|
yield from partition |
|
|
|
|
|
|
|
def cuda(self, device: Optional[Device] = None) -> "Pipe": |
|
raise MOVING_DENIED |
|
|
|
def cpu(self) -> "Pipe": |
|
raise MOVING_DENIED |
|
|
|
def to(self, *args: Any, **kwargs: Any) -> "Pipe": |
|
"""Deny these usages: |
|
- to(device[, dtype, non_blocking]) |
|
- to(tensor[, non_blocking]) |
|
|
|
But allow this: |
|
- to(dtype[, non_blocking])""" |
|
|
|
if "device" in kwargs or "tensor" in kwargs: |
|
raise MOVING_DENIED |
|
|
|
if args: |
|
if isinstance(args[0], (torch.device, int, str)): |
|
raise MOVING_DENIED |
|
if torch.is_tensor(args[0]): |
|
raise MOVING_DENIED |
|
|
|
return super().to(*args, **kwargs) |
|
|
|
def _ensure_copy_streams(self) -> List[List[AbstractStream]]: |
|
"""Ensures that :class:`Pipe` caches CUDA streams for copy. |
|
|
|
It's worth to cache CUDA streams although PyTorch already manages a |
|
pool of pre-allocated CUDA streams, because it may reduce GPU memory |
|
fragementation when the number of micro-batches is small. |
|
|
|
""" |
|
if not self._copy_streams: |
|
for device in self.devices: |
|
self._copy_streams.append([new_stream(device) for _ in range(self.chunks)]) |
|
|
|
return self._copy_streams |
|
|
|
def forward(self, input: TensorOrTensors) -> TensorOrTensors: |
|
""":class:`Pipe` is a fairly transparent module wrapper. It doesn't |
|
modify the input and output signature of the underlying module. But |
|
there's type restriction. Input and output have to be a |
|
:class:`~torch.Tensor` or a tuple of tensors. This restriction is |
|
applied at partition boundaries too. |
|
|
|
Args: |
|
input (torch.Tensor or tensors): input mini-batch |
|
|
|
Returns: |
|
tensor or tensors: output mini-batch |
|
|
|
Raises: |
|
TypeError: input is not a tensor or tensors. |
|
|
|
""" |
|
microbatch.check(input) |
|
|
|
if not self.devices: |
|
|
|
return input |
|
|
|
|
|
batches = microbatch.scatter(input, self.chunks) |
|
|
|
|
|
self.pipeline.run(batches) |
|
|
|
|
|
output = microbatch.gather(batches) |
|
return output |
|
|