jamtur01's picture
Upload folder using huggingface_hub
9c6594c verified
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import torch
import torch.nn.functional as F
import enum
import deepspeed.comm as dist
from .async_linear import DominoAsyncColumnParallelLinear, RowParallelLinearNoComm
class LayerType(enum.Enum):
encoder = 1
decoder = 2
class AttnType(enum.Enum):
self_attn = 1
cross_attn = 2
class AttnMaskType(enum.Enum):
padding = 1
causal = 2
class ModelType(enum.Enum):
encoder_or_decoder = 1
encoder_and_decoder = 2
class DominoUtil:
BATCH_0 = "BATCH0"
BATCH_1 = "BATCH1"
HANDLE_DIC = {"BATCH0": None, "BATCH1": None}
class DominoModule(torch.nn.Module):
"""extensions of torch Module."""
def __init__(self, ):
super(DominoModule, self).__init__()
def _Wait_bwd_comm(input_, dic_, h_id):
return NoOper.apply(input_, dic_, h_id)
class NoOper(torch.autograd.Function):
@staticmethod
def symbolic(graph, input_, handle_dic, h_id):
return input_
@staticmethod
def forward(ctx, input_, handle_dic, h_id):
ctx.handle_dic = handle_dic
ctx.h_id = h_id
return input_
@staticmethod
def backward(ctx, grad_output):
handle = ctx.handle_dic[ctx.h_id]
handle.wait()
return grad_output, None, None
class CoreAttention(DominoModule):
def __init__(self, config, tp_world_size, attn_mask_type=AttnMaskType.causal):
super(CoreAttention, self).__init__()
self.attn_mask_type = attn_mask_type
projection_size = config.kv_channels * config.num_attention_heads
# Per attention head and per partition values.
assert projection_size % tp_world_size == 0, f"projection size {projection_size} should be multiple of TP world size {tp_world_size}"
self.hidden_size_per_partition = projection_size // tp_world_size
self.attention_dropout_rate = config.attention_dropout
def forward(self, query_layer, key_layer, value_layer, attention_mask):
context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer,
key_layer,
value_layer,
attn_mask=None,
dropout_p=self.attention_dropout_rate,
is_causal=True,
scale=None)
# [b, np, sq, hn] --> [sq, b, np, hn]
context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
# [sq, b, np, hn] --> [sq, b, hp]
new_context_layer_shape = context_layer.size()[:-2] + \
(self.hidden_size_per_partition,)
context_layer = context_layer.view(*new_context_layer_shape)
return context_layer
class ShardedAttention(DominoModule):
"""Sharded self-attention layer class.
Only support self attention and causal attention mask for now.
"""
def __init__(self,
config,
mpu,
apply_rotary_pos_emb,
layer_number,
attention_type=AttnType.self_attn,
attn_mask_type=AttnMaskType.causal):
super(ShardedAttention, self).__init__()
assert attention_type == AttnType.self_attn, "Only support self_attn for now!"
self.layer_number = max(1, layer_number)
self.attention_type = attention_type
self.attn_mask_type = attn_mask_type
self.params_dtype = config.params_dtype
self.apply_rotary_pos_emb = apply_rotary_pos_emb
query_projection_size = config.kv_channels * config.num_attention_heads
kv_projection_size = config.kv_channels * config.num_attention_heads
tp_world_size = mpu.get_tensor_model_parallel_world_size()
self.hidden_size_per_attention_head = query_projection_size // config.num_attention_heads
self.num_attention_heads_per_partition = config.num_attention_heads // tp_world_size
qkv_projection_per_partition = (query_projection_size + 2 * kv_projection_size) // tp_world_size
self.query_key_value = DominoAsyncColumnParallelLinear(config.hidden_size,
qkv_projection_per_partition,
mpu.get_tensor_model_parallel_group(),
config=config,
init_method=config.init_method,
bias=config.add_bias_linear)
self.core_attention = CoreAttention(config, tp_world_size, self.attn_mask_type)
query_projection_size_per_partition = query_projection_size // tp_world_size
# Output.
self.dense = RowParallelLinearNoComm(query_projection_size_per_partition,
config.hidden_size,
config=config,
init_method=config.output_layer_init_method,
bias=config.add_bias_linear,
skip_bias_add=True)
def forward(self, hidden_states, attention_mask, micro_batch_num, rotary_pos_emb=None):
# hidden_states: [sq, b, h]
mixed_x_layer, _ = self.query_key_value(hidden_states, DominoUtil.HANDLE_DIC, micro_batch_num)
new_tensor_shape = mixed_x_layer.size()[:-1] + (
self.num_attention_heads_per_partition,
3 * self.hidden_size_per_attention_head,
)
mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
mixed_x_layer = mixed_x_layer.permute(1, 2, 0, 3).contiguous()
(query_layer, key_layer, value_layer) = torch.split(mixed_x_layer, [
self.hidden_size_per_attention_head, self.hidden_size_per_attention_head,
self.hidden_size_per_attention_head
],
dim=3)
query_layer = query_layer.view(query_layer.size(0), query_layer.size(1), -1,
self.hidden_size_per_attention_head)
if rotary_pos_emb is not None:
if isinstance(rotary_pos_emb, tuple):
rotary_pos_emb = rotary_pos_emb
else:
rotary_pos_emb = ((rotary_pos_emb, ) * 2)
q_pos_emb, k_pos_emb = rotary_pos_emb
query_layer = self.apply_rotary_pos_emb(query_layer, q_pos_emb)
key_layer = self.apply_rotary_pos_emb(key_layer, k_pos_emb)
context_layer = self.core_attention(query_layer, key_layer, value_layer, attention_mask)
output, bias = self.dense(context_layer)
return output, bias
def domino_core_attention_forward(self, mixed_x_layer, attention_mask, rotary_pos_emb=None):
# hidden_states: [sq, b, h]
# To illustrate the difference between intra-layer overlap and inter-layer overlap
# mixed_x_layer, _ = self.query_key_value(hidden_states, handle_dic, micro_batch_num)
new_tensor_shape = mixed_x_layer.size()[:-1] + (
self.num_attention_heads_per_partition,
3 * self.hidden_size_per_attention_head,
)
mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
mixed_x_layer = mixed_x_layer.permute(1, 2, 0, 3).contiguous()
(query_layer, key_layer, value_layer) = torch.split(mixed_x_layer, [
self.hidden_size_per_attention_head, self.hidden_size_per_attention_head,
self.hidden_size_per_attention_head
],
dim=3)
query_layer = query_layer.view(query_layer.size(0), query_layer.size(1), -1,
self.hidden_size_per_attention_head)
if rotary_pos_emb is not None:
if isinstance(rotary_pos_emb, tuple):
rotary_pos_emb = rotary_pos_emb
else:
rotary_pos_emb = ((rotary_pos_emb, ) * 2)
q_pos_emb, k_pos_emb = rotary_pos_emb
query_layer = self.apply_rotary_pos_emb(query_layer, q_pos_emb)
key_layer = self.apply_rotary_pos_emb(key_layer, k_pos_emb)
context_layer = self.core_attention(query_layer, key_layer, value_layer, attention_mask)
# output, bias = self.dense(context_layer)
# return output, bias
return context_layer
class bias_dropout_add(torch.nn.Module):
def __init__(self, prob: float):
super(bias_dropout_add, self).__init__()
self.dropout = torch.nn.Dropout(prob)
def forward(self, x: torch.Tensor, bias: torch.Tensor, residual: torch.Tensor) -> torch.Tensor:
if bias is not None:
x = x + bias
out = self.dropout(x)
out = out + residual
return out
class DominoTransformerLayer(DominoModule):
"""A domino single transformer layer.
[s, b, h] -> [s, b, h]
"""
def __init__(self,
config,
mpu,
apply_rotary_pos_emb,
layer_number,
layer_type=LayerType.encoder,
self_attn_mask_type=AttnMaskType.causal,
drop_path_rate=0.):
super(DominoTransformerLayer, self).__init__()
self.layer_number = layer_number
self.layer_type = layer_type
self.apply_residual_connection_post_layernorm \
= config.apply_residual_connection_post_layernorm
self.llama_model = False
self.input_layernorm = torch.nn.LayerNorm(config.hidden_size, eps=config.layernorm_epsilon)
# Self attention.
self.self_attention = ShardedAttention(config,
mpu,
apply_rotary_pos_emb,
layer_number,
attention_type=AttnType.self_attn,
attn_mask_type=self_attn_mask_type)
self.hidden_dropout = config.hidden_dropout
self.post_attention_layernorm = torch.nn.LayerNorm(config.hidden_size, eps=config.layernorm_epsilon)
# MLP
ffn_hidden_size = config.ffn_hidden_size
if config.gated_linear_unit:
ffn_hidden_size *= 2
self.output_size_c = config.ffn_hidden_size
self.input_size_c = config.hidden_size
self.input_size_r = config.ffn_hidden_size
self.output_size_r = self.input_size_c
tp_world_size = mpu.get_tensor_model_parallel_world_size()
self.TP_group = mpu.get_tensor_model_parallel_group()
self.output_size_per_partition = self.output_size_c // tp_world_size
self.input_size_per_partition = self.input_size_r // tp_world_size
self.linear_fc1 = DominoAsyncColumnParallelLinear(self.input_size_c,
self.output_size_per_partition,
mpu.get_tensor_model_parallel_group(),
config=config,
init_method=config.init_method,
bias=config.add_bias_linear)
self.mlp_activation_func = F.gelu
self.linear_fc2 = RowParallelLinearNoComm(self.input_size_per_partition,
self.output_size_r,
config=config,
init_method=config.output_layer_init_method,
bias=config.add_bias_linear,
skip_bias_add=True)
self.bias_dropout_add_func = bias_dropout_add(self.hidden_dropout)
def forward(self, hidden_states, attention_mask, rotary_pos_emb=None):
hidden_states0, hidden_states1 = hidden_states
layernorm_output0 = self.input_layernorm(hidden_states0)
layernorm_output0 = _Wait_bwd_comm(layernorm_output0, DominoUtil.HANDLE_DIC, DominoUtil.BATCH_0)
# Micro batch 0: attention
attention_output0, attention_bias0 = self.self_attention(layernorm_output0,
attention_mask,
DominoUtil.BATCH_0,
rotary_pos_emb=rotary_pos_emb)
fwd_handle0 = dist.all_reduce(attention_output0, group=self.TP_group, async_op=True)
# End of Micro batch 0: attention
# Micro batch 1: attention
layernorm_output1 = self.input_layernorm(hidden_states1)
layernorm_output1 = _Wait_bwd_comm(layernorm_output1, DominoUtil.HANDLE_DIC, DominoUtil.BATCH_1)
attention_output1, attention_bias1 = self.self_attention(layernorm_output1,
attention_mask,
DominoUtil.BATCH_1,
rotary_pos_emb=rotary_pos_emb)
fwd_handle1 = dist.all_reduce(attention_output1, group=self.TP_group, async_op=True)
# Micro batch 0: Residual connection.
fwd_handle0.wait()
if self.apply_residual_connection_post_layernorm:
residual0 = layernorm_output0
else:
residual0 = hidden_states0
layernorm_input0 = self.bias_dropout_add_func(attention_output0, attention_bias0, residual0)
layernorm_output0 = self.post_attention_layernorm(layernorm_input0)
layernorm_output0 = _Wait_bwd_comm(layernorm_output0, DominoUtil.HANDLE_DIC, DominoUtil.BATCH_0)
if self.apply_residual_connection_post_layernorm:
residual0 = layernorm_output0
else:
residual0 = layernorm_input0
# End of Micro batch 0: Residual connection.
# ------------ MLP ------------
# Micro batch 0: MLP
output0, _ = self.linear_fc1(layernorm_output0, DominoUtil.HANDLE_DIC, DominoUtil.BATCH_0)
output0 = self.mlp_activation_func(output0)
# Micro batch 1: Residual connection.
fwd_handle1.wait()
if self.apply_residual_connection_post_layernorm:
residual1 = layernorm_output1
else:
residual1 = hidden_states1
layernorm_input1 = self.bias_dropout_add_func(attention_output1, attention_bias1, residual1)
layernorm_output1 = self.post_attention_layernorm(layernorm_input1)
layernorm_output1 = _Wait_bwd_comm(layernorm_output1, DominoUtil.HANDLE_DIC, DominoUtil.BATCH_1)
if self.apply_residual_connection_post_layernorm:
residual1 = layernorm_output1
else:
residual1 = layernorm_input1
# End of Micro batch 1: Residual connection.
hidden_states0, last_mlp_bias = self.linear_fc2(output0)
fwd_handle0 = dist.all_reduce(hidden_states0, group=self.TP_group, async_op=True)
# End of Micro batch 0: MLP
# Micro batch 1: MLP
output1, _ = self.linear_fc1(layernorm_output1, DominoUtil.HANDLE_DIC, DominoUtil.BATCH_1)
output1 = self.mlp_activation_func(output1)
hidden_states1, last_mlp_bias = self.linear_fc2(output1)
fwd_handle1 = dist.all_reduce(hidden_states1, group=self.TP_group, async_op=True)
# End of Micro batch 1: MLP
# ------------ End of MLP ------------
fwd_handle0.wait()
hidden_states0 = self.bias_dropout_add_func(hidden_states0, last_mlp_bias, residual0)
fwd_handle1.wait()
hidden_states1 = self.bias_dropout_add_func(hidden_states1, last_mlp_bias, residual1)
return hidden_states0, hidden_states1
class DominoTransformer(DominoModule):
"""Transformer class."""
def __init__(self,
config,
mpu,
apply_rotary_pos_emb,
model_type,
layer_type=LayerType.encoder,
self_attn_mask_type=AttnMaskType.causal,
post_layer_norm=True,
pre_process=True,
post_process=True,
drop_path_rate=0.0):
super(DominoTransformer, self).__init__()
self.layer_type = layer_type
self.model_type = model_type
self.post_layer_norm = post_layer_norm
self.post_process = post_process
self.input_tensor = None
self.drop_path_rate = drop_path_rate
self.TP_group = mpu.get_tensor_model_parallel_group()
if not dist.is_initialized():
dist.init_distributed()
assert dist.is_initialized(), "deepspeed.comm failed to initialize!"
self.num_layers = config.num_layers
self.drop_path_rates = [rate.item() for rate in torch.linspace(0, self.drop_path_rate, config.num_layers)]
def build_layer(layer_number):
current_layer_type = layer_type
return DominoTransformerLayer(config,
mpu,
apply_rotary_pos_emb,
layer_number,
layer_type=current_layer_type,
self_attn_mask_type=self_attn_mask_type,
drop_path_rate=self.drop_path_rates[layer_number - 1])
self.layers = torch.nn.ModuleList([build_layer(i + 1) for i in range(self.num_layers)])
if self.post_process and self.post_layer_norm:
self.final_layernorm = torch.nn.LayerNorm(config.hidden_size, eps=config.layernorm_epsilon)
self._forward_impl = self.inter_layer_overlap_forward
if config.domino_intra_layer_overlap:
self._forward_impl = self.intra_layer_overlap_forward
def forward(self, hidden_states, attention_mask, rotary_pos_emb=None):
return self._forward_impl(hidden_states, attention_mask, rotary_pos_emb)
def inter_layer_overlap_forward(self, hidden_states, attention_mask, rotary_pos_emb=None):
# hidden_states: [s, b, h]
hidden_states0, hidden_states1 = torch.chunk(hidden_states, chunks=2, dim=1)
last_mlp_bias = None
fwd_handle0, fwd_handle1 = None, None
residual0, residual1 = None, None
layernorm_output0 = self.layers[0].input_layernorm(hidden_states0)
layernorm_output0 = _Wait_bwd_comm(layernorm_output0, DominoUtil.HANDLE_DIC, DominoUtil.BATCH_0)
for index in range(self.num_layers):
# Micro batch 0: attention
attention_output0, _ = self.layers[index].self_attention.query_key_value(
layernorm_output0, DominoUtil.HANDLE_DIC, DominoUtil.BATCH_0)
attention_output0 = self.layers[index].self_attention.domino_core_attention_forward(
attention_output0, attention_mask, rotary_pos_emb=rotary_pos_emb)
# Micro batch 1: Residual connection
if index > 0:
fwd_handle1.wait()
hidden_states1 = self.layers[index - 1].bias_dropout_add_func(hidden_states1, last_mlp_bias, residual1)
layernorm_output1 = self.layers[index].input_layernorm(hidden_states1)
layernorm_output1 = _Wait_bwd_comm(layernorm_output1, DominoUtil.HANDLE_DIC, DominoUtil.BATCH_1)
# End of Micro batch 1: Residual connection
attention_output0, attention_bias0 = self.layers[index].self_attention.dense(attention_output0)
fwd_handle0 = dist.all_reduce(attention_output0, group=self.TP_group, async_op=True)
# End of Micro batch 0: attention
# Micro batch 1: attention
attention_output1, _ = self.layers[index].self_attention.query_key_value(
layernorm_output1, DominoUtil.HANDLE_DIC, DominoUtil.BATCH_1)
attention_output1 = self.layers[index].self_attention.domino_core_attention_forward(
attention_output1, attention_mask, rotary_pos_emb=rotary_pos_emb)
# Micro batch 0: Residual connection.
fwd_handle0.wait()
if self.layers[index].apply_residual_connection_post_layernorm:
residual0 = layernorm_output0
else:
residual0 = hidden_states0
layernorm_input0 = self.layers[index].bias_dropout_add_func(attention_output0, attention_bias0, residual0)
layernorm_output0 = self.layers[index].post_attention_layernorm(layernorm_input0)
layernorm_output0 = _Wait_bwd_comm(layernorm_output0, DominoUtil.HANDLE_DIC, DominoUtil.BATCH_0)
if self.layers[index].apply_residual_connection_post_layernorm:
residual0 = layernorm_output0
else:
residual0 = layernorm_input0
# End of Micro batch 0: Residual connection.
attention_output1, attention_bias1 = self.layers[index].self_attention.dense(attention_output1)
fwd_handle1 = dist.all_reduce(attention_output1, group=self.TP_group, async_op=True)
# End of Micro batch 1: attention
# ------------ MLP ------------
# Micro batch 0: MLP
output0, _ = self.layers[index].linear_fc1(layernorm_output0, DominoUtil.HANDLE_DIC, DominoUtil.BATCH_0)
output0 = self.layers[index].mlp_activation_func(output0)
# Micro batch 1: Residual connection.
fwd_handle1.wait()
if self.layers[index].apply_residual_connection_post_layernorm:
residual1 = layernorm_output1
else:
residual1 = hidden_states1
layernorm_input1 = self.layers[index].bias_dropout_add_func(attention_output1, attention_bias1, residual1)
layernorm_output1 = self.layers[index].post_attention_layernorm(layernorm_input1)
layernorm_output1 = _Wait_bwd_comm(layernorm_output1, DominoUtil.HANDLE_DIC, DominoUtil.BATCH_1)
if self.layers[index].apply_residual_connection_post_layernorm:
residual1 = layernorm_output1
else:
residual1 = layernorm_input1
# End of Micro batch 1: Residual connection.
hidden_states0, last_mlp_bias = self.layers[index].linear_fc2(output0)
fwd_handle0 = dist.all_reduce(hidden_states0, group=self.TP_group, async_op=True)
# End of Micro batch 0: MLP
# Micro batch 1: MLP
output1, _ = self.layers[index].linear_fc1(layernorm_output1, DominoUtil.HANDLE_DIC, DominoUtil.BATCH_1)
output1 = self.layers[index].mlp_activation_func(output1)
# Micro batch 0: Residual connection.
fwd_handle0.wait()
hidden_states0 = self.layers[index].bias_dropout_add_func(hidden_states0, last_mlp_bias, residual0)
if index < self.num_layers - 1:
layernorm_output0 = self.layers[index + 1].input_layernorm(hidden_states0)
layernorm_output0 = _Wait_bwd_comm(layernorm_output0, DominoUtil.HANDLE_DIC, DominoUtil.BATCH_0)
# End of Micro batch 0: Residual connection.
hidden_states1, last_mlp_bias = self.layers[index].linear_fc2(output1)
fwd_handle1 = dist.all_reduce(hidden_states1, group=self.TP_group, async_op=True)
# End of Micro batch 1: MLP
# ------------ End of MLP ------------
if self.post_process and self.post_layer_norm:
hidden_states0 = self.final_layernorm(hidden_states0)
index = self.num_layers - 1
fwd_handle1.wait()
hidden_states1 = self.layers[index].bias_dropout_add_func(hidden_states1, last_mlp_bias, residual1)
if self.post_process and self.post_layer_norm:
hidden_states1 = self.final_layernorm(hidden_states1)
hidden_states = torch.cat([hidden_states0, hidden_states1], dim=1)
return hidden_states
def intra_layer_overlap_forward(self, hidden_states, attention_mask, rotary_pos_emb=None):
hidden_states = torch.chunk(hidden_states, chunks=2, dim=1)
for index in range(self.num_layers):
layer = self.layers[index]
hidden_states = layer(hidden_states, attention_mask, rotary_pos_emb)
hidden_states0, hidden_states1 = hidden_states
if self.post_process and self.post_layer_norm:
hidden_states0 = self.final_layernorm(hidden_states0)
hidden_states1 = self.final_layernorm(hidden_states1)
hidden_states = torch.cat([hidden_states0, hidden_states1], dim=1)
return hidden_states