|
|
|
|
|
|
|
|
|
import torch |
|
|
|
from typing import Any, Tuple |
|
from torch import Tensor |
|
from torch.nn import Module |
|
|
|
from einops import rearrange |
|
|
|
import deepspeed.comm as dist |
|
from deepspeed.accelerator import get_accelerator |
|
from deepspeed.module_inject.tp_shard import get_shard_size_list, set_num_kv_heads, get_num_kv_heads |
|
from deepspeed.utils import groups |
|
|
|
|
|
def _generate_layout_params(scatter_idx, batch_dim_idx, seq_world_size, input): |
|
""" |
|
This function generates the parameters required for `permute` and `reshape` operations, |
|
which are used to process data before and after `all2all` communication. |
|
""" |
|
if batch_dim_idx == 0: |
|
if scatter_idx < 2: |
|
bs, global_seq_len, num_local_head, head_dim = input.shape |
|
pre_all2all_inp_shape = [bs, seq_world_size, global_seq_len // seq_world_size, num_local_head, head_dim] |
|
pre_all2all_permute_idx = (1, 0, 2, 3, 4) |
|
|
|
post_all2all_permute_idx = (1, 2, 0, 3, 4) |
|
post_all2all_res_shape = [bs, global_seq_len // seq_world_size, seq_world_size * num_local_head, head_dim] |
|
else: |
|
bs, local_seq_len, num_total_head, head_dim = input.shape |
|
assert num_total_head % seq_world_size == 0, f"Number of heads ({num_total_head}) must be divisible by the sequence parallel size ({seq_world_size})!" |
|
pre_all2all_inp_shape = [bs, local_seq_len, seq_world_size, num_total_head // seq_world_size, head_dim] |
|
pre_all2all_permute_idx = (2, 0, 1, 3, 4) |
|
|
|
post_all2all_permute_idx = (1, 0, 2, 3, 4) |
|
post_all2all_res_shape = [bs, seq_world_size * local_seq_len, num_total_head // seq_world_size, head_dim] |
|
else: |
|
if scatter_idx < 2: |
|
global_seq_len, bs, num_local_head, head_dim = input.shape |
|
pre_all2all_inp_shape = [seq_world_size, global_seq_len // seq_world_size, bs, num_local_head, head_dim] |
|
pre_all2all_permute_idx = None |
|
|
|
post_all2all_permute_idx = (1, 2, 0, 3, 4) |
|
post_all2all_res_shape = [bs, seq_world_size * global_seq_len, num_local_head // seq_world_size, head_dim] |
|
else: |
|
local_seq_len, bs, num_total_head, head_dim = input.shape |
|
assert num_total_head % seq_world_size == 0, f"Number of heads ({num_total_head}) must be divisible by the sequence parallel size ({seq_world_size})!" |
|
pre_all2all_inp_shape = [local_seq_len, bs, seq_world_size, num_total_head // seq_world_size, head_dim] |
|
pre_all2all_permute_idx = (2, 0, 1, 3, 4) |
|
post_all2all_permute_idx = None |
|
post_all2all_res_shape = [local_seq_len * seq_world_size, bs, num_total_head // seq_world_size, head_dim] |
|
|
|
return pre_all2all_permute_idx, pre_all2all_inp_shape, post_all2all_permute_idx, post_all2all_res_shape |
|
|
|
|
|
def post_all2all(permute_idx, res_shape): |
|
""" |
|
Post-processing function for `all2all` communication. |
|
""" |
|
|
|
def post_func(input): |
|
if permute_idx is not None: |
|
input = input.permute(permute_idx).contiguous() |
|
output = input.reshape(res_shape).contiguous() |
|
|
|
return output |
|
|
|
return post_func |
|
|
|
|
|
def pre_all2all_fun(permute_idx, inp_shape, input): |
|
""" |
|
Pre-processing function for `all2all` communication. |
|
""" |
|
input_t = input.reshape(inp_shape).contiguous() |
|
if permute_idx is not None: |
|
input_t = input_t.permute(permute_idx).contiguous() |
|
return input_t |
|
|
|
|
|
def _rotate_half(x): |
|
""" |
|
change sign so the last dimension becomes [-odd, +even] |
|
""" |
|
x = rearrange(x, '... (j d) -> ... j d', j=2) |
|
x1, x2 = x.unbind(dim=-2) |
|
return torch.cat((-x2, x1), dim=-1) |
|
|
|
|
|
def apply_rotary_pos_emb(t, freqs_cos, freqs_sin): |
|
""" |
|
input tensor t is of shape [seq_length, ..., dim] |
|
rotary positional embeding tensor freqs is of shape [seq_length, ..., dim] |
|
check https://kexue.fm/archives/8265 for detailed formulas |
|
""" |
|
rot_dim = freqs_cos.shape[-1] |
|
|
|
t, t_pass = t[..., :rot_dim], t[..., rot_dim:] |
|
|
|
|
|
|
|
t = (t * freqs_cos) + (_rotate_half(t) * freqs_sin) |
|
|
|
res = t if t_pass.shape[-1] == 0 else torch.cat((t, t_pass), dim=-1) |
|
return res |
|
|
|
|
|
def uneven_heads_all2all(input, scatter_idx, gather_idx, batch_dim_idx, group): |
|
seq_world_size = dist.get_world_size(group) |
|
inp_shape = list(input.shape) |
|
assert batch_dim_idx in [0, 1], "batch_dim_idx must be either 0 or 1" |
|
|
|
if not (scatter_idx < 2): |
|
input_splits = get_shard_size_list(inp_shape[scatter_idx], seq_world_size) |
|
input = input.transpose(0, scatter_idx).contiguous() |
|
local_heads = input_splits[groups._get_sequence_parallel_rank()] |
|
output_splits = [local_heads] * seq_world_size |
|
|
|
output_buffer_shape = [seq_world_size * local_heads] + list(input.shape[1:]) |
|
output = torch.empty(output_buffer_shape, device=input.device, dtype=input.dtype) |
|
dist.all_to_all_single(output,input,output_split_sizes=output_splits,\ |
|
input_split_sizes=input_splits,group=group) |
|
|
|
output = output.view(seq_world_size, local_heads, *output.shape[1:]) |
|
|
|
|
|
|
|
|
|
if batch_dim_idx == 0: |
|
order = [3, 0, 2, 1] + list(range(4, len(output.shape))) |
|
output = output.permute(order).contiguous() |
|
|
|
output = output.view(output.shape[0], inp_shape[gather_idx] * seq_world_size, |
|
*output.shape[3:]).contiguous() |
|
elif batch_dim_idx == 1: |
|
output = output.transpose(1, 3).contiguous() |
|
|
|
output = output.view(inp_shape[gather_idx] * seq_world_size, *output.shape[2:]).contiguous() |
|
else: |
|
|
|
input = input.reshape(input.shape[0], input.shape[1], -1) |
|
|
|
if batch_dim_idx == 0: |
|
input = input.permute(1, 2, 0).contiguous() |
|
elif batch_dim_idx == 1: |
|
input = input.transpose(1, 2).contiguous() |
|
seq_len, h, batch_size = input.shape |
|
num_local_heads_list = get_shard_size_list(get_num_kv_heads(), seq_world_size) |
|
local_heads = num_local_heads_list[groups._get_sequence_parallel_rank()] |
|
h_dim = h // local_heads |
|
local_seq_len = seq_len // seq_world_size |
|
|
|
input = input.view(seq_len * h, batch_size) |
|
local_seq_len_with_heads = int(input.shape[0] / seq_world_size) |
|
input_splits = [local_seq_len_with_heads] * seq_world_size |
|
coeff = local_seq_len_with_heads // local_heads |
|
|
|
|
|
heads_scale_coeff = get_num_kv_heads() / local_heads |
|
|
|
output_splits = [num_local_heads * coeff for num_local_heads in num_local_heads_list] |
|
output_buff_d1_size = int(heads_scale_coeff * local_seq_len_with_heads) |
|
total_h = int(inp_shape[gather_idx] * heads_scale_coeff) |
|
output = torch.empty(output_buff_d1_size, input.shape[1], device=input.device, dtype=input.dtype) |
|
dist.all_to_all_single(output,input,output_split_sizes=output_splits, \ |
|
input_split_sizes=input_splits,group=group) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
chunk_num_heads_small = get_num_kv_heads() // seq_world_size |
|
chunk_num_heads_large = chunk_num_heads_small + 1 |
|
num_chunk_heads_large = get_num_kv_heads() % seq_world_size |
|
num_chunk_heads_small = seq_world_size - num_chunk_heads_large |
|
total_num_large_heads = num_chunk_heads_large * chunk_num_heads_large |
|
total_num_small_heads = num_chunk_heads_small * chunk_num_heads_small |
|
|
|
heads_large_combine_size = coeff * total_num_large_heads |
|
heads_small_combine_size = coeff * total_num_small_heads |
|
heads_large_chunk, heads_small_chunk = output.split([heads_large_combine_size, heads_small_combine_size], |
|
dim=0) |
|
heads_large_chunk = heads_large_chunk.view(num_chunk_heads_large, local_seq_len, chunk_num_heads_large, h_dim, |
|
batch_size) |
|
heads_small_chunk = heads_small_chunk.view(num_chunk_heads_small, local_seq_len, chunk_num_heads_small, h_dim, |
|
batch_size) |
|
if batch_dim_idx == 0: |
|
|
|
order = [4, 1, 0, 2, 3] |
|
heads_large_chunk = heads_large_chunk.permute(order).contiguous().view(batch_size, local_seq_len, |
|
total_num_large_heads, h_dim) |
|
heads_small_chunk = heads_small_chunk.permute(order).contiguous().view(batch_size, local_seq_len, |
|
total_num_small_heads, h_dim) |
|
elif batch_dim_idx == 1: |
|
|
|
order = [1, 4, 0, 2, 3] |
|
heads_large_chunk = heads_large_chunk.permute(order).contiguous().view(local_seq_len, batch_size, |
|
total_num_large_heads, h_dim) |
|
heads_small_chunk = heads_small_chunk.permute(order).contiguous().view(local_seq_len, batch_size, |
|
total_num_small_heads, h_dim) |
|
|
|
output = torch.cat([heads_large_chunk, heads_small_chunk], dim=2).contiguous() |
|
|
|
inp_shape[scatter_idx] = inp_shape[scatter_idx] // seq_world_size |
|
output_shape= inp_shape[: gather_idx] + \ |
|
[total_h,] + \ |
|
inp_shape[gather_idx + 1:] |
|
|
|
output = output.view(output_shape) |
|
|
|
return output |
|
|
|
|
|
def single_all_to_all(input, scatter_idx, gather_idx, batch_dim_idx, group, async_op=False, handle=None, type=None): |
|
seq_world_size = dist.get_world_size(group) |
|
|
|
num_heads = input.shape[2] |
|
|
|
if get_num_kv_heads() is not None or (num_heads % seq_world_size != 0 and not scatter_idx < 2): |
|
|
|
|
|
if get_num_kv_heads() is None: |
|
assert num_heads > seq_world_size, f"Number of heads ({num_heads}) must be larger than sequence parallel size ({seq_world_size})" |
|
|
|
|
|
set_num_kv_heads(num_heads) |
|
assert async_op == False, "uneven head sp does not support async op" |
|
return uneven_heads_all2all(input, scatter_idx, gather_idx, batch_dim_idx, group) |
|
|
|
pre_all2all_permute_idx, pre_all2all_inp_shape, post_all2all_permute_idx, post_all2all_res_shape = _generate_layout_params( |
|
scatter_idx, batch_dim_idx, seq_world_size, input) |
|
|
|
input_t = pre_all2all_fun(pre_all2all_permute_idx, pre_all2all_inp_shape, input) |
|
|
|
post_all2all_fun = post_all2all(post_all2all_permute_idx, post_all2all_res_shape) |
|
output = torch.empty_like(input_t) |
|
work = dist.all_to_all_single(output, input_t, group=group, async_op=async_op) |
|
|
|
if async_op: |
|
if type in ('dq', 'dk'): |
|
handle[type + '_work'] = work |
|
handle[type + '_grad'] = output |
|
handle[type + '_post_all2all_func'] = post_all2all_fun |
|
return output.view(post_all2all_res_shape) |
|
|
|
res = post_all2all_fun(output) |
|
return res |
|
|
|
|
|
class _DimZeroAllToAll(torch.autograd.Function): |
|
"""Differentiable All2All across dimension 0.""" |
|
|
|
@staticmethod |
|
def forward(ctx: Any, group: dist.ProcessGroup, input: Tensor) -> Tensor: |
|
world_size = dist.get_world_size(group) |
|
assert input.shape[0] == world_size, f"Dim 0 {input.shape[0]} is not world size" |
|
|
|
ctx.group = group |
|
|
|
output = torch.empty_like(input).contiguous() |
|
|
|
dist.all_to_all_single(output, input.contiguous(), group=group) |
|
return output |
|
|
|
@staticmethod |
|
def backward(ctx: Any, *grad_output: Tensor) -> Tuple[None, Tensor]: |
|
return (None, _DimZeroAllToAll.apply(ctx.group, *grad_output)) |
|
|
|
|
|
class _SeqAllToAll(torch.autograd.Function): |
|
|
|
@staticmethod |
|
def forward(ctx: Any, |
|
group: dist.ProcessGroup, |
|
input: Tensor, |
|
scatter_idx: int, |
|
gather_idx: int, |
|
batch_dim_idx: int, |
|
stream=None, |
|
handle=None, |
|
type=None, |
|
is_fwd=True) -> Tensor: |
|
ctx.group = group |
|
ctx.scatter_idx = scatter_idx |
|
ctx.gather_idx = gather_idx |
|
ctx.stream = stream |
|
ctx.handle = handle |
|
ctx.type = type |
|
ctx.batch_dim_idx = batch_dim_idx |
|
if ctx.handle is None: |
|
res = single_all_to_all(input, scatter_idx, gather_idx, batch_dim_idx, group, False) |
|
|
|
else: |
|
|
|
if not is_fwd and type == 'o': |
|
assert ctx.stream != None |
|
res = single_all_to_all(input, scatter_idx, gather_idx, batch_dim_idx, group, False) |
|
get_accelerator().current_stream().wait_stream(ctx.stream) |
|
|
|
|
|
elif not is_fwd and type in ('q', 'k'): |
|
|
|
type = 'd' + type |
|
res = single_all_to_all(input, scatter_idx, gather_idx, batch_dim_idx, group, True, handle, type) |
|
|
|
elif is_fwd and type in ('q', 'k'): |
|
|
|
type = 'fwd_' + type |
|
res = single_all_to_all(input, scatter_idx, gather_idx, batch_dim_idx, group, False, handle, type) |
|
|
|
else: |
|
res = single_all_to_all(input, scatter_idx, gather_idx, batch_dim_idx, group, False) |
|
|
|
return res |
|
|
|
@staticmethod |
|
def backward(ctx: Any, *grad_output: Tensor) -> Tuple[None, Tensor, None, None]: |
|
|
|
return (None, |
|
_SeqAllToAll.apply(ctx.group, *grad_output, ctx.gather_idx, ctx.scatter_idx, ctx.batch_dim_idx, |
|
ctx.stream, ctx.handle, ctx.type, False), None, None, None, None, None, None, None) |
|
|
|
|
|
class DistributedAttention(torch.nn.Module): |
|
"""Initialization. |
|
|
|
Arguments: |
|
local_attention (Module): local attention with q,k,v |
|
sequence_process_group (ProcessGroup): sequence parallel process group |
|
scatter_idx (int): scatter_idx for all2all comm |
|
gather_idx (int): gather_idx for all2all comm |
|
""" |
|
|
|
def __init__( |
|
self, |
|
local_attention: Module, |
|
sequence_process_group: dist.ProcessGroup, |
|
scatter_idx: int = 2, |
|
gather_idx: int = 0, |
|
sp_stream=None, |
|
) -> None: |
|
|
|
super(DistributedAttention, self).__init__() |
|
self.local_attn = local_attention |
|
self.spg = sequence_process_group |
|
self.scatter_idx = scatter_idx |
|
self.gather_idx = gather_idx |
|
self.sp_overlap_comm = False |
|
self.overlap_handles = None |
|
self.sp_stream = sp_stream |
|
if sp_stream is not None: |
|
self.overlap_handles = {} |
|
self.sp_overlap_comm = True |
|
self.default_stream = get_accelerator().default_stream() |
|
|
|
def layer_sync(self, layer): |
|
if self.sp_overlap_comm and hasattr(layer, 'done_event'): |
|
self.default_stream.wait_event(layer.done_event) |
|
|
|
def forward(self, |
|
query: Tensor, |
|
key: Tensor, |
|
value: Tensor, |
|
batch_dim_idx: int, |
|
rotary_pos_emb=None, |
|
*args: Any, |
|
**kwargs) -> Tensor: |
|
""" forward |
|
|
|
Arguments: |
|
query (Tensor): query input to the layer |
|
key (Tensor): key input to the layer |
|
value (Tensor): value input to the layer |
|
batch_dim_idx (int): indicating which dim is batch |
|
args: other args |
|
|
|
Returns: |
|
* output (Tensor): context output |
|
""" |
|
|
|
|
|
|
|
|
|
|
|
def bwd_hook(layer_type): |
|
|
|
def pre_hook_fun(grad): |
|
type = 'd' + layer_type |
|
self.overlap_handles[type + '_work'].wait() |
|
self.sp_stream.wait_stream(self.default_stream) |
|
all2all_output = self.overlap_handles[type + '_grad'] |
|
grad = list(grad) |
|
grad[0] = self.overlap_handles[type + '_post_all2all_func'](all2all_output) |
|
grad = tuple(grad) |
|
|
|
return pre_hook_fun |
|
|
|
self.layer_sync(query) |
|
query_layer = _SeqAllToAll.apply(self.spg, query, self.scatter_idx, self.gather_idx, batch_dim_idx, None, |
|
self.overlap_handles, 'q') |
|
self.layer_sync(key) |
|
key_layer = _SeqAllToAll.apply(self.spg, key, self.scatter_idx, self.gather_idx, batch_dim_idx, None, |
|
self.overlap_handles, 'k') |
|
if self.sp_overlap_comm: |
|
self.default_stream.wait_stream(self.sp_stream) |
|
|
|
value_layer = _SeqAllToAll.apply(self.spg, value, self.scatter_idx, self.gather_idx, batch_dim_idx, None, |
|
self.overlap_handles, 'v') |
|
|
|
if self.sp_overlap_comm: |
|
|
|
|
|
|
|
|
|
|
|
grad_fn_q = query.grad_fn.next_functions[0][0] |
|
grad_fn_q.register_prehook(bwd_hook(layer_type='q')) |
|
grad_fn_k = key.grad_fn.next_functions[0][0] |
|
grad_fn_k.register_prehook(bwd_hook(layer_type='k')) |
|
|
|
|
|
if rotary_pos_emb is not None: |
|
pos_emb_cos, pos_emb_sin = rotary_pos_emb[0].permute(1, 0, 2, 3), rotary_pos_emb[1].permute(1, 0, 2, 3) |
|
query_layer = apply_rotary_pos_emb(query_layer, pos_emb_cos, pos_emb_sin) |
|
key_layer = apply_rotary_pos_emb(key_layer, pos_emb_cos, pos_emb_sin) |
|
|
|
context_layer = self.local_attn(query_layer, key_layer, value_layer, *args, **kwargs) |
|
|
|
output = _SeqAllToAll.apply(self.spg, context_layer, self.gather_idx, self.scatter_idx, batch_dim_idx, |
|
self.sp_stream, self.overlap_handles, 'o') |
|
|
|
|
|
return output |
|
|