File size: 10,188 Bytes
9c6594c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 |
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import torch
from deepspeed.utils.logging import warning_once
from deepspeed.module_inject.tp_shard import get_shard_size, get_shard_size_list, get_num_kv_heads, get_n_embd, get_num_attention_heads
def split_by_qkvlist_and_refuse(qkv_list, split_size, split_dim=0, cat_dim=0):
qkv_split_list = [torch.split(mat, split_size, dim=split_dim) for mat in qkv_list]
tp_fusedqkv_list = [
torch.cat([qkv_s[i] for qkv_s in qkv_split_list], dim=cat_dim) for i in range(len(qkv_split_list[0]))
]
return tp_fusedqkv_list
def require_tp_fused_qkvw(name, mp_size):
fused_qkvw_name_list = ['qkv_proj', 'query_key_value', 'attn.Wqkv', 'self_attn.W_pack', 'c_attn']
if mp_size == 1:
return False
for fused_name in fused_qkvw_name_list:
if fused_name in name:
return True
return False
def prepare_tp_fused_qkvw(module, src, mp_size, gpu_index):
module_str = str(module).strip()
if src is None:
return
fused_type_dict = {
'CodeGenBlock': 'codegentype',
'BloomBlock': 'bloomtype',
'GLMBlock': 'glmtype',
"MPTBlock": 'glmtype',
"MptBlock": 'glmtype',
"BaichuanLayer": 'glmtype',
"QWenBlock": 'qwentype',
"FalconDecoderLayer": 'bloomtype',
"GPTBigCodeBlock": 'bigcodetype',
"DecoderLayer": 'glmtype',
"Phi3DecoderLayer": "phi3type"
}
def _codegen_type_transpose(input, mp_size, codegen_mp_num=4):
# codegen_mp_num defined in https://github.com/huggingface/transformers/blob/main/src/transformers/models/codegen/modeling_codegen.py
assert get_num_kv_heads() % (
mp_size * codegen_mp_num) == 0, "codgen autoTP requires num_kv_heads % (mp_size*codegen_mp_num) == 0"
#input : [3*hidden_dim, hidden_dim](weight) or [3*hidden_dim](bias)
shape = input.shape
dst_shape = get_shard_size(shape[0], mp_size)
num_mp_blocks = input.reshape(codegen_mp_num, shape[0] // codegen_mp_num, shape[1])
#num_mp_blocks : [codegen_mp_num, 3*hidden_dim/codegen_mp_num, :]
src_split = list(torch.split(num_mp_blocks, num_mp_blocks.shape[1] // 3, dim=1))
src_split = [x.reshape(codegen_mp_num * mp_size, -1, shape[1]) for x in src_split]
split_fusedqkv = split_by_qkvlist_and_refuse(src_split, get_shard_size(shape[0] // 3, mp_size), 0, 1)
tp_fuseqkv_weight = torch.cat(split_fusedqkv, dim=0).reshape(shape[0], -1)
return tp_fuseqkv_weight[gpu_index * dst_shape:(gpu_index + 1) * dst_shape]
def _glm_type_transpose(input, mp_size):
#input : [3*hidden_dim, hidden_dim](weight) or [3*hidden_dim](bias)
# For chatglm2 & chatglm3(kv_heads=2), need to special handle.
if get_num_kv_heads() == 2:
shape = input.shape
hidden_dim = get_n_embd()
kv_dim = (shape[0] - hidden_dim) // get_num_kv_heads()
q = input[:hidden_dim]
k = input[hidden_dim:hidden_dim + kv_dim]
v = input[hidden_dim + kv_dim:]
q_split = q.split(get_shard_size_list(q.shape[0], mp_size), dim=0)
k_split = k.split(get_shard_size_list(k.shape[0], mp_size), dim=0)
v_split = v.split(get_shard_size_list(v.shape[0], mp_size), dim=0)
return torch.cat((q_split[gpu_index], k_split[gpu_index], v_split[gpu_index]), dim=0)
else:
shape = input.shape
src_split = torch.split(input, shape[0] // 3, dim=0)
split_fusedqkv = split_by_qkvlist_and_refuse(src_split, get_shard_size_list(shape[0] // 3, mp_size))
return split_fusedqkv[gpu_index]
def _bloom_type_transpose(input, mp_size):
shape = input.shape
split_fusedqkv = input.split(get_shard_size_list(shape[0], mp_size), dim=0)
return split_fusedqkv[gpu_index]
def _qwen_type_transpose(input, mp_size, module):
if not hasattr(module, "_ds_fusedqkv_entered"):
# Adjust splitting absolute value variables
setattr(module, "_ds_fusedqkv_entered", True)
module.attn.split_size = get_shard_size(module.attn.split_size, mp_size)
return _glm_type_transpose(input, mp_size)
def _bigcode_type_transpose(input, mp_size):
n_embd = get_n_embd()
q = input[:n_embd]
kv = input[n_embd:]
shape = q.shape
split_q = q.split(get_shard_size_list(shape[0], mp_size), dim=0)
return torch.cat((split_q[gpu_index], kv), dim=0)
def _phi3_type_transpose(input, mp_size):
num_kv_heads = get_num_kv_heads()
num_heads = get_num_attention_heads()
hidden_size = input.shape[1]
head_dim = hidden_size // num_heads
q_pos = input.shape[0] - 2 * num_kv_heads * head_dim
q = input[:q_pos]
k = input[q_pos:q_pos + num_kv_heads * head_dim]
v = input[q_pos + num_kv_heads * head_dim:]
split_q = q.split(get_shard_size_list(q.shape[0], mp_size), dim=0)
split_k = k.split(get_shard_size_list(k.shape[0], mp_size), dim=0)
split_v = v.split(get_shard_size_list(v.shape[0], mp_size), dim=0)
return torch.cat((split_q[gpu_index], split_k[gpu_index], split_v[gpu_index]), dim=0)
def _transpose_fused_qkvw(src, mp_size, fused_qkv_type=None, module=None):
# suppose num_heads=n, q(n)_w means the n-th q head linear weight, the weight format are as following
# bloomtype: [q(1)_w,k(1)_w,v(1)_w,q(2)_w,k(2)_w,v(2)_w,...,q(n)_w,k(n)_w,v(n)_w]
# glmtype: [q(1)_w, q(2)_w,...,q(n)_w,k(1)_w,k(2)_w,...,k(n)_w,v(1)_w,v(2)_w,...,v(n)_w]
# codegentype: [q(1)_w,q(2)_w,...,q(n/t)_w,k(1)_w,k(2)_w,...,k(n/t)_w,v(1)_2,v(2)_w,...v(n/t)_w,q(n/t+1)_w,...], where t is a const defined in model file.
if fused_qkv_type == 'bloomtype':
return _bloom_type_transpose(src, mp_size)
elif fused_qkv_type == 'codegentype':
return _codegen_type_transpose(src, mp_size)
elif fused_qkv_type == 'glmtype':
return _glm_type_transpose(src, mp_size)
elif fused_qkv_type == 'qwentype':
return _qwen_type_transpose(src, mp_size, module)
elif fused_qkv_type == 'bigcodetype':
return _bigcode_type_transpose(src, mp_size)
elif fused_qkv_type == 'phi3type':
return _phi3_type_transpose(src, mp_size)
raise ValueError("unknown fused_qkv_type")
module_name_matches = [k for k in fused_type_dict.keys() if k in module_str]
if module_name_matches:
# There can be overlap with matches (e.g., "DecoderLayer" and "FalconDecoderLayer").
# We take the longest matching module_name
module_name = max(module_name_matches, key=len)
fused_type = fused_type_dict[module_name]
return _transpose_fused_qkvw(src, mp_size, fused_type, module)
warning_once(f"Unrecognized fusedkqv weight type, default to using bloom type,"
f"please check in prepare_tp_fused_qkvw() to avoid potential calculation errors")
return _bloom_type_transpose(src, mp_size)
# For share qk type:
# q = [q1,...,q_{n/4}, q_{n/2+1},...,q_{3n/4}, k1,...,k_{n/4}, k_{n/2+1},...,k_{3n/4}]
# k = [q_{n/4+1},...,q_{n/2}, q_{3n/4+1},...,qn, k_{n/4+1},...,k_{n/2}, k{3n/4+1},...,kn]
# Avoid modifying the modeling code. We adjust the value and oproj weight to fit this qk type.
def shard_value_with_share_qk(
weight,
bias,
rank,
world_size,
shard_value=True # True -> shard_value; False -> shard_oproj
):
if shard_value:
total_size = weight.shape[0]
weight_cat_dim = 0
else:
total_size = weight.shape[1]
weight_cat_dim = 1
num_heads = get_num_kv_heads()
head_dim = total_size // num_heads
assert (num_heads % world_size == 0)
if world_size > num_heads // 2:
RuntimeError(f"world_size {world_size} is larger than half of num_heads {num_heads}")
head_per_rank = num_heads // world_size
q_head_start = rank * head_per_rank
# mapping q_head to v_head
v_head_ids = []
i = 0
# mapping neighbor q_head to v_head
while i < head_per_rank:
v_head_ids.append(q_head_start // 2)
q_head_start += 2
i = i + 2
# mapping neighbor k_head to v_head
v_head_ids.extend([i + num_heads // 2 for i in v_head_ids])
sharded_weight = []
sharded_bias = []
for head_id in v_head_ids:
if shard_value:
sharded_weight.append(weight[head_id * head_dim:(head_id + 1) * head_dim])
if bias is not None:
sharded_bias.append(bias.data[head_id * head_dim:(head_id + 1) * head_dim])
else:
sharded_weight.append(weight[:, head_id * head_dim:(head_id + 1) * head_dim])
sharded_weight = torch.cat(sharded_weight, dim=weight_cat_dim)
if bias is not None:
if shard_value:
sharded_bias = torch.cat(sharded_bias, dim=0)
else:
bias = bias / float(world_size)
return torch.nn.Parameter(sharded_weight), torch.nn.Parameter(sharded_bias)
else:
return torch.nn.Parameter(sharded_weight), None
# For phi3 with chunk mlp, adjust the weight order.
def shard_chunk_mlp(
weight,
bias,
rank,
world_size,
):
weight_gate, weight_states = weight.chunk(2, dim=0)
total_size = weight_gate.shape[0]
split_weight_gate = weight_gate.split(get_shard_size_list(total_size, world_size, "mlp"), dim=0)
split_weight_states = weight_states.split(get_shard_size_list(total_size, world_size, "mlp"), dim=0)
shard_weight = torch.cat((split_weight_gate[rank], split_weight_states[rank]), dim=0)
if bias is not None:
bias_gate, bias_states = bias.chunk(2, dim=0)
split_bias_gate = bias_gate.split(get_shard_size_list(total_size, world_size, "mlp"), dim=0)
split_bias_states = bias_states.split(get_shard_size_list(total_size, world_size, "mlp"), dim=0)
return shard_weight, torch.cat((split_bias_gate[rank], split_bias_states[rank]), dim=0)
return shard_weight, None
|