File size: 4,576 Bytes
9c6594c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 |
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
# Adapted from https://github.com/NVIDIA/Megatron-LM/blob/23.08/megatron/core/tensor_parallel/layers.py
import torch
from torch.nn.parameter import Parameter
import torch.nn.functional as F
from deepspeed.accelerator import get_accelerator
import deepspeed.comm as dist
from typing import Callable
TP_group = None
class DominoAsyncColumnParallelLinearImpl(torch.autograd.Function):
@staticmethod
def forward(ctx, inp, weight, bias, handle_dic, h_id): # inp: (b, s, k), weight: (m, k), bias (m)
ctx.save_for_backward(inp, weight, bias)
ctx.handle_dic = handle_dic
ctx.h_id = h_id
output = torch.matmul(inp, weight.t()) # (b, s, k) @ (k, m) -> (b, s, m)
if bias is not None: # bias (m)
output = output + bias
return output
@staticmethod
def backward(ctx, grad_output):
inp, weight, bias = ctx.saved_tensors
grad_input = grad_weight = grad_bias = None
grad_input = torch.matmul(grad_output, weight) # (b, s, m) @ (m, k) -> (b, s, k)
handle = dist.all_reduce(grad_input, group=TP_group, async_op=True)
ctx.handle_dic[ctx.h_id] = handle
grad_output = grad_output.view(grad_output.shape[0] * grad_output.shape[1], grad_output.shape[2]) # (b*s, m)
inp = inp.view(inp.shape[0] * inp.shape[1], inp.shape[2]) # (b*s, k)
grad_weight = torch.matmul(grad_output.t(), inp) # (m, b*s) @ (b*s, k) -> (m, k)
if bias is not None:
grad_bias = grad_output.sum(dim=0) # (b*s, m) -> (m)
return grad_input, grad_weight, grad_bias, None, None
class DominoAsyncColumnParallelLinear(torch.nn.Module):
def __init__(self,
input_size,
output_size,
_tp_group,
config,
init_method: Callable,
bias=True,
skip_bias_add=False):
super(DominoAsyncColumnParallelLinear, self).__init__()
self.skip_bias_add = skip_bias_add
global TP_group
if TP_group == None:
TP_group = _tp_group
self.weight = Parameter(
torch.empty(
output_size,
input_size,
device=get_accelerator().current_device_name(),
dtype=config.params_dtype,
))
if config.perform_initialization:
init_method(self.weight)
if bias:
self.bias = Parameter(
torch.empty(output_size, device=get_accelerator().current_device_name(), dtype=config.params_dtype))
if config.perform_initialization:
with torch.no_grad():
self.bias.zero_()
else:
self.register_parameter('bias', None)
def forward(self, input_: torch.Tensor, handle_dic, h_id):
bias = self.bias if not self.skip_bias_add else None
output = DominoAsyncColumnParallelLinearImpl.apply(input_, self.weight, bias, handle_dic, h_id)
output_bias = self.bias if self.skip_bias_add else None
return output, output_bias
class RowParallelLinearNoComm(torch.nn.Module):
def __init__(
self,
input_size: int,
output_size: int,
config,
init_method: Callable,
bias: bool = True,
stride: int = 1,
skip_bias_add: bool = False,
):
super(RowParallelLinearNoComm, self).__init__()
self.skip_bias_add = skip_bias_add
self.weight = Parameter(
torch.empty(
output_size,
input_size,
device=get_accelerator().current_device_name(),
dtype=config.params_dtype,
))
if config.perform_initialization:
init_method(self.weight)
if bias:
self.bias = Parameter(
torch.empty(
output_size,
device=get_accelerator().current_device_name(),
dtype=config.params_dtype,
))
if config.perform_initialization:
with torch.no_grad():
self.bias.zero_()
else:
self.register_parameter('bias', None)
def forward(self, input_):
bias = self.bias if not self.skip_bias_add else None
output = F.linear(input_, self.weight, bias)
output_bias = self.bias if self.skip_bias_add else None
return output, output_bias
|