File size: 20,516 Bytes
9c6594c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team
import torch

from typing import Any, Tuple
from torch import Tensor
from torch.nn import Module

from einops import rearrange

import deepspeed.comm as dist
from deepspeed.accelerator import get_accelerator
from deepspeed.module_inject.tp_shard import get_shard_size_list, set_num_kv_heads, get_num_kv_heads
from deepspeed.utils import groups


def _generate_layout_params(scatter_idx, batch_dim_idx, seq_world_size, input):
    """
    This function generates the parameters required for `permute` and `reshape` operations,
    which are used to process data before and after `all2all` communication.
    """
    if batch_dim_idx == 0:
        if scatter_idx < 2:
            bs, global_seq_len, num_local_head, head_dim = input.shape
            pre_all2all_inp_shape = [bs, seq_world_size, global_seq_len // seq_world_size, num_local_head, head_dim]
            pre_all2all_permute_idx = (1, 0, 2, 3, 4)

            post_all2all_permute_idx = (1, 2, 0, 3, 4)
            post_all2all_res_shape = [bs, global_seq_len // seq_world_size, seq_world_size * num_local_head, head_dim]
        else:
            bs, local_seq_len, num_total_head, head_dim = input.shape
            assert num_total_head % seq_world_size == 0, f"Number of heads ({num_total_head}) must be divisible by the sequence parallel size ({seq_world_size})!"
            pre_all2all_inp_shape = [bs, local_seq_len, seq_world_size, num_total_head // seq_world_size, head_dim]
            pre_all2all_permute_idx = (2, 0, 1, 3, 4)

            post_all2all_permute_idx = (1, 0, 2, 3, 4)
            post_all2all_res_shape = [bs, seq_world_size * local_seq_len, num_total_head // seq_world_size, head_dim]
    else:
        if scatter_idx < 2:
            global_seq_len, bs, num_local_head, head_dim = input.shape
            pre_all2all_inp_shape = [seq_world_size, global_seq_len // seq_world_size, bs, num_local_head, head_dim]
            pre_all2all_permute_idx = None

            post_all2all_permute_idx = (1, 2, 0, 3, 4)
            post_all2all_res_shape = [bs, seq_world_size * global_seq_len, num_local_head // seq_world_size, head_dim]
        else:
            local_seq_len, bs, num_total_head, head_dim = input.shape
            assert num_total_head % seq_world_size == 0, f"Number of heads ({num_total_head}) must be divisible by the sequence parallel size ({seq_world_size})!"
            pre_all2all_inp_shape = [local_seq_len, bs, seq_world_size, num_total_head // seq_world_size, head_dim]
            pre_all2all_permute_idx = (2, 0, 1, 3, 4)
            post_all2all_permute_idx = None
            post_all2all_res_shape = [local_seq_len * seq_world_size, bs, num_total_head // seq_world_size, head_dim]

    return pre_all2all_permute_idx, pre_all2all_inp_shape, post_all2all_permute_idx, post_all2all_res_shape


def post_all2all(permute_idx, res_shape):
    """
    Post-processing function for `all2all` communication.
    """

    def post_func(input):
        if permute_idx is not None:
            input = input.permute(permute_idx).contiguous()
        output = input.reshape(res_shape).contiguous()

        return output

    return post_func


def pre_all2all_fun(permute_idx, inp_shape, input):
    """
    Pre-processing function for `all2all` communication.
    """
    input_t = input.reshape(inp_shape).contiguous()
    if permute_idx is not None:
        input_t = input_t.permute(permute_idx).contiguous()
    return input_t


def _rotate_half(x):
    """
    change sign so the last dimension becomes [-odd, +even]
    """
    x = rearrange(x, '... (j d) -> ... j d', j=2)
    x1, x2 = x.unbind(dim=-2)
    return torch.cat((-x2, x1), dim=-1)


def apply_rotary_pos_emb(t, freqs_cos, freqs_sin):
    """
    input tensor t is of shape [seq_length, ..., dim]
    rotary positional embeding tensor freqs is of shape [seq_length, ..., dim]
    check https://kexue.fm/archives/8265 for detailed formulas
    """
    rot_dim = freqs_cos.shape[-1]
    # ideally t_pass is empty so rotary pos embedding is applied to all tensor t
    t, t_pass = t[..., :rot_dim], t[..., rot_dim:]

    # first part is cosine component
    # second part is sine component, need to change signs with _rotate_half method
    t = (t * freqs_cos) + (_rotate_half(t) * freqs_sin)

    res = t if t_pass.shape[-1] == 0 else torch.cat((t, t_pass), dim=-1)
    return res


def uneven_heads_all2all(input, scatter_idx, gather_idx, batch_dim_idx, group):
    seq_world_size = dist.get_world_size(group)
    inp_shape = list(input.shape)
    assert batch_dim_idx in [0, 1], "batch_dim_idx must be either 0 or 1"

    if not (scatter_idx < 2):
        input_splits = get_shard_size_list(inp_shape[scatter_idx], seq_world_size)
        input = input.transpose(0, scatter_idx).contiguous()
        local_heads = input_splits[groups._get_sequence_parallel_rank()]
        output_splits = [local_heads] * seq_world_size

        output_buffer_shape = [seq_world_size * local_heads] + list(input.shape[1:])
        output = torch.empty(output_buffer_shape, device=input.device, dtype=input.dtype)
        dist.all_to_all_single(output,input,output_split_sizes=output_splits,\
            input_split_sizes=input_splits,group=group)
        ###[seq_ws*local_heads, ...] to [seq_ws, local_heads, ...]
        output = output.view(seq_world_size, local_heads, *output.shape[1:])
        ###[seq_ws,local_heads,b,seq_len,...] to [seq_ws,seq_len,b,local_heads,...]

        ### batch_dim_idx=0 [seq_ws,local_heads,seq_len,b,...] to [b, seq_ws, seq_len, local_heads ...]
        ### batch_dim_idx=1 [seq_ws,local_heads,b,seq_len,...] to [seq_ws,seq_len,b,local_heads,...]
        if batch_dim_idx == 0:
            order = [3, 0, 2, 1] + list(range(4, len(output.shape)))
            output = output.permute(order).contiguous()
            ###[b, seq_ws*local_seq_len, local_heads,...]
            output = output.view(output.shape[0], inp_shape[gather_idx] * seq_world_size,
                                 *output.shape[3:]).contiguous()
        elif batch_dim_idx == 1:
            output = output.transpose(1, 3).contiguous()
            ###[seq_ws*local_seq_len, b, local_heads,...]
            output = output.view(inp_shape[gather_idx] * seq_world_size, *output.shape[2:]).contiguous()
    else:
        # The compatibility handling of 4D and 3D tensors, standardizing to 3D.
        input = input.reshape(input.shape[0], input.shape[1], -1)

        if batch_dim_idx == 0:  #b,s,h
            input = input.permute(1, 2, 0).contiguous()  #s,h,b
        elif batch_dim_idx == 1:  #s,b,h
            input = input.transpose(1, 2).contiguous()  #s,h,b
        seq_len, h, batch_size = input.shape
        num_local_heads_list = get_shard_size_list(get_num_kv_heads(), seq_world_size)
        local_heads = num_local_heads_list[groups._get_sequence_parallel_rank()]
        h_dim = h // local_heads
        local_seq_len = seq_len // seq_world_size

        input = input.view(seq_len * h, batch_size)
        local_seq_len_with_heads = int(input.shape[0] / seq_world_size)  # dim size of local_seq_len*local_heads*hdim
        input_splits = [local_seq_len_with_heads] * seq_world_size
        coeff = local_seq_len_with_heads // local_heads  #per head: dim size of local_seq_len*hdim

        #uneven seq_world_size coeff,  total_heads/local_heads.
        heads_scale_coeff = get_num_kv_heads() / local_heads

        output_splits = [num_local_heads * coeff for num_local_heads in num_local_heads_list]
        output_buff_d1_size = int(heads_scale_coeff * local_seq_len_with_heads)
        total_h = int(inp_shape[gather_idx] * heads_scale_coeff)
        output = torch.empty(output_buff_d1_size, input.shape[1], device=input.device, dtype=input.dtype)
        dist.all_to_all_single(output,input,output_split_sizes=output_splits, \
            input_split_sizes=input_splits,group=group)
        ##################
        #suppose 7 heads divide into 4 ranks [2,2,2,1]
        #chunk_num_heads_small=floor(7/4)=1
        #chunk_num_heads_large=ceil(7/4)=2
        #num_chunk_heads_large=len([2,2,2])=3, all2all_buffer_counts
        #num_chunk_heads_small=len([1])=1, all2all_buffer_counts
        #total_num_large_heads=sum([2,2,2])=7
        #total_num_small_heads=sum([1])=1

        chunk_num_heads_small = get_num_kv_heads() // seq_world_size  # even heads compatible
        chunk_num_heads_large = chunk_num_heads_small + 1
        num_chunk_heads_large = get_num_kv_heads() % seq_world_size
        num_chunk_heads_small = seq_world_size - num_chunk_heads_large
        total_num_large_heads = num_chunk_heads_large * chunk_num_heads_large
        total_num_small_heads = num_chunk_heads_small * chunk_num_heads_small

        heads_large_combine_size = coeff * total_num_large_heads
        heads_small_combine_size = coeff * total_num_small_heads
        heads_large_chunk, heads_small_chunk = output.split([heads_large_combine_size, heads_small_combine_size],
                                                            dim=0)
        heads_large_chunk = heads_large_chunk.view(num_chunk_heads_large, local_seq_len, chunk_num_heads_large, h_dim,
                                                   batch_size)
        heads_small_chunk = heads_small_chunk.view(num_chunk_heads_small, local_seq_len, chunk_num_heads_small, h_dim,
                                                   batch_size)
        if batch_dim_idx == 0:
            #[all2all_buffer_counts, local_seq_len, n_heads,dim,batch]->[batch,local_seq_len,all2all_buffer_counts*n_heads,dim]
            order = [4, 1, 0, 2, 3]
            heads_large_chunk = heads_large_chunk.permute(order).contiguous().view(batch_size, local_seq_len,
                                                                                   total_num_large_heads, h_dim)
            heads_small_chunk = heads_small_chunk.permute(order).contiguous().view(batch_size, local_seq_len,
                                                                                   total_num_small_heads, h_dim)
        elif batch_dim_idx == 1:
            #[all2all_buffer_counts, local_seq_len, n_heads,dim,batch]->[local_seq_len,batch,all2all_buffer_counts*n_heads,dim]
            order = [1, 4, 0, 2, 3]
            heads_large_chunk = heads_large_chunk.permute(order).contiguous().view(local_seq_len, batch_size,
                                                                                   total_num_large_heads, h_dim)
            heads_small_chunk = heads_small_chunk.permute(order).contiguous().view(local_seq_len, batch_size,
                                                                                   total_num_small_heads, h_dim)

        output = torch.cat([heads_large_chunk, heads_small_chunk], dim=2).contiguous()

        inp_shape[scatter_idx] = inp_shape[scatter_idx] // seq_world_size
        output_shape=  inp_shape[: gather_idx] + \
            [total_h,] + \
            inp_shape[gather_idx + 1:]

        output = output.view(output_shape)

    return output


def single_all_to_all(input, scatter_idx, gather_idx, batch_dim_idx, group, async_op=False, handle=None, type=None):
    seq_world_size = dist.get_world_size(group)
    # we only need num_heads once
    num_heads = input.shape[2]

    if get_num_kv_heads() is not None or (num_heads % seq_world_size != 0 and not scatter_idx < 2):
        # Assuming here that the number of heads for q is consistent with kv
        # If not, additional logic is required for cases like GQA
        if get_num_kv_heads() is None:
            assert num_heads > seq_world_size, f"Number of heads ({num_heads}) must be larger than sequence parallel size ({seq_world_size})"
            # set heads at first call by num_total_heads.
            # then use ``get_num_kv_heads() is not None`` to re-entry uneven path.
            set_num_kv_heads(num_heads)
        assert async_op == False, "uneven head sp does not support async op"
        return uneven_heads_all2all(input, scatter_idx, gather_idx, batch_dim_idx, group)

    pre_all2all_permute_idx, pre_all2all_inp_shape, post_all2all_permute_idx, post_all2all_res_shape = _generate_layout_params(
        scatter_idx, batch_dim_idx, seq_world_size, input)

    input_t = pre_all2all_fun(pre_all2all_permute_idx, pre_all2all_inp_shape, input)

    post_all2all_fun = post_all2all(post_all2all_permute_idx, post_all2all_res_shape)
    output = torch.empty_like(input_t)
    work = dist.all_to_all_single(output, input_t, group=group, async_op=async_op)

    if async_op:
        if type in ('dq', 'dk'):
            handle[type + '_work'] = work
            handle[type + '_grad'] = output
            handle[type + '_post_all2all_func'] = post_all2all_fun
            return output.view(post_all2all_res_shape)

    res = post_all2all_fun(output)
    return res


class _DimZeroAllToAll(torch.autograd.Function):
    """Differentiable All2All across dimension 0."""

    @staticmethod
    def forward(ctx: Any, group: dist.ProcessGroup, input: Tensor) -> Tensor:
        world_size = dist.get_world_size(group)
        assert input.shape[0] == world_size, f"Dim 0 {input.shape[0]} is not world size"

        ctx.group = group

        output = torch.empty_like(input).contiguous()
        # torch.distributed.nn.functional.all_to_all_single(output, input.contiguous(), group=group)
        dist.all_to_all_single(output, input.contiguous(), group=group)
        return output

    @staticmethod
    def backward(ctx: Any, *grad_output: Tensor) -> Tuple[None, Tensor]:
        return (None, _DimZeroAllToAll.apply(ctx.group, *grad_output))


class _SeqAllToAll(torch.autograd.Function):

    @staticmethod
    def forward(ctx: Any,
                group: dist.ProcessGroup,
                input: Tensor,
                scatter_idx: int,
                gather_idx: int,
                batch_dim_idx: int,
                stream=None,
                handle=None,
                type=None,
                is_fwd=True) -> Tensor:
        ctx.group = group
        ctx.scatter_idx = scatter_idx
        ctx.gather_idx = gather_idx
        ctx.stream = stream
        ctx.handle = handle
        ctx.type = type
        ctx.batch_dim_idx = batch_dim_idx
        if ctx.handle is None:
            res = single_all_to_all(input, scatter_idx, gather_idx, batch_dim_idx, group, False)

        else:
            # overlap communication path
            if not is_fwd and type == 'o':
                assert ctx.stream != None
                res = single_all_to_all(input, scatter_idx, gather_idx, batch_dim_idx, group, False)
                get_accelerator().current_stream().wait_stream(ctx.stream)
                # The computation of d o_weight can overlap with the communication of d o_input

            elif not is_fwd and type in ('q', 'k'):
                # Achieve communication overlap by pipelining the matrix computation and communication of dq, dk, and dv
                type = 'd' + type
                res = single_all_to_all(input, scatter_idx, gather_idx, batch_dim_idx, group, True, handle, type)

            elif is_fwd and type in ('q', 'k'):
                # Achieve communication overlap by pipelining the matrix computation and communication of q, k, and v
                type = 'fwd_' + type
                res = single_all_to_all(input, scatter_idx, gather_idx, batch_dim_idx, group, False, handle, type)

            else:
                res = single_all_to_all(input, scatter_idx, gather_idx, batch_dim_idx, group, False)

        return res

    @staticmethod
    def backward(ctx: Any, *grad_output: Tensor) -> Tuple[None, Tensor, None, None]:

        return (None,
                _SeqAllToAll.apply(ctx.group, *grad_output, ctx.gather_idx, ctx.scatter_idx, ctx.batch_dim_idx,
                                   ctx.stream, ctx.handle, ctx.type, False), None, None, None, None, None, None, None)


class DistributedAttention(torch.nn.Module):
    """Initialization.

    Arguments:
        local_attention (Module): local attention with q,k,v
        sequence_process_group (ProcessGroup): sequence parallel process group
        scatter_idx (int): scatter_idx for all2all comm
        gather_idx (int): gather_idx for all2all comm
    """

    def __init__(
        self,
        local_attention: Module,
        sequence_process_group: dist.ProcessGroup,
        scatter_idx: int = 2,
        gather_idx: int = 0,
        sp_stream=None,
    ) -> None:

        super(DistributedAttention, self).__init__()
        self.local_attn = local_attention
        self.spg = sequence_process_group
        self.scatter_idx = scatter_idx
        self.gather_idx = gather_idx
        self.sp_overlap_comm = False
        self.overlap_handles = None
        self.sp_stream = sp_stream
        if sp_stream is not None:
            self.overlap_handles = {}
            self.sp_overlap_comm = True
            self.default_stream = get_accelerator().default_stream()

    def layer_sync(self, layer):
        if self.sp_overlap_comm and hasattr(layer, 'done_event'):
            self.default_stream.wait_event(layer.done_event)

    def forward(self,
                query: Tensor,
                key: Tensor,
                value: Tensor,
                batch_dim_idx: int,
                rotary_pos_emb=None,
                *args: Any,
                **kwargs) -> Tensor:
        """ forward

        Arguments:
            query (Tensor): query input to the layer
            key (Tensor): key input to the layer
            value (Tensor): value input to the layer
            batch_dim_idx (int): indicating which dim is batch
            args: other args

        Returns:
            * output (Tensor): context output
        """

        # TODO Merge three alltoall calls into one
        # TODO (Reza): change the api on the megatron-deepspeed side so that we only receive all data (q,k, and v) together!
        #in shape : e.g.,  [s/p:h:]

        def bwd_hook(layer_type):

            def pre_hook_fun(grad):
                type = 'd' + layer_type
                self.overlap_handles[type + '_work'].wait()
                self.sp_stream.wait_stream(self.default_stream)
                all2all_output = self.overlap_handles[type + '_grad']
                grad = list(grad)
                grad[0] = self.overlap_handles[type + '_post_all2all_func'](all2all_output)
                grad = tuple(grad)

            return pre_hook_fun

        self.layer_sync(query)
        query_layer = _SeqAllToAll.apply(self.spg, query, self.scatter_idx, self.gather_idx, batch_dim_idx, None,
                                         self.overlap_handles, 'q')
        self.layer_sync(key)
        key_layer = _SeqAllToAll.apply(self.spg, key, self.scatter_idx, self.gather_idx, batch_dim_idx, None,
                                       self.overlap_handles, 'k')
        if self.sp_overlap_comm:
            self.default_stream.wait_stream(self.sp_stream)

        value_layer = _SeqAllToAll.apply(self.spg, value, self.scatter_idx, self.gather_idx, batch_dim_idx, None,
                                         self.overlap_handles, 'v')

        if self.sp_overlap_comm:
            # Register a hook to synchronize dq and dk after the all-to-all
            # operation when the gradient data is used.
            # Place this logic after the q, k, v all-to-all operation to
            # improve interpreter speed to
            # call and launch of the forward all-to-all communication.
            grad_fn_q = query.grad_fn.next_functions[0][0]
            grad_fn_q.register_prehook(bwd_hook(layer_type='q'))
            grad_fn_k = key.grad_fn.next_functions[0][0]
            grad_fn_k.register_prehook(bwd_hook(layer_type='k'))

        #out shape : e.g., [s:h/p:]
        if rotary_pos_emb is not None:
            pos_emb_cos, pos_emb_sin = rotary_pos_emb[0].permute(1, 0, 2, 3), rotary_pos_emb[1].permute(1, 0, 2, 3)
            query_layer = apply_rotary_pos_emb(query_layer, pos_emb_cos, pos_emb_sin)
            key_layer = apply_rotary_pos_emb(key_layer, pos_emb_cos, pos_emb_sin)

        context_layer = self.local_attn(query_layer, key_layer, value_layer, *args, **kwargs)

        output = _SeqAllToAll.apply(self.spg, context_layer, self.gather_idx, self.scatter_idx, batch_dim_idx,
                                    self.sp_stream, self.overlap_handles, 'o')

        #out e.g., [s/p::h]
        return output