|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
from torch.nn.parameter import Parameter |
|
import torch.nn.functional as F |
|
from deepspeed.accelerator import get_accelerator |
|
import deepspeed.comm as dist |
|
from typing import Callable |
|
|
|
TP_group = None |
|
|
|
|
|
class DominoAsyncColumnParallelLinearImpl(torch.autograd.Function): |
|
|
|
@staticmethod |
|
def forward(ctx, inp, weight, bias, handle_dic, h_id): |
|
ctx.save_for_backward(inp, weight, bias) |
|
ctx.handle_dic = handle_dic |
|
ctx.h_id = h_id |
|
output = torch.matmul(inp, weight.t()) |
|
if bias is not None: |
|
output = output + bias |
|
return output |
|
|
|
@staticmethod |
|
def backward(ctx, grad_output): |
|
inp, weight, bias = ctx.saved_tensors |
|
grad_input = grad_weight = grad_bias = None |
|
grad_input = torch.matmul(grad_output, weight) |
|
handle = dist.all_reduce(grad_input, group=TP_group, async_op=True) |
|
ctx.handle_dic[ctx.h_id] = handle |
|
grad_output = grad_output.view(grad_output.shape[0] * grad_output.shape[1], grad_output.shape[2]) |
|
|
|
inp = inp.view(inp.shape[0] * inp.shape[1], inp.shape[2]) |
|
grad_weight = torch.matmul(grad_output.t(), inp) |
|
|
|
if bias is not None: |
|
grad_bias = grad_output.sum(dim=0) |
|
return grad_input, grad_weight, grad_bias, None, None |
|
|
|
|
|
class DominoAsyncColumnParallelLinear(torch.nn.Module): |
|
|
|
def __init__(self, |
|
input_size, |
|
output_size, |
|
_tp_group, |
|
config, |
|
init_method: Callable, |
|
bias=True, |
|
skip_bias_add=False): |
|
super(DominoAsyncColumnParallelLinear, self).__init__() |
|
|
|
self.skip_bias_add = skip_bias_add |
|
|
|
global TP_group |
|
if TP_group == None: |
|
TP_group = _tp_group |
|
|
|
self.weight = Parameter( |
|
torch.empty( |
|
output_size, |
|
input_size, |
|
device=get_accelerator().current_device_name(), |
|
dtype=config.params_dtype, |
|
)) |
|
if config.perform_initialization: |
|
init_method(self.weight) |
|
|
|
if bias: |
|
self.bias = Parameter( |
|
torch.empty(output_size, device=get_accelerator().current_device_name(), dtype=config.params_dtype)) |
|
|
|
if config.perform_initialization: |
|
with torch.no_grad(): |
|
self.bias.zero_() |
|
else: |
|
self.register_parameter('bias', None) |
|
|
|
def forward(self, input_: torch.Tensor, handle_dic, h_id): |
|
|
|
bias = self.bias if not self.skip_bias_add else None |
|
|
|
output = DominoAsyncColumnParallelLinearImpl.apply(input_, self.weight, bias, handle_dic, h_id) |
|
|
|
output_bias = self.bias if self.skip_bias_add else None |
|
return output, output_bias |
|
|
|
|
|
class RowParallelLinearNoComm(torch.nn.Module): |
|
|
|
def __init__( |
|
self, |
|
input_size: int, |
|
output_size: int, |
|
config, |
|
init_method: Callable, |
|
bias: bool = True, |
|
stride: int = 1, |
|
skip_bias_add: bool = False, |
|
): |
|
super(RowParallelLinearNoComm, self).__init__() |
|
|
|
self.skip_bias_add = skip_bias_add |
|
|
|
self.weight = Parameter( |
|
torch.empty( |
|
output_size, |
|
input_size, |
|
device=get_accelerator().current_device_name(), |
|
dtype=config.params_dtype, |
|
)) |
|
if config.perform_initialization: |
|
init_method(self.weight) |
|
if bias: |
|
self.bias = Parameter( |
|
torch.empty( |
|
output_size, |
|
device=get_accelerator().current_device_name(), |
|
dtype=config.params_dtype, |
|
)) |
|
|
|
if config.perform_initialization: |
|
with torch.no_grad(): |
|
self.bias.zero_() |
|
else: |
|
self.register_parameter('bias', None) |
|
|
|
def forward(self, input_): |
|
bias = self.bias if not self.skip_bias_add else None |
|
|
|
output = F.linear(input_, self.weight, bias) |
|
|
|
output_bias = self.bias if self.skip_bias_add else None |
|
return output, output_bias |
|
|