|
|
|
|
|
|
|
|
|
|
|
from collections import defaultdict |
|
from typing import Any, Dict, List, Set, Tuple, Union, cast |
|
|
|
import torch |
|
from torch import nn |
|
|
|
from .layer import MoE |
|
|
|
|
|
def has_moe_layers(m: nn.Module) -> Tuple[bool, int]: |
|
has_moe = False |
|
num_experts = 0 |
|
|
|
for module in m.modules(): |
|
if isinstance(module, MoE): |
|
has_moe = True |
|
num_experts = module.num_experts |
|
break |
|
return has_moe, num_experts |
|
|
|
|
|
def is_moe_param(param: torch.Tensor) -> bool: |
|
if hasattr(param, "allreduce") and not param.allreduce: |
|
return True |
|
return False |
|
|
|
|
|
def split_params_into_shared_and_expert_params( |
|
params: List[torch.nn.Parameter]) -> Tuple[List[torch.nn.Parameter], List[torch.nn.Parameter]]: |
|
shared_params: List[nn.Parameter] = [] |
|
expert_params: List[nn.Parameter] = [] |
|
|
|
for p in params: |
|
if is_moe_param(p): |
|
expert_params.append(p) |
|
else: |
|
shared_params.append(p) |
|
return shared_params, expert_params |
|
|
|
|
|
def split_params_grads_into_shared_and_expert_params( |
|
group: List[torch.nn.Parameter]) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: |
|
"""Split grad of parameters into grads of non-expert params |
|
and grads of expert params. This is useful while computing |
|
grad-norms for clipping and overflow detection |
|
|
|
group (List[torch.nn.Parameter]): |
|
Args: |
|
The group of parameters to split |
|
|
|
Returns: |
|
Tuple[List[torch.Tensor], List[torch.Tensor]]: |
|
list of gradients for non MoE params, list of gradients of MoE params |
|
""" |
|
expert_grads: List[torch.Tensor] = [] |
|
shared_grads: List[torch.Tensor] = [] |
|
|
|
for p in group: |
|
if p.grad is not None: |
|
if is_moe_param(p): |
|
expert_grads.append(p.grad.to(p.dtype)) |
|
else: |
|
shared_grads.append(p.grad.to(p.dtype)) |
|
return shared_grads, expert_grads |
|
|
|
|
|
def split_params_into_different_moe_groups_for_optimizer( |
|
param_groups: Union[Dict[str, Any], Tuple[Dict[str, Any], ...], List[Dict[str, Any]]], |
|
max_group_size: Union[int, float] = 178956971) -> List[Dict[str, Any]]: |
|
"""Split parameters into different MoE groups for optimizer |
|
|
|
Args: |
|
param_groups (Union[Dict[str, Any], Tuple[Dict[str, Any], ...], List[Dict[str, Any]]]) |
|
The list of parameter groups to split |
|
|
|
Returns: |
|
List[Dict[str, Any]]: |
|
list of MoE/non-MoE groups for optimizer |
|
""" |
|
if isinstance(param_groups, tuple): |
|
param_groups = list(param_groups) |
|
elif isinstance(param_groups, dict): |
|
param_groups = [param_groups] |
|
elif not isinstance(param_groups, list): |
|
raise ValueError(f"Unknown param group type of {type(param_groups)}") |
|
|
|
|
|
data_parallel_group_names: Set[str] = set() |
|
for param_group in param_groups: |
|
for param in cast(List[nn.Parameter], param_group["params"]): |
|
if is_moe_param(param): |
|
data_parallel_group_names.add(param.group_name) |
|
|
|
|
|
group_moe: Dict[str, Dict[str, Dict[str, Any]]] = defaultdict(lambda: defaultdict(dict)) |
|
for param_group in param_groups: |
|
for key in data_parallel_group_names: |
|
group_moe[param_group['name']][key] = { |
|
**param_group, |
|
'name': key, |
|
'moe': True, |
|
'params': [], |
|
} |
|
|
|
|
|
for param_group in param_groups: |
|
new_params: List[nn.Parameter] = [] |
|
|
|
for param in cast(List[nn.Parameter], param_group['params']): |
|
if is_moe_param(param): |
|
group_moe[param_group['name']][param.group_name]['params'].append(param) |
|
else: |
|
new_params.append(param) |
|
param_group['params'] = new_params |
|
|
|
|
|
if max_group_size is not None: |
|
for moe_group in group_moe.values(): |
|
for param_group in moe_group.values(): |
|
cur_group: List[nn.Parameter] = [] |
|
all_groups: List[List[nn.Parameter]] = [] |
|
size_of_cur_group = 0 |
|
|
|
for param in cast(List[nn.Parameter], param_group['params']): |
|
if size_of_cur_group + param.numel() <= max_group_size: |
|
cur_group.append(param) |
|
size_of_cur_group += param.numel() |
|
else: |
|
all_groups.append(cur_group) |
|
cur_group = [param] |
|
size_of_cur_group = param.numel() |
|
|
|
if cur_group: |
|
all_groups.append(cur_group) |
|
|
|
for group in all_groups: |
|
param_groups.append({**param_group, 'params': group}) |
|
else: |
|
for moe_group in group_moe.values(): |
|
for param_group in moe_group.values(): |
|
param_groups.append(param_group) |
|
|
|
return param_groups |
|
|
|
|
|
def is_moe_param_group(param_group): |
|
return param_group.get('moe', False) |
|
|
|
|
|
def configure_moe_param_groups(model_parameters: List): |
|
assert isinstance(model_parameters, list), "model_parameters must be a list" |
|
|
|
for p in model_parameters: |
|
|
|
|
|
if not isinstance(p, (torch.Tensor, dict)): |
|
raise TypeError("param argument that would be given to the optimizer should be " |
|
f"an iterable of Tensors or dicts, but got {type(p)}") |
|
|
|
|
|
first = model_parameters[0] |
|
|
|
|
|
|
|
if isinstance(first, torch.nn.Parameter): |
|
param_group = {'params': model_parameters, 'name': 'dense-params'} |
|
return split_params_into_different_moe_groups_for_optimizer(param_group) |
|
|
|
|
|
|
|
elif isinstance(first, dict): |
|
|
|
if not any(['moe' in param_group for param_group in model_parameters]): |
|
return split_params_into_different_moe_groups_for_optimizer(model_parameters) |
|
else: |
|
|
|
return model_parameters |
|
|