|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import math |
|
import functools |
|
|
|
import torch |
|
from torch import Tensor |
|
from torch.nn.parameter import Parameter |
|
from torch.nn import init |
|
from torch.nn.modules.module import Module |
|
from deepspeed.runtime.utils import noop_decorator |
|
from deepspeed import comm as dist |
|
from deepspeed.accelerator import get_accelerator |
|
|
|
|
|
def print_rank_0(message, debug=False, force=False): |
|
if dist.get_rank() == 0 and (debug or force): |
|
print(message) |
|
|
|
|
|
try: |
|
|
|
if hasattr(torch, 'amp') and hasattr(torch.amp, 'custom_fwd') and hasattr(torch.amp, 'custom_bwd'): |
|
autocast_custom_fwd = functools.partial(torch.amp.custom_fwd, device_type=get_accelerator().device_name()) |
|
autocast_custom_bwd = functools.partial(torch.amp.custom_bwd, device_type=get_accelerator().device_name()) |
|
else: |
|
|
|
autocast_custom_fwd = get_accelerator().amp().custom_fwd |
|
autocast_custom_bwd = get_accelerator().amp().custom_bwd |
|
except (ImportError, AttributeError) as exp: |
|
autocast_custom_fwd = noop_decorator |
|
autocast_custom_bwd = noop_decorator |
|
|
|
|
|
class LinearFunctionForZeroStage3(torch.autograd.Function): |
|
|
|
|
|
@staticmethod |
|
@autocast_custom_fwd |
|
|
|
def forward(ctx, input, weight, bias=None): |
|
|
|
ctx.save_for_backward(input, weight, bias) |
|
|
|
if input.dim() == 2 and bias is not None: |
|
|
|
ret = torch.addmm(bias, input, weight.t()) |
|
else: |
|
output = input.matmul(weight.t()) |
|
if bias is not None: |
|
output += bias |
|
ret = output |
|
|
|
return ret |
|
|
|
|
|
@staticmethod |
|
@autocast_custom_bwd |
|
def backward(ctx, grad_output): |
|
|
|
|
|
|
|
|
|
|
|
input, weight, bias = ctx.saved_tensors |
|
|
|
grad_input = grad_weight = grad_bias = None |
|
|
|
|
|
|
|
|
|
|
|
|
|
if ctx.needs_input_grad[0]: |
|
|
|
grad_input = grad_output.matmul(weight) |
|
|
|
if ctx.needs_input_grad[1]: |
|
|
|
dim = grad_output.dim() |
|
if dim > 2: |
|
grad_weight = grad_output.reshape(-1, |
|
grad_output.shape[-1]).t().matmul(input.reshape(-1, input.shape[-1])) |
|
else: |
|
grad_weight = grad_output.t().matmul(input) |
|
|
|
if bias is not None and ctx.needs_input_grad[2]: |
|
|
|
if dim > 2: |
|
grad_bias = grad_output.sum([i for i in range(dim - 1)]) |
|
else: |
|
grad_bias = grad_output.sum(0) |
|
|
|
|
|
|
|
return grad_input, grad_weight, grad_bias |
|
|
|
|
|
def zero3_linear_wrap(input, weight, bias=None): |
|
if bias is None: |
|
return LinearFunctionForZeroStage3.apply(input, weight) |
|
else: |
|
return LinearFunctionForZeroStage3.apply(input, weight, bias) |
|
|
|
|
|
class LinearModuleForZeroStage3(Module): |
|
r"""Applies a linear transformation to the incoming data: :math:`y = xA^T + b`. |
|
The weights are pre-transposed and stored as A^T instead of transposing during each |
|
forward. Memory savings proportional to the parameter size. |
|
|
|
Args: |
|
in_features: size of each input sample |
|
out_features: size of each output sample |
|
bias: If set to ``False``, the layer will not learn an additive bias. |
|
Default: ``True`` |
|
|
|
Shape: |
|
- Input: :math:`(N, *, H_{in})` where :math:`*` means any number of |
|
additional dimensions and :math:`H_{in} = \text{in\_features}` |
|
- Output: :math:`(N, *, H_{out})` where all but the last dimension |
|
are the same shape as the input and :math:`H_{out} = \text{out\_features}`. |
|
|
|
Attributes: |
|
weight: the learnable weights of the module of shape |
|
:math:`(\text{out\_features}, \text{in\_features})`. The values are |
|
initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})`, where |
|
:math:`k = \frac{1}{\text{in\_features}}` |
|
bias: the learnable bias of the module of shape :math:`(\text{out\_features})`. |
|
If :attr:`bias` is ``True``, the values are initialized from |
|
:math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where |
|
:math:`k = \frac{1}{\text{in\_features}}` |
|
|
|
Examples:: |
|
|
|
>>> m = nn.Linear(20, 30) |
|
>>> input = torch.randn(128, 20) |
|
>>> output = m(input) |
|
>>> print(output.size()) |
|
torch.Size([128, 30]) |
|
""" |
|
__constants__ = ['in_features', 'out_features'] |
|
in_features: int |
|
out_features: int |
|
weight: Tensor |
|
|
|
def __init__(self, in_features: int, out_features: int, bias: bool = True) -> None: |
|
super(LinearModuleForZeroStage3, self).__init__() |
|
print("Building ZeRO module") |
|
self.in_features = in_features |
|
self.out_features = out_features |
|
self.weight = Parameter(torch.Tensor(out_features, in_features)) |
|
if bias: |
|
self.bias = Parameter(torch.Tensor(out_features)) |
|
else: |
|
self.register_parameter('bias', None) |
|
self.reset_parameters() |
|
|
|
def reset_parameters(self) -> None: |
|
init.kaiming_uniform_(self.weight, a=math.sqrt(5)) |
|
if self.bias is not None: |
|
fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight) |
|
bound = 1 / math.sqrt(fan_in) |
|
init.uniform_(self.bias, -bound, bound) |
|
|
|
def forward(self, input: Tensor) -> Tensor: |
|
return LinearFunctionForZeroStage3.apply(input, self.weight, self.bias) |
|
|
|
def extra_repr(self) -> str: |
|
return 'in_features={}, out_features={}, bias={}'.format(self.in_features, self.out_features, self.bias |
|
is not None) |
|
|