File size: 10,188 Bytes
9c6594c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team
import torch
from deepspeed.utils.logging import warning_once
from deepspeed.module_inject.tp_shard import get_shard_size, get_shard_size_list, get_num_kv_heads, get_n_embd, get_num_attention_heads


def split_by_qkvlist_and_refuse(qkv_list, split_size, split_dim=0, cat_dim=0):
    qkv_split_list = [torch.split(mat, split_size, dim=split_dim) for mat in qkv_list]
    tp_fusedqkv_list = [
        torch.cat([qkv_s[i] for qkv_s in qkv_split_list], dim=cat_dim) for i in range(len(qkv_split_list[0]))
    ]
    return tp_fusedqkv_list


def require_tp_fused_qkvw(name, mp_size):
    fused_qkvw_name_list = ['qkv_proj', 'query_key_value', 'attn.Wqkv', 'self_attn.W_pack', 'c_attn']

    if mp_size == 1:
        return False
    for fused_name in fused_qkvw_name_list:
        if fused_name in name:
            return True
    return False


def prepare_tp_fused_qkvw(module, src, mp_size, gpu_index):

    module_str = str(module).strip()
    if src is None:
        return
    fused_type_dict = {
        'CodeGenBlock': 'codegentype',
        'BloomBlock': 'bloomtype',
        'GLMBlock': 'glmtype',
        "MPTBlock": 'glmtype',
        "MptBlock": 'glmtype',
        "BaichuanLayer": 'glmtype',
        "QWenBlock": 'qwentype',
        "FalconDecoderLayer": 'bloomtype',
        "GPTBigCodeBlock": 'bigcodetype',
        "DecoderLayer": 'glmtype',
        "Phi3DecoderLayer": "phi3type"
    }

    def _codegen_type_transpose(input, mp_size, codegen_mp_num=4):
        # codegen_mp_num defined in https://github.com/huggingface/transformers/blob/main/src/transformers/models/codegen/modeling_codegen.py
        assert get_num_kv_heads() % (
            mp_size * codegen_mp_num) == 0, "codgen autoTP requires num_kv_heads % (mp_size*codegen_mp_num) == 0"
        #input : [3*hidden_dim, hidden_dim](weight) or [3*hidden_dim](bias)

        shape = input.shape
        dst_shape = get_shard_size(shape[0], mp_size)
        num_mp_blocks = input.reshape(codegen_mp_num, shape[0] // codegen_mp_num, shape[1])

        #num_mp_blocks : [codegen_mp_num, 3*hidden_dim/codegen_mp_num, :]
        src_split = list(torch.split(num_mp_blocks, num_mp_blocks.shape[1] // 3, dim=1))
        src_split = [x.reshape(codegen_mp_num * mp_size, -1, shape[1]) for x in src_split]

        split_fusedqkv = split_by_qkvlist_and_refuse(src_split, get_shard_size(shape[0] // 3, mp_size), 0, 1)
        tp_fuseqkv_weight = torch.cat(split_fusedqkv, dim=0).reshape(shape[0], -1)

        return tp_fuseqkv_weight[gpu_index * dst_shape:(gpu_index + 1) * dst_shape]

    def _glm_type_transpose(input, mp_size):
        #input : [3*hidden_dim, hidden_dim](weight) or [3*hidden_dim](bias)

        # For chatglm2 & chatglm3(kv_heads=2), need to special handle.
        if get_num_kv_heads() == 2:
            shape = input.shape
            hidden_dim = get_n_embd()
            kv_dim = (shape[0] - hidden_dim) // get_num_kv_heads()
            q = input[:hidden_dim]
            k = input[hidden_dim:hidden_dim + kv_dim]
            v = input[hidden_dim + kv_dim:]
            q_split = q.split(get_shard_size_list(q.shape[0], mp_size), dim=0)
            k_split = k.split(get_shard_size_list(k.shape[0], mp_size), dim=0)
            v_split = v.split(get_shard_size_list(v.shape[0], mp_size), dim=0)
            return torch.cat((q_split[gpu_index], k_split[gpu_index], v_split[gpu_index]), dim=0)
        else:
            shape = input.shape
            src_split = torch.split(input, shape[0] // 3, dim=0)

            split_fusedqkv = split_by_qkvlist_and_refuse(src_split, get_shard_size_list(shape[0] // 3, mp_size))
            return split_fusedqkv[gpu_index]

    def _bloom_type_transpose(input, mp_size):
        shape = input.shape

        split_fusedqkv = input.split(get_shard_size_list(shape[0], mp_size), dim=0)
        return split_fusedqkv[gpu_index]

    def _qwen_type_transpose(input, mp_size, module):
        if not hasattr(module, "_ds_fusedqkv_entered"):
            # Adjust splitting absolute value variables
            setattr(module, "_ds_fusedqkv_entered", True)
            module.attn.split_size = get_shard_size(module.attn.split_size, mp_size)
        return _glm_type_transpose(input, mp_size)

    def _bigcode_type_transpose(input, mp_size):
        n_embd = get_n_embd()
        q = input[:n_embd]
        kv = input[n_embd:]
        shape = q.shape
        split_q = q.split(get_shard_size_list(shape[0], mp_size), dim=0)
        return torch.cat((split_q[gpu_index], kv), dim=0)

    def _phi3_type_transpose(input, mp_size):
        num_kv_heads = get_num_kv_heads()
        num_heads = get_num_attention_heads()
        hidden_size = input.shape[1]
        head_dim = hidden_size // num_heads
        q_pos = input.shape[0] - 2 * num_kv_heads * head_dim
        q = input[:q_pos]
        k = input[q_pos:q_pos + num_kv_heads * head_dim]
        v = input[q_pos + num_kv_heads * head_dim:]
        split_q = q.split(get_shard_size_list(q.shape[0], mp_size), dim=0)
        split_k = k.split(get_shard_size_list(k.shape[0], mp_size), dim=0)
        split_v = v.split(get_shard_size_list(v.shape[0], mp_size), dim=0)
        return torch.cat((split_q[gpu_index], split_k[gpu_index], split_v[gpu_index]), dim=0)

    def _transpose_fused_qkvw(src, mp_size, fused_qkv_type=None, module=None):

        # suppose num_heads=n, q(n)_w means the n-th q head linear weight, the weight format are as following
        # bloomtype: [q(1)_w,k(1)_w,v(1)_w,q(2)_w,k(2)_w,v(2)_w,...,q(n)_w,k(n)_w,v(n)_w]
        # glmtype:  [q(1)_w, q(2)_w,...,q(n)_w,k(1)_w,k(2)_w,...,k(n)_w,v(1)_w,v(2)_w,...,v(n)_w]
        # codegentype: [q(1)_w,q(2)_w,...,q(n/t)_w,k(1)_w,k(2)_w,...,k(n/t)_w,v(1)_2,v(2)_w,...v(n/t)_w,q(n/t+1)_w,...], where t is a const defined in model file.

        if fused_qkv_type == 'bloomtype':
            return _bloom_type_transpose(src, mp_size)
        elif fused_qkv_type == 'codegentype':
            return _codegen_type_transpose(src, mp_size)
        elif fused_qkv_type == 'glmtype':
            return _glm_type_transpose(src, mp_size)
        elif fused_qkv_type == 'qwentype':
            return _qwen_type_transpose(src, mp_size, module)
        elif fused_qkv_type == 'bigcodetype':
            return _bigcode_type_transpose(src, mp_size)
        elif fused_qkv_type == 'phi3type':
            return _phi3_type_transpose(src, mp_size)

        raise ValueError("unknown fused_qkv_type")

    module_name_matches = [k for k in fused_type_dict.keys() if k in module_str]
    if module_name_matches:
        # There can be overlap with matches (e.g., "DecoderLayer" and "FalconDecoderLayer").
        # We take the longest matching module_name
        module_name = max(module_name_matches, key=len)
        fused_type = fused_type_dict[module_name]
        return _transpose_fused_qkvw(src, mp_size, fused_type, module)
    warning_once(f"Unrecognized fusedkqv weight type, default to using bloom type,"
                 f"please check in prepare_tp_fused_qkvw() to avoid potential calculation errors")
    return _bloom_type_transpose(src, mp_size)


# For share qk type:
# q = [q1,...,q_{n/4}, q_{n/2+1},...,q_{3n/4}, k1,...,k_{n/4}, k_{n/2+1},...,k_{3n/4}]
# k = [q_{n/4+1},...,q_{n/2}, q_{3n/4+1},...,qn, k_{n/4+1},...,k_{n/2}, k{3n/4+1},...,kn]
# Avoid modifying the modeling code. We adjust the value and oproj weight to fit this qk type.
def shard_value_with_share_qk(
        weight,
        bias,
        rank,
        world_size,
        shard_value=True  # True -> shard_value; False -> shard_oproj
):
    if shard_value:
        total_size = weight.shape[0]
        weight_cat_dim = 0
    else:
        total_size = weight.shape[1]
        weight_cat_dim = 1
    num_heads = get_num_kv_heads()
    head_dim = total_size // num_heads
    assert (num_heads % world_size == 0)
    if world_size > num_heads // 2:
        RuntimeError(f"world_size {world_size} is larger than half of num_heads {num_heads}")
    head_per_rank = num_heads // world_size
    q_head_start = rank * head_per_rank
    # mapping q_head to v_head
    v_head_ids = []
    i = 0
    # mapping neighbor q_head to v_head
    while i < head_per_rank:
        v_head_ids.append(q_head_start // 2)
        q_head_start += 2
        i = i + 2

    # mapping neighbor k_head to v_head
    v_head_ids.extend([i + num_heads // 2 for i in v_head_ids])
    sharded_weight = []
    sharded_bias = []
    for head_id in v_head_ids:
        if shard_value:
            sharded_weight.append(weight[head_id * head_dim:(head_id + 1) * head_dim])
            if bias is not None:
                sharded_bias.append(bias.data[head_id * head_dim:(head_id + 1) * head_dim])
        else:
            sharded_weight.append(weight[:, head_id * head_dim:(head_id + 1) * head_dim])
    sharded_weight = torch.cat(sharded_weight, dim=weight_cat_dim)
    if bias is not None:
        if shard_value:
            sharded_bias = torch.cat(sharded_bias, dim=0)
        else:
            bias = bias / float(world_size)
        return torch.nn.Parameter(sharded_weight), torch.nn.Parameter(sharded_bias)
    else:
        return torch.nn.Parameter(sharded_weight), None


# For phi3 with chunk mlp, adjust the weight order.
def shard_chunk_mlp(
    weight,
    bias,
    rank,
    world_size,
):
    weight_gate, weight_states = weight.chunk(2, dim=0)
    total_size = weight_gate.shape[0]
    split_weight_gate = weight_gate.split(get_shard_size_list(total_size, world_size, "mlp"), dim=0)
    split_weight_states = weight_states.split(get_shard_size_list(total_size, world_size, "mlp"), dim=0)
    shard_weight = torch.cat((split_weight_gate[rank], split_weight_states[rank]), dim=0)
    if bias is not None:
        bias_gate, bias_states = bias.chunk(2, dim=0)
        split_bias_gate = bias_gate.split(get_shard_size_list(total_size, world_size, "mlp"), dim=0)
        split_bias_states = bias_states.split(get_shard_size_list(total_size, world_size, "mlp"), dim=0)
        return shard_weight, torch.cat((split_bias_gate[rank], split_bias_states[rank]), dim=0)

    return shard_weight, None