|
|
|
|
|
|
|
|
|
|
|
from deepspeed import comm as dist |
|
global num_kv_heads |
|
|
|
|
|
def set_num_kv_heads(num): |
|
global num_kv_heads |
|
num_kv_heads = num |
|
|
|
|
|
def set_num_attention_heads(num): |
|
global num_attention_heads |
|
num_attention_heads = num |
|
|
|
|
|
def set_n_embd(num): |
|
global n_embd |
|
n_embd = num |
|
|
|
|
|
def set_tp_grain_size(num): |
|
global tp_grain_size |
|
tp_grain_size = num |
|
|
|
|
|
def get_num_kv_heads(): |
|
global num_kv_heads |
|
if 'num_kv_heads' in globals(): |
|
return num_kv_heads |
|
return None |
|
|
|
|
|
def get_num_attention_heads(): |
|
global num_attention_heads |
|
return num_attention_heads |
|
|
|
|
|
def get_shard_size(total_size, mp_size, name=None, rank=None): |
|
global num_kv_heads |
|
last_linear = ["lm_head", "embed_out"] |
|
|
|
moe_mlp_layer = ["gate_proj", "up_proj", "down_proj", "w1", "w2", "w3"] |
|
not_moe_mlp_layer = True |
|
if name != None and any(s in str(name) for s in moe_mlp_layer): |
|
not_moe_mlp_layer = False |
|
|
|
if rank == None: |
|
rank = dist.get_rank() |
|
if num_kv_heads != None and total_size % num_kv_heads == 0 and "mlp" not in str(name) and str( |
|
name) not in last_linear and not_moe_mlp_layer: |
|
my_slices = (num_kv_heads // mp_size) + (1 if rank < (num_kv_heads % mp_size) else 0) |
|
return total_size * my_slices // num_kv_heads |
|
else: |
|
if total_size >= tp_grain_size: |
|
grain_size = total_size // tp_grain_size |
|
return (grain_size // mp_size + (1 if rank < (grain_size % mp_size) else 0)) * tp_grain_size |
|
else: |
|
return total_size // mp_size + (1 if rank < (total_size % mp_size) else 0) |
|
|
|
|
|
def get_n_embd(): |
|
global n_embd |
|
return n_embd |
|
|
|
|
|
def get_shard_size_list(total_size, mp_size, name=None): |
|
shard_sizes = [] |
|
for i in range(mp_size): |
|
shard_sizes.append(get_shard_size(total_size, mp_size, name, i)) |
|
return shard_sizes |
|
|