|
|
|
|
|
|
|
|
|
|
|
import torch |
|
import deepspeed |
|
from deepspeed.runtime.utils import partition_uniform as partition |
|
|
|
|
|
def split_tensor_along_last_dim(tensor, partitions, contiguous_split_chunks=False): |
|
"""Split a tensor along its last dimension. Adapted from Megatron-LM. |
|
|
|
Arguments: |
|
tensor: input tensor. |
|
partitions: list of partition sizes to supply to torch.split |
|
contiguous_split_chunks: If True, make each chunk contiguous |
|
in memory. |
|
""" |
|
|
|
last_dim = tensor.dim() - 1 |
|
|
|
|
|
tensor_list = torch.split(tensor, partitions, dim=last_dim) |
|
|
|
if contiguous_split_chunks: |
|
return tuple(chunk.contiguous() for chunk in tensor_list) |
|
|
|
return tensor_list |
|
|
|
|
|
class TiledLinear(torch.nn.Module): |
|
|
|
def __init__(self, |
|
in_features, |
|
out_features, |
|
bias=True, |
|
in_splits=1, |
|
out_splits=1, |
|
input_is_already_split=False, |
|
combine_out_splits=True, |
|
linear_cls=torch.nn.Linear, |
|
init_linear=None, |
|
**kwargs): |
|
"""A replacement for ``torch.nn.Linear`` that works with ZeRO-3 to reduce |
|
memory requirements via tiling. |
|
|
|
TiledLinear breaks the input and output dimensions of a linear layer |
|
into tiles that are processed in sequence. This class enables huge |
|
linear layers when combined with ZeRO-3 because inactive tiles can be |
|
partitioned and offloaded. |
|
|
|
.. note:: |
|
We recommend using as few tiles as necessary. Tiling |
|
significantly reduces memory usage, but can reduce throughput |
|
for inexpensive layers. This due to the smaller kernels having |
|
less parallelism and lower arithmetic intensity, while |
|
introducing more frequent synchronization and communication. |
|
|
|
Args: |
|
in_features (int): See ``torch.nn.Linear`` |
|
out_features (int): See ``torch.nn.Linear`` |
|
bias (bool, optional): See ``torch.nn.Linear`` |
|
in_splits (int, optional): The number of tiles along the input dimension. Defaults to 1. |
|
out_splits (int, optional): The number of tiles along the output dimension. Defaults to 1. |
|
input_is_already_split (bool, optional): If set to ``True``, assume that the ``input_`` in |
|
to ``forward()`` is already split into ``in_splits`` chunks. Defaults to ``False``. |
|
combine_out_splits (bool, optional): If set to ``False``, do not combine the ``out_splits`` outputs |
|
into a single tensor. Defaults to ``True``. |
|
linear_cls (class, optional): The underlying class to build individual tiles. |
|
Defaults to ``torch.nn.Linear``. |
|
init_linear (``torch.nn.Linear``, optional): If set, copy the parameters of |
|
``init_linear``. Useful for debugging. Defaults to ``None``. |
|
kwargs (dict, optional): additional keyword arguments to provide to ``linear_cls()``. |
|
|
|
Raises: |
|
RuntimeError: ``in_splits`` must be within the range [1, in_features). |
|
RuntimeError: ``out_splits`` must be within the range of [1, out_features). |
|
""" |
|
|
|
super().__init__() |
|
|
|
if (in_splits < 1) or (in_splits > in_features): |
|
raise RuntimeError('in splits must be in range [1, in_features].') |
|
if (out_splits < 1) or (out_splits > out_features): |
|
raise RuntimeError('out splits must be in range [1, out_features].') |
|
|
|
|
|
self.in_features = in_features |
|
self.out_features = out_features |
|
self.use_bias = bias |
|
|
|
self.out_splits = out_splits |
|
self.in_splits = in_splits |
|
self.input_is_already_split = input_is_already_split |
|
self.combine_out_splits = combine_out_splits |
|
|
|
|
|
|
|
|
|
self.in_parts = partition(num_items=in_features, num_parts=in_splits) |
|
self.out_parts = partition(num_items=out_features, num_parts=out_splits) |
|
|
|
assert len(self.out_parts) == out_splits + 1 |
|
assert len(self.in_parts) == in_splits + 1 |
|
assert self.out_parts[0] == 0 |
|
assert self.out_parts[out_splits] == out_features |
|
assert self.in_parts[in_splits] == in_features |
|
|
|
self.linears = torch.nn.ModuleList() |
|
for out_id in range(out_splits): |
|
self.linears.append(torch.nn.ModuleList()) |
|
|
|
local_out_dim = self.out_parts[out_id + 1] - self.out_parts[out_id] |
|
|
|
for in_id in range(in_splits): |
|
|
|
local_bias = bias if in_id == (in_splits - 1) else False |
|
|
|
local_in_dim = self.in_parts[in_id + 1] - self.in_parts[in_id] |
|
local = linear_cls(local_in_dim, local_out_dim, bias=local_bias, **kwargs) |
|
self.linears[out_id].append(local) |
|
|
|
|
|
if init_linear is not None: |
|
self.copy_params_from(init_linear) |
|
|
|
def forward(self, input_): |
|
if self.in_splits > 1 and not self.input_is_already_split: |
|
input_parts = partition(input_.shape[-1], self.in_splits) |
|
split_sizes = [input_parts[p + 1] - input_parts[p] for p in range(self.in_splits)] |
|
inputs = self._split_global_input(input_, split_sizes) |
|
elif self.in_splits > 1: |
|
inputs = input_ |
|
assert len( |
|
inputs) == self.in_splits, f"Col splits {self.in_splits} does not match input splits {len(inputs)}" |
|
else: |
|
|
|
inputs = [input_] |
|
|
|
outputs = [None] * self.out_splits |
|
for out_id in range(self.out_splits): |
|
for in_id in range(self.in_splits): |
|
local_output = self.linears[out_id][in_id](inputs[in_id]) |
|
|
|
outputs[out_id] = self._reduce_local_output(in_id=in_id, |
|
out_id=out_id, |
|
current_out=outputs[out_id], |
|
new_out=local_output) |
|
|
|
if self.combine_out_splits: |
|
return self._combine_output_splits(outputs) |
|
|
|
return outputs |
|
|
|
def _split_global_input(self, input, split_sizes): |
|
"""Partition an input tensor along the last dimension, aligned with given splits. |
|
|
|
Subclasses should override this method to account for new input types. |
|
|
|
Args: |
|
input (List[Tensor]): The tensor to partition along the last dimension. |
|
split_sizes (List[int]): The size of each partition. |
|
|
|
Returns: |
|
List[Any]: A list of the chunks of ``input``. |
|
""" |
|
return split_tensor_along_last_dim(input, split_sizes) |
|
|
|
def _reduce_local_output(self, in_id, out_id, current_out, new_out): |
|
"""Reduce (sum) a new local result into the existing local results. |
|
|
|
Subclasses should override this method. |
|
|
|
For a given ``out_id``, this method is called ``in_id-1`` times. The first input |
|
split is a simple assignment. |
|
|
|
Args: |
|
in_id (int): The input split that produced ``new_out``. |
|
out_id (int): The output split that produced ``new_out``. |
|
current_out (Any): The reduced form of all previous ``out_id`` results. |
|
new_out (Any): The local result from forward (``in_id``, ``out_id``)e |
|
|
|
Returns: |
|
Any: The combined result of ``current_out`` and ``new_out``. |
|
""" |
|
|
|
if current_out is None: |
|
|
|
|
|
return new_out.clone() |
|
else: |
|
return current_out + new_out |
|
|
|
def _combine_output_splits(self, outputs): |
|
"""Join the splits of the output into a single result. |
|
|
|
Args: |
|
outputs (List[Any]): The reduced outputs for each output split. |
|
|
|
Returns: |
|
Any: The combined outputs. |
|
""" |
|
assert len(outputs) == self.out_splits |
|
return torch.cat(outputs, dim=-1) |
|
|
|
@torch.no_grad() |
|
def copy_params_from(self, other): |
|
"""Copy the weight and bias data from ``other``. |
|
|
|
This is especially useful for reproducible initialization and testing. |
|
|
|
Equivalent to: |
|
|
|
.. code-block:: python |
|
|
|
with torch.no_grad(): |
|
self.weight.copy_(other.weight) |
|
if self.bias is not None: |
|
self.bias.copy_(other.bias) |
|
|
|
.. note:: |
|
If ZeRO-3 is enabled, this is a collective operation and the |
|
updated parameters of data-parallel rank 0 will be visible on all |
|
ranks. See :class:`deepspeed.zero.GatheredParameters` for more |
|
information. |
|
|
|
|
|
Args: |
|
other (``torch.nn.Linear``): the linear layer to copy from. |
|
""" |
|
assert hasattr(other, 'weight') |
|
assert other.weight.size() == (self.out_features, self.in_features) |
|
if self.use_bias: |
|
assert hasattr(other, 'bias') |
|
assert other.bias is not None |
|
assert other.bias.size() == (self.out_features, ) |
|
else: |
|
assert other.bias is None |
|
|
|
for row in range(self.out_splits): |
|
rstart = self.out_parts[row] |
|
rstop = self.out_parts[row + 1] |
|
|
|
for col in range(self.in_splits): |
|
cstart = self.in_parts[col] |
|
cstop = self.in_parts[col + 1] |
|
|
|
local = self.linears[row][col] |
|
global_weight = other.weight[rstart:rstop, cstart:cstop] |
|
with deepspeed.zero.GatheredParameters(local.weight, modifier_rank=0): |
|
local.weight.copy_(global_weight) |
|
|
|
if local.bias is not None: |
|
with deepspeed.zero.GatheredParameters(local.bias, modifier_rank=0): |
|
local.bias.data.copy_(other.bias[rstart:rstop].data) |
|
|
|
|
|
class TiledLinearReturnBias(TiledLinear): |
|
"""Wrapper for a Linear class that returns its own bias parameter, such as |
|
used by Megatron-LM. |
|
""" |
|
|
|
def _reduce_local_output(self, in_id, out_id, current_out, new_out): |
|
"""Reduces output tensors, but not the returned bias. """ |
|
if current_out is not None: |
|
old_tensor, old_bias = current_out |
|
else: |
|
old_tensor, old_bias = None, None |
|
|
|
assert isinstance(new_out, tuple) |
|
assert len(new_out) == 2 |
|
|
|
tensor, bias = new_out |
|
assert tensor is not None |
|
|
|
tensor = super()._reduce_local_output(in_id=in_id, out_id=out_id, current_out=old_tensor, new_out=tensor) |
|
|
|
if bias is None: |
|
bias = old_bias |
|
|
|
return tensor, bias |
|
|
|
def _combine_output_splits(self, outputs): |
|
|
|
tensors = [o[0] for o in outputs] |
|
tensor = super()._combine_output_splits(tensors) |
|
|
|
|
|
biases = [o[1] for o in outputs if o[1] is not None] |
|
if len(biases) > 0: |
|
bias = super()._combine_output_splits(biases) |
|
else: |
|
bias = None |
|
|
|
return tensor, bias |
|
|