|
|
|
|
|
|
|
|
|
import torch |
|
from deepspeed.utils.logging import warning_once |
|
from deepspeed.module_inject.tp_shard import get_shard_size, get_shard_size_list, get_num_kv_heads, get_n_embd, get_num_attention_heads |
|
|
|
|
|
def split_by_qkvlist_and_refuse(qkv_list, split_size, split_dim=0, cat_dim=0): |
|
qkv_split_list = [torch.split(mat, split_size, dim=split_dim) for mat in qkv_list] |
|
tp_fusedqkv_list = [ |
|
torch.cat([qkv_s[i] for qkv_s in qkv_split_list], dim=cat_dim) for i in range(len(qkv_split_list[0])) |
|
] |
|
return tp_fusedqkv_list |
|
|
|
|
|
def require_tp_fused_qkvw(name, mp_size): |
|
fused_qkvw_name_list = ['qkv_proj', 'query_key_value', 'attn.Wqkv', 'self_attn.W_pack', 'c_attn'] |
|
|
|
if mp_size == 1: |
|
return False |
|
for fused_name in fused_qkvw_name_list: |
|
if fused_name in name: |
|
return True |
|
return False |
|
|
|
|
|
def prepare_tp_fused_qkvw(module, src, mp_size, gpu_index): |
|
|
|
module_str = str(module).strip() |
|
if src is None: |
|
return |
|
fused_type_dict = { |
|
'CodeGenBlock': 'codegentype', |
|
'BloomBlock': 'bloomtype', |
|
'GLMBlock': 'glmtype', |
|
"MPTBlock": 'glmtype', |
|
"MptBlock": 'glmtype', |
|
"BaichuanLayer": 'glmtype', |
|
"QWenBlock": 'qwentype', |
|
"FalconDecoderLayer": 'bloomtype', |
|
"GPTBigCodeBlock": 'bigcodetype', |
|
"DecoderLayer": 'glmtype', |
|
"Phi3DecoderLayer": "phi3type" |
|
} |
|
|
|
def _codegen_type_transpose(input, mp_size, codegen_mp_num=4): |
|
|
|
assert get_num_kv_heads() % ( |
|
mp_size * codegen_mp_num) == 0, "codgen autoTP requires num_kv_heads % (mp_size*codegen_mp_num) == 0" |
|
|
|
|
|
shape = input.shape |
|
dst_shape = get_shard_size(shape[0], mp_size) |
|
num_mp_blocks = input.reshape(codegen_mp_num, shape[0] // codegen_mp_num, shape[1]) |
|
|
|
|
|
src_split = list(torch.split(num_mp_blocks, num_mp_blocks.shape[1] // 3, dim=1)) |
|
src_split = [x.reshape(codegen_mp_num * mp_size, -1, shape[1]) for x in src_split] |
|
|
|
split_fusedqkv = split_by_qkvlist_and_refuse(src_split, get_shard_size(shape[0] // 3, mp_size), 0, 1) |
|
tp_fuseqkv_weight = torch.cat(split_fusedqkv, dim=0).reshape(shape[0], -1) |
|
|
|
return tp_fuseqkv_weight[gpu_index * dst_shape:(gpu_index + 1) * dst_shape] |
|
|
|
def _glm_type_transpose(input, mp_size): |
|
|
|
|
|
|
|
if get_num_kv_heads() == 2: |
|
shape = input.shape |
|
hidden_dim = get_n_embd() |
|
kv_dim = (shape[0] - hidden_dim) // get_num_kv_heads() |
|
q = input[:hidden_dim] |
|
k = input[hidden_dim:hidden_dim + kv_dim] |
|
v = input[hidden_dim + kv_dim:] |
|
q_split = q.split(get_shard_size_list(q.shape[0], mp_size), dim=0) |
|
k_split = k.split(get_shard_size_list(k.shape[0], mp_size), dim=0) |
|
v_split = v.split(get_shard_size_list(v.shape[0], mp_size), dim=0) |
|
return torch.cat((q_split[gpu_index], k_split[gpu_index], v_split[gpu_index]), dim=0) |
|
else: |
|
shape = input.shape |
|
src_split = torch.split(input, shape[0] // 3, dim=0) |
|
|
|
split_fusedqkv = split_by_qkvlist_and_refuse(src_split, get_shard_size_list(shape[0] // 3, mp_size)) |
|
return split_fusedqkv[gpu_index] |
|
|
|
def _bloom_type_transpose(input, mp_size): |
|
shape = input.shape |
|
|
|
split_fusedqkv = input.split(get_shard_size_list(shape[0], mp_size), dim=0) |
|
return split_fusedqkv[gpu_index] |
|
|
|
def _qwen_type_transpose(input, mp_size, module): |
|
if not hasattr(module, "_ds_fusedqkv_entered"): |
|
|
|
setattr(module, "_ds_fusedqkv_entered", True) |
|
module.attn.split_size = get_shard_size(module.attn.split_size, mp_size) |
|
return _glm_type_transpose(input, mp_size) |
|
|
|
def _bigcode_type_transpose(input, mp_size): |
|
n_embd = get_n_embd() |
|
q = input[:n_embd] |
|
kv = input[n_embd:] |
|
shape = q.shape |
|
split_q = q.split(get_shard_size_list(shape[0], mp_size), dim=0) |
|
return torch.cat((split_q[gpu_index], kv), dim=0) |
|
|
|
def _phi3_type_transpose(input, mp_size): |
|
num_kv_heads = get_num_kv_heads() |
|
num_heads = get_num_attention_heads() |
|
hidden_size = input.shape[1] |
|
head_dim = hidden_size // num_heads |
|
q_pos = input.shape[0] - 2 * num_kv_heads * head_dim |
|
q = input[:q_pos] |
|
k = input[q_pos:q_pos + num_kv_heads * head_dim] |
|
v = input[q_pos + num_kv_heads * head_dim:] |
|
split_q = q.split(get_shard_size_list(q.shape[0], mp_size), dim=0) |
|
split_k = k.split(get_shard_size_list(k.shape[0], mp_size), dim=0) |
|
split_v = v.split(get_shard_size_list(v.shape[0], mp_size), dim=0) |
|
return torch.cat((split_q[gpu_index], split_k[gpu_index], split_v[gpu_index]), dim=0) |
|
|
|
def _transpose_fused_qkvw(src, mp_size, fused_qkv_type=None, module=None): |
|
|
|
|
|
|
|
|
|
|
|
|
|
if fused_qkv_type == 'bloomtype': |
|
return _bloom_type_transpose(src, mp_size) |
|
elif fused_qkv_type == 'codegentype': |
|
return _codegen_type_transpose(src, mp_size) |
|
elif fused_qkv_type == 'glmtype': |
|
return _glm_type_transpose(src, mp_size) |
|
elif fused_qkv_type == 'qwentype': |
|
return _qwen_type_transpose(src, mp_size, module) |
|
elif fused_qkv_type == 'bigcodetype': |
|
return _bigcode_type_transpose(src, mp_size) |
|
elif fused_qkv_type == 'phi3type': |
|
return _phi3_type_transpose(src, mp_size) |
|
|
|
raise ValueError("unknown fused_qkv_type") |
|
|
|
module_name_matches = [k for k in fused_type_dict.keys() if k in module_str] |
|
if module_name_matches: |
|
|
|
|
|
module_name = max(module_name_matches, key=len) |
|
fused_type = fused_type_dict[module_name] |
|
return _transpose_fused_qkvw(src, mp_size, fused_type, module) |
|
warning_once(f"Unrecognized fusedkqv weight type, default to using bloom type," |
|
f"please check in prepare_tp_fused_qkvw() to avoid potential calculation errors") |
|
return _bloom_type_transpose(src, mp_size) |
|
|
|
|
|
|
|
|
|
|
|
|
|
def shard_value_with_share_qk( |
|
weight, |
|
bias, |
|
rank, |
|
world_size, |
|
shard_value=True |
|
): |
|
if shard_value: |
|
total_size = weight.shape[0] |
|
weight_cat_dim = 0 |
|
else: |
|
total_size = weight.shape[1] |
|
weight_cat_dim = 1 |
|
num_heads = get_num_kv_heads() |
|
head_dim = total_size // num_heads |
|
assert (num_heads % world_size == 0) |
|
if world_size > num_heads // 2: |
|
RuntimeError(f"world_size {world_size} is larger than half of num_heads {num_heads}") |
|
head_per_rank = num_heads // world_size |
|
q_head_start = rank * head_per_rank |
|
|
|
v_head_ids = [] |
|
i = 0 |
|
|
|
while i < head_per_rank: |
|
v_head_ids.append(q_head_start // 2) |
|
q_head_start += 2 |
|
i = i + 2 |
|
|
|
|
|
v_head_ids.extend([i + num_heads // 2 for i in v_head_ids]) |
|
sharded_weight = [] |
|
sharded_bias = [] |
|
for head_id in v_head_ids: |
|
if shard_value: |
|
sharded_weight.append(weight[head_id * head_dim:(head_id + 1) * head_dim]) |
|
if bias is not None: |
|
sharded_bias.append(bias.data[head_id * head_dim:(head_id + 1) * head_dim]) |
|
else: |
|
sharded_weight.append(weight[:, head_id * head_dim:(head_id + 1) * head_dim]) |
|
sharded_weight = torch.cat(sharded_weight, dim=weight_cat_dim) |
|
if bias is not None: |
|
if shard_value: |
|
sharded_bias = torch.cat(sharded_bias, dim=0) |
|
else: |
|
bias = bias / float(world_size) |
|
return torch.nn.Parameter(sharded_weight), torch.nn.Parameter(sharded_bias) |
|
else: |
|
return torch.nn.Parameter(sharded_weight), None |
|
|
|
|
|
|
|
def shard_chunk_mlp( |
|
weight, |
|
bias, |
|
rank, |
|
world_size, |
|
): |
|
weight_gate, weight_states = weight.chunk(2, dim=0) |
|
total_size = weight_gate.shape[0] |
|
split_weight_gate = weight_gate.split(get_shard_size_list(total_size, world_size, "mlp"), dim=0) |
|
split_weight_states = weight_states.split(get_shard_size_list(total_size, world_size, "mlp"), dim=0) |
|
shard_weight = torch.cat((split_weight_gate[rank], split_weight_states[rank]), dim=0) |
|
if bias is not None: |
|
bias_gate, bias_states = bias.chunk(2, dim=0) |
|
split_bias_gate = bias_gate.split(get_shard_size_list(total_size, world_size, "mlp"), dim=0) |
|
split_bias_states = bias_states.split(get_shard_size_list(total_size, world_size, "mlp"), dim=0) |
|
return shard_weight, torch.cat((split_bias_gate[rank], split_bias_states[rank]), dim=0) |
|
|
|
return shard_weight, None |
|
|