File size: 20,516 Bytes
9c6594c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 |
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import torch
from typing import Any, Tuple
from torch import Tensor
from torch.nn import Module
from einops import rearrange
import deepspeed.comm as dist
from deepspeed.accelerator import get_accelerator
from deepspeed.module_inject.tp_shard import get_shard_size_list, set_num_kv_heads, get_num_kv_heads
from deepspeed.utils import groups
def _generate_layout_params(scatter_idx, batch_dim_idx, seq_world_size, input):
"""
This function generates the parameters required for `permute` and `reshape` operations,
which are used to process data before and after `all2all` communication.
"""
if batch_dim_idx == 0:
if scatter_idx < 2:
bs, global_seq_len, num_local_head, head_dim = input.shape
pre_all2all_inp_shape = [bs, seq_world_size, global_seq_len // seq_world_size, num_local_head, head_dim]
pre_all2all_permute_idx = (1, 0, 2, 3, 4)
post_all2all_permute_idx = (1, 2, 0, 3, 4)
post_all2all_res_shape = [bs, global_seq_len // seq_world_size, seq_world_size * num_local_head, head_dim]
else:
bs, local_seq_len, num_total_head, head_dim = input.shape
assert num_total_head % seq_world_size == 0, f"Number of heads ({num_total_head}) must be divisible by the sequence parallel size ({seq_world_size})!"
pre_all2all_inp_shape = [bs, local_seq_len, seq_world_size, num_total_head // seq_world_size, head_dim]
pre_all2all_permute_idx = (2, 0, 1, 3, 4)
post_all2all_permute_idx = (1, 0, 2, 3, 4)
post_all2all_res_shape = [bs, seq_world_size * local_seq_len, num_total_head // seq_world_size, head_dim]
else:
if scatter_idx < 2:
global_seq_len, bs, num_local_head, head_dim = input.shape
pre_all2all_inp_shape = [seq_world_size, global_seq_len // seq_world_size, bs, num_local_head, head_dim]
pre_all2all_permute_idx = None
post_all2all_permute_idx = (1, 2, 0, 3, 4)
post_all2all_res_shape = [bs, seq_world_size * global_seq_len, num_local_head // seq_world_size, head_dim]
else:
local_seq_len, bs, num_total_head, head_dim = input.shape
assert num_total_head % seq_world_size == 0, f"Number of heads ({num_total_head}) must be divisible by the sequence parallel size ({seq_world_size})!"
pre_all2all_inp_shape = [local_seq_len, bs, seq_world_size, num_total_head // seq_world_size, head_dim]
pre_all2all_permute_idx = (2, 0, 1, 3, 4)
post_all2all_permute_idx = None
post_all2all_res_shape = [local_seq_len * seq_world_size, bs, num_total_head // seq_world_size, head_dim]
return pre_all2all_permute_idx, pre_all2all_inp_shape, post_all2all_permute_idx, post_all2all_res_shape
def post_all2all(permute_idx, res_shape):
"""
Post-processing function for `all2all` communication.
"""
def post_func(input):
if permute_idx is not None:
input = input.permute(permute_idx).contiguous()
output = input.reshape(res_shape).contiguous()
return output
return post_func
def pre_all2all_fun(permute_idx, inp_shape, input):
"""
Pre-processing function for `all2all` communication.
"""
input_t = input.reshape(inp_shape).contiguous()
if permute_idx is not None:
input_t = input_t.permute(permute_idx).contiguous()
return input_t
def _rotate_half(x):
"""
change sign so the last dimension becomes [-odd, +even]
"""
x = rearrange(x, '... (j d) -> ... j d', j=2)
x1, x2 = x.unbind(dim=-2)
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(t, freqs_cos, freqs_sin):
"""
input tensor t is of shape [seq_length, ..., dim]
rotary positional embeding tensor freqs is of shape [seq_length, ..., dim]
check https://kexue.fm/archives/8265 for detailed formulas
"""
rot_dim = freqs_cos.shape[-1]
# ideally t_pass is empty so rotary pos embedding is applied to all tensor t
t, t_pass = t[..., :rot_dim], t[..., rot_dim:]
# first part is cosine component
# second part is sine component, need to change signs with _rotate_half method
t = (t * freqs_cos) + (_rotate_half(t) * freqs_sin)
res = t if t_pass.shape[-1] == 0 else torch.cat((t, t_pass), dim=-1)
return res
def uneven_heads_all2all(input, scatter_idx, gather_idx, batch_dim_idx, group):
seq_world_size = dist.get_world_size(group)
inp_shape = list(input.shape)
assert batch_dim_idx in [0, 1], "batch_dim_idx must be either 0 or 1"
if not (scatter_idx < 2):
input_splits = get_shard_size_list(inp_shape[scatter_idx], seq_world_size)
input = input.transpose(0, scatter_idx).contiguous()
local_heads = input_splits[groups._get_sequence_parallel_rank()]
output_splits = [local_heads] * seq_world_size
output_buffer_shape = [seq_world_size * local_heads] + list(input.shape[1:])
output = torch.empty(output_buffer_shape, device=input.device, dtype=input.dtype)
dist.all_to_all_single(output,input,output_split_sizes=output_splits,\
input_split_sizes=input_splits,group=group)
###[seq_ws*local_heads, ...] to [seq_ws, local_heads, ...]
output = output.view(seq_world_size, local_heads, *output.shape[1:])
###[seq_ws,local_heads,b,seq_len,...] to [seq_ws,seq_len,b,local_heads,...]
### batch_dim_idx=0 [seq_ws,local_heads,seq_len,b,...] to [b, seq_ws, seq_len, local_heads ...]
### batch_dim_idx=1 [seq_ws,local_heads,b,seq_len,...] to [seq_ws,seq_len,b,local_heads,...]
if batch_dim_idx == 0:
order = [3, 0, 2, 1] + list(range(4, len(output.shape)))
output = output.permute(order).contiguous()
###[b, seq_ws*local_seq_len, local_heads,...]
output = output.view(output.shape[0], inp_shape[gather_idx] * seq_world_size,
*output.shape[3:]).contiguous()
elif batch_dim_idx == 1:
output = output.transpose(1, 3).contiguous()
###[seq_ws*local_seq_len, b, local_heads,...]
output = output.view(inp_shape[gather_idx] * seq_world_size, *output.shape[2:]).contiguous()
else:
# The compatibility handling of 4D and 3D tensors, standardizing to 3D.
input = input.reshape(input.shape[0], input.shape[1], -1)
if batch_dim_idx == 0: #b,s,h
input = input.permute(1, 2, 0).contiguous() #s,h,b
elif batch_dim_idx == 1: #s,b,h
input = input.transpose(1, 2).contiguous() #s,h,b
seq_len, h, batch_size = input.shape
num_local_heads_list = get_shard_size_list(get_num_kv_heads(), seq_world_size)
local_heads = num_local_heads_list[groups._get_sequence_parallel_rank()]
h_dim = h // local_heads
local_seq_len = seq_len // seq_world_size
input = input.view(seq_len * h, batch_size)
local_seq_len_with_heads = int(input.shape[0] / seq_world_size) # dim size of local_seq_len*local_heads*hdim
input_splits = [local_seq_len_with_heads] * seq_world_size
coeff = local_seq_len_with_heads // local_heads #per head: dim size of local_seq_len*hdim
#uneven seq_world_size coeff, total_heads/local_heads.
heads_scale_coeff = get_num_kv_heads() / local_heads
output_splits = [num_local_heads * coeff for num_local_heads in num_local_heads_list]
output_buff_d1_size = int(heads_scale_coeff * local_seq_len_with_heads)
total_h = int(inp_shape[gather_idx] * heads_scale_coeff)
output = torch.empty(output_buff_d1_size, input.shape[1], device=input.device, dtype=input.dtype)
dist.all_to_all_single(output,input,output_split_sizes=output_splits, \
input_split_sizes=input_splits,group=group)
##################
#suppose 7 heads divide into 4 ranks [2,2,2,1]
#chunk_num_heads_small=floor(7/4)=1
#chunk_num_heads_large=ceil(7/4)=2
#num_chunk_heads_large=len([2,2,2])=3, all2all_buffer_counts
#num_chunk_heads_small=len([1])=1, all2all_buffer_counts
#total_num_large_heads=sum([2,2,2])=7
#total_num_small_heads=sum([1])=1
chunk_num_heads_small = get_num_kv_heads() // seq_world_size # even heads compatible
chunk_num_heads_large = chunk_num_heads_small + 1
num_chunk_heads_large = get_num_kv_heads() % seq_world_size
num_chunk_heads_small = seq_world_size - num_chunk_heads_large
total_num_large_heads = num_chunk_heads_large * chunk_num_heads_large
total_num_small_heads = num_chunk_heads_small * chunk_num_heads_small
heads_large_combine_size = coeff * total_num_large_heads
heads_small_combine_size = coeff * total_num_small_heads
heads_large_chunk, heads_small_chunk = output.split([heads_large_combine_size, heads_small_combine_size],
dim=0)
heads_large_chunk = heads_large_chunk.view(num_chunk_heads_large, local_seq_len, chunk_num_heads_large, h_dim,
batch_size)
heads_small_chunk = heads_small_chunk.view(num_chunk_heads_small, local_seq_len, chunk_num_heads_small, h_dim,
batch_size)
if batch_dim_idx == 0:
#[all2all_buffer_counts, local_seq_len, n_heads,dim,batch]->[batch,local_seq_len,all2all_buffer_counts*n_heads,dim]
order = [4, 1, 0, 2, 3]
heads_large_chunk = heads_large_chunk.permute(order).contiguous().view(batch_size, local_seq_len,
total_num_large_heads, h_dim)
heads_small_chunk = heads_small_chunk.permute(order).contiguous().view(batch_size, local_seq_len,
total_num_small_heads, h_dim)
elif batch_dim_idx == 1:
#[all2all_buffer_counts, local_seq_len, n_heads,dim,batch]->[local_seq_len,batch,all2all_buffer_counts*n_heads,dim]
order = [1, 4, 0, 2, 3]
heads_large_chunk = heads_large_chunk.permute(order).contiguous().view(local_seq_len, batch_size,
total_num_large_heads, h_dim)
heads_small_chunk = heads_small_chunk.permute(order).contiguous().view(local_seq_len, batch_size,
total_num_small_heads, h_dim)
output = torch.cat([heads_large_chunk, heads_small_chunk], dim=2).contiguous()
inp_shape[scatter_idx] = inp_shape[scatter_idx] // seq_world_size
output_shape= inp_shape[: gather_idx] + \
[total_h,] + \
inp_shape[gather_idx + 1:]
output = output.view(output_shape)
return output
def single_all_to_all(input, scatter_idx, gather_idx, batch_dim_idx, group, async_op=False, handle=None, type=None):
seq_world_size = dist.get_world_size(group)
# we only need num_heads once
num_heads = input.shape[2]
if get_num_kv_heads() is not None or (num_heads % seq_world_size != 0 and not scatter_idx < 2):
# Assuming here that the number of heads for q is consistent with kv
# If not, additional logic is required for cases like GQA
if get_num_kv_heads() is None:
assert num_heads > seq_world_size, f"Number of heads ({num_heads}) must be larger than sequence parallel size ({seq_world_size})"
# set heads at first call by num_total_heads.
# then use ``get_num_kv_heads() is not None`` to re-entry uneven path.
set_num_kv_heads(num_heads)
assert async_op == False, "uneven head sp does not support async op"
return uneven_heads_all2all(input, scatter_idx, gather_idx, batch_dim_idx, group)
pre_all2all_permute_idx, pre_all2all_inp_shape, post_all2all_permute_idx, post_all2all_res_shape = _generate_layout_params(
scatter_idx, batch_dim_idx, seq_world_size, input)
input_t = pre_all2all_fun(pre_all2all_permute_idx, pre_all2all_inp_shape, input)
post_all2all_fun = post_all2all(post_all2all_permute_idx, post_all2all_res_shape)
output = torch.empty_like(input_t)
work = dist.all_to_all_single(output, input_t, group=group, async_op=async_op)
if async_op:
if type in ('dq', 'dk'):
handle[type + '_work'] = work
handle[type + '_grad'] = output
handle[type + '_post_all2all_func'] = post_all2all_fun
return output.view(post_all2all_res_shape)
res = post_all2all_fun(output)
return res
class _DimZeroAllToAll(torch.autograd.Function):
"""Differentiable All2All across dimension 0."""
@staticmethod
def forward(ctx: Any, group: dist.ProcessGroup, input: Tensor) -> Tensor:
world_size = dist.get_world_size(group)
assert input.shape[0] == world_size, f"Dim 0 {input.shape[0]} is not world size"
ctx.group = group
output = torch.empty_like(input).contiguous()
# torch.distributed.nn.functional.all_to_all_single(output, input.contiguous(), group=group)
dist.all_to_all_single(output, input.contiguous(), group=group)
return output
@staticmethod
def backward(ctx: Any, *grad_output: Tensor) -> Tuple[None, Tensor]:
return (None, _DimZeroAllToAll.apply(ctx.group, *grad_output))
class _SeqAllToAll(torch.autograd.Function):
@staticmethod
def forward(ctx: Any,
group: dist.ProcessGroup,
input: Tensor,
scatter_idx: int,
gather_idx: int,
batch_dim_idx: int,
stream=None,
handle=None,
type=None,
is_fwd=True) -> Tensor:
ctx.group = group
ctx.scatter_idx = scatter_idx
ctx.gather_idx = gather_idx
ctx.stream = stream
ctx.handle = handle
ctx.type = type
ctx.batch_dim_idx = batch_dim_idx
if ctx.handle is None:
res = single_all_to_all(input, scatter_idx, gather_idx, batch_dim_idx, group, False)
else:
# overlap communication path
if not is_fwd and type == 'o':
assert ctx.stream != None
res = single_all_to_all(input, scatter_idx, gather_idx, batch_dim_idx, group, False)
get_accelerator().current_stream().wait_stream(ctx.stream)
# The computation of d o_weight can overlap with the communication of d o_input
elif not is_fwd and type in ('q', 'k'):
# Achieve communication overlap by pipelining the matrix computation and communication of dq, dk, and dv
type = 'd' + type
res = single_all_to_all(input, scatter_idx, gather_idx, batch_dim_idx, group, True, handle, type)
elif is_fwd and type in ('q', 'k'):
# Achieve communication overlap by pipelining the matrix computation and communication of q, k, and v
type = 'fwd_' + type
res = single_all_to_all(input, scatter_idx, gather_idx, batch_dim_idx, group, False, handle, type)
else:
res = single_all_to_all(input, scatter_idx, gather_idx, batch_dim_idx, group, False)
return res
@staticmethod
def backward(ctx: Any, *grad_output: Tensor) -> Tuple[None, Tensor, None, None]:
return (None,
_SeqAllToAll.apply(ctx.group, *grad_output, ctx.gather_idx, ctx.scatter_idx, ctx.batch_dim_idx,
ctx.stream, ctx.handle, ctx.type, False), None, None, None, None, None, None, None)
class DistributedAttention(torch.nn.Module):
"""Initialization.
Arguments:
local_attention (Module): local attention with q,k,v
sequence_process_group (ProcessGroup): sequence parallel process group
scatter_idx (int): scatter_idx for all2all comm
gather_idx (int): gather_idx for all2all comm
"""
def __init__(
self,
local_attention: Module,
sequence_process_group: dist.ProcessGroup,
scatter_idx: int = 2,
gather_idx: int = 0,
sp_stream=None,
) -> None:
super(DistributedAttention, self).__init__()
self.local_attn = local_attention
self.spg = sequence_process_group
self.scatter_idx = scatter_idx
self.gather_idx = gather_idx
self.sp_overlap_comm = False
self.overlap_handles = None
self.sp_stream = sp_stream
if sp_stream is not None:
self.overlap_handles = {}
self.sp_overlap_comm = True
self.default_stream = get_accelerator().default_stream()
def layer_sync(self, layer):
if self.sp_overlap_comm and hasattr(layer, 'done_event'):
self.default_stream.wait_event(layer.done_event)
def forward(self,
query: Tensor,
key: Tensor,
value: Tensor,
batch_dim_idx: int,
rotary_pos_emb=None,
*args: Any,
**kwargs) -> Tensor:
""" forward
Arguments:
query (Tensor): query input to the layer
key (Tensor): key input to the layer
value (Tensor): value input to the layer
batch_dim_idx (int): indicating which dim is batch
args: other args
Returns:
* output (Tensor): context output
"""
# TODO Merge three alltoall calls into one
# TODO (Reza): change the api on the megatron-deepspeed side so that we only receive all data (q,k, and v) together!
#in shape : e.g., [s/p:h:]
def bwd_hook(layer_type):
def pre_hook_fun(grad):
type = 'd' + layer_type
self.overlap_handles[type + '_work'].wait()
self.sp_stream.wait_stream(self.default_stream)
all2all_output = self.overlap_handles[type + '_grad']
grad = list(grad)
grad[0] = self.overlap_handles[type + '_post_all2all_func'](all2all_output)
grad = tuple(grad)
return pre_hook_fun
self.layer_sync(query)
query_layer = _SeqAllToAll.apply(self.spg, query, self.scatter_idx, self.gather_idx, batch_dim_idx, None,
self.overlap_handles, 'q')
self.layer_sync(key)
key_layer = _SeqAllToAll.apply(self.spg, key, self.scatter_idx, self.gather_idx, batch_dim_idx, None,
self.overlap_handles, 'k')
if self.sp_overlap_comm:
self.default_stream.wait_stream(self.sp_stream)
value_layer = _SeqAllToAll.apply(self.spg, value, self.scatter_idx, self.gather_idx, batch_dim_idx, None,
self.overlap_handles, 'v')
if self.sp_overlap_comm:
# Register a hook to synchronize dq and dk after the all-to-all
# operation when the gradient data is used.
# Place this logic after the q, k, v all-to-all operation to
# improve interpreter speed to
# call and launch of the forward all-to-all communication.
grad_fn_q = query.grad_fn.next_functions[0][0]
grad_fn_q.register_prehook(bwd_hook(layer_type='q'))
grad_fn_k = key.grad_fn.next_functions[0][0]
grad_fn_k.register_prehook(bwd_hook(layer_type='k'))
#out shape : e.g., [s:h/p:]
if rotary_pos_emb is not None:
pos_emb_cos, pos_emb_sin = rotary_pos_emb[0].permute(1, 0, 2, 3), rotary_pos_emb[1].permute(1, 0, 2, 3)
query_layer = apply_rotary_pos_emb(query_layer, pos_emb_cos, pos_emb_sin)
key_layer = apply_rotary_pos_emb(key_layer, pos_emb_cos, pos_emb_sin)
context_layer = self.local_attn(query_layer, key_layer, value_layer, *args, **kwargs)
output = _SeqAllToAll.apply(self.spg, context_layer, self.gather_idx, self.scatter_idx, batch_dim_idx,
self.sp_stream, self.overlap_handles, 'o')
#out e.g., [s/p::h]
return output
|