jamtur01's picture
Upload folder using huggingface_hub
9c6594c verified
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import torch
from typing import Any, Tuple
from torch import Tensor
from torch.nn import Module
from einops import rearrange
import deepspeed.comm as dist
from deepspeed.accelerator import get_accelerator
from deepspeed.module_inject.tp_shard import get_shard_size_list, set_num_kv_heads, get_num_kv_heads
from deepspeed.utils import groups
def _generate_layout_params(scatter_idx, batch_dim_idx, seq_world_size, input):
"""
This function generates the parameters required for `permute` and `reshape` operations,
which are used to process data before and after `all2all` communication.
"""
if batch_dim_idx == 0:
if scatter_idx < 2:
bs, global_seq_len, num_local_head, head_dim = input.shape
pre_all2all_inp_shape = [bs, seq_world_size, global_seq_len // seq_world_size, num_local_head, head_dim]
pre_all2all_permute_idx = (1, 0, 2, 3, 4)
post_all2all_permute_idx = (1, 2, 0, 3, 4)
post_all2all_res_shape = [bs, global_seq_len // seq_world_size, seq_world_size * num_local_head, head_dim]
else:
bs, local_seq_len, num_total_head, head_dim = input.shape
assert num_total_head % seq_world_size == 0, f"Number of heads ({num_total_head}) must be divisible by the sequence parallel size ({seq_world_size})!"
pre_all2all_inp_shape = [bs, local_seq_len, seq_world_size, num_total_head // seq_world_size, head_dim]
pre_all2all_permute_idx = (2, 0, 1, 3, 4)
post_all2all_permute_idx = (1, 0, 2, 3, 4)
post_all2all_res_shape = [bs, seq_world_size * local_seq_len, num_total_head // seq_world_size, head_dim]
else:
if scatter_idx < 2:
global_seq_len, bs, num_local_head, head_dim = input.shape
pre_all2all_inp_shape = [seq_world_size, global_seq_len // seq_world_size, bs, num_local_head, head_dim]
pre_all2all_permute_idx = None
post_all2all_permute_idx = (1, 2, 0, 3, 4)
post_all2all_res_shape = [bs, seq_world_size * global_seq_len, num_local_head // seq_world_size, head_dim]
else:
local_seq_len, bs, num_total_head, head_dim = input.shape
assert num_total_head % seq_world_size == 0, f"Number of heads ({num_total_head}) must be divisible by the sequence parallel size ({seq_world_size})!"
pre_all2all_inp_shape = [local_seq_len, bs, seq_world_size, num_total_head // seq_world_size, head_dim]
pre_all2all_permute_idx = (2, 0, 1, 3, 4)
post_all2all_permute_idx = None
post_all2all_res_shape = [local_seq_len * seq_world_size, bs, num_total_head // seq_world_size, head_dim]
return pre_all2all_permute_idx, pre_all2all_inp_shape, post_all2all_permute_idx, post_all2all_res_shape
def post_all2all(permute_idx, res_shape):
"""
Post-processing function for `all2all` communication.
"""
def post_func(input):
if permute_idx is not None:
input = input.permute(permute_idx).contiguous()
output = input.reshape(res_shape).contiguous()
return output
return post_func
def pre_all2all_fun(permute_idx, inp_shape, input):
"""
Pre-processing function for `all2all` communication.
"""
input_t = input.reshape(inp_shape).contiguous()
if permute_idx is not None:
input_t = input_t.permute(permute_idx).contiguous()
return input_t
def _rotate_half(x):
"""
change sign so the last dimension becomes [-odd, +even]
"""
x = rearrange(x, '... (j d) -> ... j d', j=2)
x1, x2 = x.unbind(dim=-2)
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(t, freqs_cos, freqs_sin):
"""
input tensor t is of shape [seq_length, ..., dim]
rotary positional embeding tensor freqs is of shape [seq_length, ..., dim]
check https://kexue.fm/archives/8265 for detailed formulas
"""
rot_dim = freqs_cos.shape[-1]
# ideally t_pass is empty so rotary pos embedding is applied to all tensor t
t, t_pass = t[..., :rot_dim], t[..., rot_dim:]
# first part is cosine component
# second part is sine component, need to change signs with _rotate_half method
t = (t * freqs_cos) + (_rotate_half(t) * freqs_sin)
res = t if t_pass.shape[-1] == 0 else torch.cat((t, t_pass), dim=-1)
return res
def uneven_heads_all2all(input, scatter_idx, gather_idx, batch_dim_idx, group):
seq_world_size = dist.get_world_size(group)
inp_shape = list(input.shape)
assert batch_dim_idx in [0, 1], "batch_dim_idx must be either 0 or 1"
if not (scatter_idx < 2):
input_splits = get_shard_size_list(inp_shape[scatter_idx], seq_world_size)
input = input.transpose(0, scatter_idx).contiguous()
local_heads = input_splits[groups._get_sequence_parallel_rank()]
output_splits = [local_heads] * seq_world_size
output_buffer_shape = [seq_world_size * local_heads] + list(input.shape[1:])
output = torch.empty(output_buffer_shape, device=input.device, dtype=input.dtype)
dist.all_to_all_single(output,input,output_split_sizes=output_splits,\
input_split_sizes=input_splits,group=group)
###[seq_ws*local_heads, ...] to [seq_ws, local_heads, ...]
output = output.view(seq_world_size, local_heads, *output.shape[1:])
###[seq_ws,local_heads,b,seq_len,...] to [seq_ws,seq_len,b,local_heads,...]
### batch_dim_idx=0 [seq_ws,local_heads,seq_len,b,...] to [b, seq_ws, seq_len, local_heads ...]
### batch_dim_idx=1 [seq_ws,local_heads,b,seq_len,...] to [seq_ws,seq_len,b,local_heads,...]
if batch_dim_idx == 0:
order = [3, 0, 2, 1] + list(range(4, len(output.shape)))
output = output.permute(order).contiguous()
###[b, seq_ws*local_seq_len, local_heads,...]
output = output.view(output.shape[0], inp_shape[gather_idx] * seq_world_size,
*output.shape[3:]).contiguous()
elif batch_dim_idx == 1:
output = output.transpose(1, 3).contiguous()
###[seq_ws*local_seq_len, b, local_heads,...]
output = output.view(inp_shape[gather_idx] * seq_world_size, *output.shape[2:]).contiguous()
else:
# The compatibility handling of 4D and 3D tensors, standardizing to 3D.
input = input.reshape(input.shape[0], input.shape[1], -1)
if batch_dim_idx == 0: #b,s,h
input = input.permute(1, 2, 0).contiguous() #s,h,b
elif batch_dim_idx == 1: #s,b,h
input = input.transpose(1, 2).contiguous() #s,h,b
seq_len, h, batch_size = input.shape
num_local_heads_list = get_shard_size_list(get_num_kv_heads(), seq_world_size)
local_heads = num_local_heads_list[groups._get_sequence_parallel_rank()]
h_dim = h // local_heads
local_seq_len = seq_len // seq_world_size
input = input.view(seq_len * h, batch_size)
local_seq_len_with_heads = int(input.shape[0] / seq_world_size) # dim size of local_seq_len*local_heads*hdim
input_splits = [local_seq_len_with_heads] * seq_world_size
coeff = local_seq_len_with_heads // local_heads #per head: dim size of local_seq_len*hdim
#uneven seq_world_size coeff, total_heads/local_heads.
heads_scale_coeff = get_num_kv_heads() / local_heads
output_splits = [num_local_heads * coeff for num_local_heads in num_local_heads_list]
output_buff_d1_size = int(heads_scale_coeff * local_seq_len_with_heads)
total_h = int(inp_shape[gather_idx] * heads_scale_coeff)
output = torch.empty(output_buff_d1_size, input.shape[1], device=input.device, dtype=input.dtype)
dist.all_to_all_single(output,input,output_split_sizes=output_splits, \
input_split_sizes=input_splits,group=group)
##################
#suppose 7 heads divide into 4 ranks [2,2,2,1]
#chunk_num_heads_small=floor(7/4)=1
#chunk_num_heads_large=ceil(7/4)=2
#num_chunk_heads_large=len([2,2,2])=3, all2all_buffer_counts
#num_chunk_heads_small=len([1])=1, all2all_buffer_counts
#total_num_large_heads=sum([2,2,2])=7
#total_num_small_heads=sum([1])=1
chunk_num_heads_small = get_num_kv_heads() // seq_world_size # even heads compatible
chunk_num_heads_large = chunk_num_heads_small + 1
num_chunk_heads_large = get_num_kv_heads() % seq_world_size
num_chunk_heads_small = seq_world_size - num_chunk_heads_large
total_num_large_heads = num_chunk_heads_large * chunk_num_heads_large
total_num_small_heads = num_chunk_heads_small * chunk_num_heads_small
heads_large_combine_size = coeff * total_num_large_heads
heads_small_combine_size = coeff * total_num_small_heads
heads_large_chunk, heads_small_chunk = output.split([heads_large_combine_size, heads_small_combine_size],
dim=0)
heads_large_chunk = heads_large_chunk.view(num_chunk_heads_large, local_seq_len, chunk_num_heads_large, h_dim,
batch_size)
heads_small_chunk = heads_small_chunk.view(num_chunk_heads_small, local_seq_len, chunk_num_heads_small, h_dim,
batch_size)
if batch_dim_idx == 0:
#[all2all_buffer_counts, local_seq_len, n_heads,dim,batch]->[batch,local_seq_len,all2all_buffer_counts*n_heads,dim]
order = [4, 1, 0, 2, 3]
heads_large_chunk = heads_large_chunk.permute(order).contiguous().view(batch_size, local_seq_len,
total_num_large_heads, h_dim)
heads_small_chunk = heads_small_chunk.permute(order).contiguous().view(batch_size, local_seq_len,
total_num_small_heads, h_dim)
elif batch_dim_idx == 1:
#[all2all_buffer_counts, local_seq_len, n_heads,dim,batch]->[local_seq_len,batch,all2all_buffer_counts*n_heads,dim]
order = [1, 4, 0, 2, 3]
heads_large_chunk = heads_large_chunk.permute(order).contiguous().view(local_seq_len, batch_size,
total_num_large_heads, h_dim)
heads_small_chunk = heads_small_chunk.permute(order).contiguous().view(local_seq_len, batch_size,
total_num_small_heads, h_dim)
output = torch.cat([heads_large_chunk, heads_small_chunk], dim=2).contiguous()
inp_shape[scatter_idx] = inp_shape[scatter_idx] // seq_world_size
output_shape= inp_shape[: gather_idx] + \
[total_h,] + \
inp_shape[gather_idx + 1:]
output = output.view(output_shape)
return output
def single_all_to_all(input, scatter_idx, gather_idx, batch_dim_idx, group, async_op=False, handle=None, type=None):
seq_world_size = dist.get_world_size(group)
# we only need num_heads once
num_heads = input.shape[2]
if get_num_kv_heads() is not None or (num_heads % seq_world_size != 0 and not scatter_idx < 2):
# Assuming here that the number of heads for q is consistent with kv
# If not, additional logic is required for cases like GQA
if get_num_kv_heads() is None:
assert num_heads > seq_world_size, f"Number of heads ({num_heads}) must be larger than sequence parallel size ({seq_world_size})"
# set heads at first call by num_total_heads.
# then use ``get_num_kv_heads() is not None`` to re-entry uneven path.
set_num_kv_heads(num_heads)
assert async_op == False, "uneven head sp does not support async op"
return uneven_heads_all2all(input, scatter_idx, gather_idx, batch_dim_idx, group)
pre_all2all_permute_idx, pre_all2all_inp_shape, post_all2all_permute_idx, post_all2all_res_shape = _generate_layout_params(
scatter_idx, batch_dim_idx, seq_world_size, input)
input_t = pre_all2all_fun(pre_all2all_permute_idx, pre_all2all_inp_shape, input)
post_all2all_fun = post_all2all(post_all2all_permute_idx, post_all2all_res_shape)
output = torch.empty_like(input_t)
work = dist.all_to_all_single(output, input_t, group=group, async_op=async_op)
if async_op:
if type in ('dq', 'dk'):
handle[type + '_work'] = work
handle[type + '_grad'] = output
handle[type + '_post_all2all_func'] = post_all2all_fun
return output.view(post_all2all_res_shape)
res = post_all2all_fun(output)
return res
class _DimZeroAllToAll(torch.autograd.Function):
"""Differentiable All2All across dimension 0."""
@staticmethod
def forward(ctx: Any, group: dist.ProcessGroup, input: Tensor) -> Tensor:
world_size = dist.get_world_size(group)
assert input.shape[0] == world_size, f"Dim 0 {input.shape[0]} is not world size"
ctx.group = group
output = torch.empty_like(input).contiguous()
# torch.distributed.nn.functional.all_to_all_single(output, input.contiguous(), group=group)
dist.all_to_all_single(output, input.contiguous(), group=group)
return output
@staticmethod
def backward(ctx: Any, *grad_output: Tensor) -> Tuple[None, Tensor]:
return (None, _DimZeroAllToAll.apply(ctx.group, *grad_output))
class _SeqAllToAll(torch.autograd.Function):
@staticmethod
def forward(ctx: Any,
group: dist.ProcessGroup,
input: Tensor,
scatter_idx: int,
gather_idx: int,
batch_dim_idx: int,
stream=None,
handle=None,
type=None,
is_fwd=True) -> Tensor:
ctx.group = group
ctx.scatter_idx = scatter_idx
ctx.gather_idx = gather_idx
ctx.stream = stream
ctx.handle = handle
ctx.type = type
ctx.batch_dim_idx = batch_dim_idx
if ctx.handle is None:
res = single_all_to_all(input, scatter_idx, gather_idx, batch_dim_idx, group, False)
else:
# overlap communication path
if not is_fwd and type == 'o':
assert ctx.stream != None
res = single_all_to_all(input, scatter_idx, gather_idx, batch_dim_idx, group, False)
get_accelerator().current_stream().wait_stream(ctx.stream)
# The computation of d o_weight can overlap with the communication of d o_input
elif not is_fwd and type in ('q', 'k'):
# Achieve communication overlap by pipelining the matrix computation and communication of dq, dk, and dv
type = 'd' + type
res = single_all_to_all(input, scatter_idx, gather_idx, batch_dim_idx, group, True, handle, type)
elif is_fwd and type in ('q', 'k'):
# Achieve communication overlap by pipelining the matrix computation and communication of q, k, and v
type = 'fwd_' + type
res = single_all_to_all(input, scatter_idx, gather_idx, batch_dim_idx, group, False, handle, type)
else:
res = single_all_to_all(input, scatter_idx, gather_idx, batch_dim_idx, group, False)
return res
@staticmethod
def backward(ctx: Any, *grad_output: Tensor) -> Tuple[None, Tensor, None, None]:
return (None,
_SeqAllToAll.apply(ctx.group, *grad_output, ctx.gather_idx, ctx.scatter_idx, ctx.batch_dim_idx,
ctx.stream, ctx.handle, ctx.type, False), None, None, None, None, None, None, None)
class DistributedAttention(torch.nn.Module):
"""Initialization.
Arguments:
local_attention (Module): local attention with q,k,v
sequence_process_group (ProcessGroup): sequence parallel process group
scatter_idx (int): scatter_idx for all2all comm
gather_idx (int): gather_idx for all2all comm
"""
def __init__(
self,
local_attention: Module,
sequence_process_group: dist.ProcessGroup,
scatter_idx: int = 2,
gather_idx: int = 0,
sp_stream=None,
) -> None:
super(DistributedAttention, self).__init__()
self.local_attn = local_attention
self.spg = sequence_process_group
self.scatter_idx = scatter_idx
self.gather_idx = gather_idx
self.sp_overlap_comm = False
self.overlap_handles = None
self.sp_stream = sp_stream
if sp_stream is not None:
self.overlap_handles = {}
self.sp_overlap_comm = True
self.default_stream = get_accelerator().default_stream()
def layer_sync(self, layer):
if self.sp_overlap_comm and hasattr(layer, 'done_event'):
self.default_stream.wait_event(layer.done_event)
def forward(self,
query: Tensor,
key: Tensor,
value: Tensor,
batch_dim_idx: int,
rotary_pos_emb=None,
*args: Any,
**kwargs) -> Tensor:
""" forward
Arguments:
query (Tensor): query input to the layer
key (Tensor): key input to the layer
value (Tensor): value input to the layer
batch_dim_idx (int): indicating which dim is batch
args: other args
Returns:
* output (Tensor): context output
"""
# TODO Merge three alltoall calls into one
# TODO (Reza): change the api on the megatron-deepspeed side so that we only receive all data (q,k, and v) together!
#in shape : e.g., [s/p:h:]
def bwd_hook(layer_type):
def pre_hook_fun(grad):
type = 'd' + layer_type
self.overlap_handles[type + '_work'].wait()
self.sp_stream.wait_stream(self.default_stream)
all2all_output = self.overlap_handles[type + '_grad']
grad = list(grad)
grad[0] = self.overlap_handles[type + '_post_all2all_func'](all2all_output)
grad = tuple(grad)
return pre_hook_fun
self.layer_sync(query)
query_layer = _SeqAllToAll.apply(self.spg, query, self.scatter_idx, self.gather_idx, batch_dim_idx, None,
self.overlap_handles, 'q')
self.layer_sync(key)
key_layer = _SeqAllToAll.apply(self.spg, key, self.scatter_idx, self.gather_idx, batch_dim_idx, None,
self.overlap_handles, 'k')
if self.sp_overlap_comm:
self.default_stream.wait_stream(self.sp_stream)
value_layer = _SeqAllToAll.apply(self.spg, value, self.scatter_idx, self.gather_idx, batch_dim_idx, None,
self.overlap_handles, 'v')
if self.sp_overlap_comm:
# Register a hook to synchronize dq and dk after the all-to-all
# operation when the gradient data is used.
# Place this logic after the q, k, v all-to-all operation to
# improve interpreter speed to
# call and launch of the forward all-to-all communication.
grad_fn_q = query.grad_fn.next_functions[0][0]
grad_fn_q.register_prehook(bwd_hook(layer_type='q'))
grad_fn_k = key.grad_fn.next_functions[0][0]
grad_fn_k.register_prehook(bwd_hook(layer_type='k'))
#out shape : e.g., [s:h/p:]
if rotary_pos_emb is not None:
pos_emb_cos, pos_emb_sin = rotary_pos_emb[0].permute(1, 0, 2, 3), rotary_pos_emb[1].permute(1, 0, 2, 3)
query_layer = apply_rotary_pos_emb(query_layer, pos_emb_cos, pos_emb_sin)
key_layer = apply_rotary_pos_emb(key_layer, pos_emb_cos, pos_emb_sin)
context_layer = self.local_attn(query_layer, key_layer, value_layer, *args, **kwargs)
output = _SeqAllToAll.apply(self.spg, context_layer, self.gather_idx, self.scatter_idx, batch_dim_idx,
self.sp_stream, self.overlap_handles, 'o')
#out e.g., [s/p::h]
return output