Yan Bai
add
55e1701
from .base import (
MemEstimator,
set_global_config,
get_tensor_model_parallel_world_size,
get_tensor_model_parallel_rank,
cum_mul,
get_expert_tensor_parallel_world_size,
get_expert_tensor_parallel_rank,
get_pipeline_model_parallel_world_size,
get_pipeline_model_parallel_rank,
get_expert_model_parallel_rank,
get_expert_model_parallel_world_size,
is_pipeline_first_stage,
is_pipeline_last_stage,
_addindent,
colored,
)
from megatron.core.transformer.spec_utils import ModuleSpec
from typing import Dict, Literal, Optional, Union
from megatron.core.transformer.transformer_config import (
TransformerConfig,
MLATransformerConfig,
)
from megatron.core.model_parallel_config import ModelParallelConfig
from megatron.core.tensor_parallel.utils import VocabUtility
from megatron.core.transformer.transformer_block import (
TransformerBlockSubmodules,
)
from megatron.core.models.common.embeddings import (
_yarn_get_mscale,
apply_rotary_pos_emb,
)
from megatron.core.extensions.transformer_engine import (
_get_extra_te_kwargs,
get_expert_parallel_rng_tracker_name,
condition_init_method,
)
from megatron.core.transformer.enums import AttnMaskType
from megatron.core.transformer.mlp import MLPSubmodules
from megatron.core.utils import divide
from megatron.core.transformer.spec_utils import import_module
from megatron.core.transformer import transformer_layer
import types, math
import warnings
from copy import deepcopy
class LanguageModelEmbedding(MemEstimator):
def __init__(
self,
config: TransformerConfig,
vocab_size: int,
max_sequence_length: int,
position_embedding_type: Literal[
"learned_absolute", "rope", "none"
] = "learned_absolute",
num_tokentypes: int = 0,
):
super().__init__()
self.config: TransformerConfig = config
self.vocab_size: int = vocab_size
self.max_sequence_length: int = max_sequence_length
self.add_position_embedding: bool = (
position_embedding_type == "learned_absolute"
)
self.num_tokentypes = num_tokentypes
self.reduce_scatter_embeddings = (
(not self.add_position_embedding)
and self.num_tokentypes <= 0
and self.config.sequence_parallel
)
# Word embeddings (parallel).
self.word_embeddings = VocabParallelEmbedding(
num_embeddings=self.vocab_size,
embedding_dim=self.config.hidden_size,
init_method=self.config.init_method,
reduce_scatter_embeddings=self.reduce_scatter_embeddings,
config=self.config,
)
# TODO if self.add_position_embedding:
# TODO if self.num_tokentypes > 0:
self.embedding_dropout = Dropout(self.config.hidden_dropout)
def num_parameter(self):
ret = self.word_embeddings.num_parameter()
ret += self.embedding_dropout.num_parameter()
return ret
def num_activation(self, input_shape: list[int]):
ret = self.word_embeddings.num_activation(input_shape)
input_shape = self.word_embeddings.mock_forward(input_shape)
ret += self.embedding_dropout.num_activation(input_shape)
return ret
def mock_forward(self, input_shape: list[int]):
input_shape = self.word_embeddings.mock_forward(input_shape)
return input_shape
class VocabParallelEmbedding(MemEstimator):
def __init__(
self,
num_embeddings: int,
embedding_dim: int,
*,
init_method,
reduce_scatter_embeddings: bool = False,
config: ModelParallelConfig,
):
super().__init__()
# Keep the input dimensions.
self.num_embeddings = num_embeddings
self.embedding_dim = embedding_dim
self.reduce_scatter_embeddings = reduce_scatter_embeddings
self.tensor_model_parallel_size = get_tensor_model_parallel_world_size()
# Divide the weight matrix along the vocaburaly dimension.
(self.vocab_start_index, self.vocab_end_index) = (
VocabUtility.vocab_range_from_global_vocab_size(
self.num_embeddings,
get_tensor_model_parallel_rank(),
self.tensor_model_parallel_size,
)
)
self.num_embeddings_per_partition = (
self.vocab_end_index - self.vocab_start_index
)
self.deterministic_mode = config.deterministic_mode
self.weight = (self.num_embeddings_per_partition, self.embedding_dim)
def num_parameter(self):
return self.weight[0] * self.weight[1]
def num_activation(self, input_shape: list[int]):
return cum_mul(input_shape) * self.weight[1]
def mock_forward(self, input_shape: list[int]):
return input_shape + [self.weight[1]]
class Dropout(MemEstimator):
def __init__(self, p=0, *args, **kwargs):
super().__init__()
self.p = p
def num_parameter(self):
return 0
def num_activation(self, input_shape: list[int]):
if self.p == 0:
return 0
return cum_mul(input_shape[:])
def mock_forward(self, input_shape: list[int]):
return input_shape
class ColumnParallelLinear(MemEstimator):
def __init__(
self,
input_size,
output_size,
*,
config: ModelParallelConfig,
init_method,
bias=True,
gather_output=False,
stride=1,
keep_master_weight_for_test=False,
skip_bias_add=False,
skip_weight_param_allocation: bool = False,
embedding_activation_buffer=None,
grad_output_buffer=None,
is_expert: bool = False,
tp_comm_buffer_name: str = None, # Not used
disable_grad_reduce: bool = False,
is_mla: bool = False,
):
super().__init__()
if is_mla and config.sequence_parallel:
tp_size = get_tensor_model_parallel_world_size()
output_size = divide(output_size, tp_size)
parallel_mode = None
tp_size = 1
tp_group = None
# Keep input parameters
self.input_size = input_size
self.output_size = output_size
self.gather_output = gather_output
# Divide the weight matrix along the last dimension.
self.skip_bias_add = skip_bias_add
self.is_expert = is_expert
self.expert_parallel = config.expert_model_parallel_size > 1
self.embedding_activation_buffer = embedding_activation_buffer
self.grad_output_buffer = grad_output_buffer
self.config = config
self.disable_grad_reduce = disable_grad_reduce
if is_expert:
world_size = get_expert_tensor_parallel_world_size()
rank = get_expert_tensor_parallel_rank()
else:
world_size = get_tensor_model_parallel_world_size()
rank = get_tensor_model_parallel_rank()
self.output_size_per_partition = divide(output_size, world_size)
# Parameters.
# Note: torch.nn.functional.linear performs XA^T + b and as a result
# we allocate the transpose.
# Initialize weight.
if not skip_weight_param_allocation:
self.weight = (self.output_size_per_partition, self.input_size)
else:
self.weight = (self.output_size_per_partition, self.input_size)
if bias:
self.bias = [self.output_size_per_partition]
else:
self.bias = None
self.sequence_parallel = config.sequence_parallel
if self.sequence_parallel and world_size <= 1:
warnings.warn(
"`sequence_parallel` is set to `True`, but tensor model parallel size "
f"is {world_size}. Disabling sequence parallel."
)
self.sequence_parallel = False
self.allreduce_dgrad = (
world_size > 1
and not self.sequence_parallel
and not self.disable_grad_reduce
)
self.gradient_accumulation_fusion = config.gradient_accumulation_fusion
def num_parameter(self):
ret = cum_mul(self.weight)
if self.bias is not None:
ret += self.bias[0]
return ret
def num_activation(self, input_shape: list[int]):
return cum_mul(input_shape[:-1]) * self.weight[0]
def mock_forward(self, input_shape: list[int]):
assert self.weight[-1] == input_shape[-1]
return input_shape[:-1] + [self.weight[0]]
class RowParallelLinear(MemEstimator):
def __init__(
self,
input_size: int,
output_size: int,
*,
config: ModelParallelConfig,
init_method,
bias: bool,
input_is_parallel: bool,
skip_bias_add: bool,
stride: int = 1,
keep_master_weight_for_test: bool = False,
is_expert: bool = False,
tp_comm_buffer_name: str = None, # Not used
):
super().__init__()
# Keep input parameters
self.input_size = input_size
self.output_size = output_size
self.input_is_parallel = input_is_parallel
self.skip_bias_add = skip_bias_add
self.config = config
self.is_expert = is_expert
self.expert_parallel = config.expert_model_parallel_size > 1
self.gradient_accumulation_fusion = config.gradient_accumulation_fusion
self.sequence_parallel = config.sequence_parallel
if self.sequence_parallel and not self.input_is_parallel:
raise RuntimeError(
"To enable `sequence_parallel`, `input_is_parallel` must be `True`"
)
# Divide the weight matrix along the last dimension.
if self.is_expert:
world_size = get_expert_tensor_parallel_world_size()
rank = get_expert_tensor_parallel_rank()
else:
world_size = get_tensor_model_parallel_world_size()
rank = get_tensor_model_parallel_rank()
self.input_size_per_partition = divide(input_size, world_size)
self.weight = (self.output_size, self.input_size_per_partition)
if bias:
self.bias = [self.output_size]
else:
self.bias = None
def num_parameter(self):
ret = cum_mul(self.weight)
if self.bias is not None:
ret += self.bias[0]
return ret
def num_activation(self, input_shape: list[int]):
return cum_mul(input_shape[:-1]) * self.weight[1]
def mock_forward(self, input_shape: list[int]):
assert self.weight[0] == input_shape[-1]
return input_shape[:-1] + [self.weight[1]]
class RMSNorm(MemEstimator):
def __init__(self, hidden_size: int, *args, **kwargs):
super().__init__()
self.weight = hidden_size
def num_parameter(self):
return self.weight
def num_activation(self, input_shape: list[int]):
return cum_mul(input_shape[:])
def mock_forward(self, input_shape: list[int]):
return input_shape
class GetBiasDropoutAdd(MemEstimator):
def __init__(self, *args, **kwargs):
super().__init__()
def num_parameter(self):
return 0
def num_activation(self, input_shape: list[int]):
return cum_mul(input_shape[:])
def mock_forward(self, input_shape: list[int]):
return input_shape
get_bias_dropout_add = GetBiasDropoutAdd()
class MLP(MemEstimator):
def __init__(
self,
config: TransformerConfig,
submodules,
is_expert: bool = False,
input_size: int = None,
):
super().__init__()
self.config: TransformerConfig = config
self.input_size = input_size if input_size != None else self.config.hidden_size
# If this is a gated linear unit we double the output width, see https://arxiv.org/pdf/2002.05202.pdf
ffn_hidden_size = self.config.ffn_hidden_size
if self.config.gated_linear_unit:
ffn_hidden_size *= 2
self.linear_fc1 = build_module(
submodules.linear_fc1,
self.input_size,
ffn_hidden_size,
config=self.config,
init_method=self.config.init_method,
gather_output=False,
bias=self.config.add_bias_linear,
skip_bias_add=True,
is_expert=is_expert,
tp_comm_buffer_name="fc1",
)
self.activation_func = self.config.activation_func
self.linear_fc2 = build_module(
submodules.linear_fc2,
self.config.ffn_hidden_size,
self.config.hidden_size,
config=self.config,
init_method=self.config.output_layer_init_method,
bias=self.config.add_bias_linear,
input_is_parallel=True,
skip_bias_add=True,
is_expert=is_expert,
tp_comm_buffer_name="fc2",
)
def num_parameter(self):
return self.linear_fc1.num_parameter() + self.linear_fc2.num_parameter()
def num_activation(self, input_shape: list[int]):
result = 0
result += self.linear_fc1.num_activation(input_shape)
intermediate_shape = self.linear_fc1.mock_forward(input_shape)
result += cum_mul(intermediate_shape) / 2 # activation layer
self.linear_fc2.num_activation(intermediate_shape)
return result
def mock_forward(self, input_shape: list[int]):
intermediate_shape = self.linear_fc1.mock_forward(input_shape)
output_shape = self.linear_fc2.mock_forward(intermediate_shape)
return output_shape
class ModuleList(MemEstimator):
def __init__(self, modules: list[MemEstimator] = None):
super().__init__()
if modules is None:
modules = []
self.modules = modules
def __repr__(self):
"""Return a custom repr for ModuleList that compresses repeated module representations."""
list_of_reprs = [repr(item) for item in self.modules]
if len(list_of_reprs) == 0:
return self._get_name() + "()"
start_end_indices = [[0, 0]]
repeated_blocks = [list_of_reprs[0]]
for i, r in enumerate(list_of_reprs[1:], 1):
if r == repeated_blocks[-1]:
start_end_indices[-1][1] += 1
continue
start_end_indices.append([i, i])
repeated_blocks.append(r)
lines = []
stat = (
"\t/* n_params="
+ colored(f"{self.num_parameter()/1024/1024:.2f}M", "red")
+ "\tn_act="
+ colored(f"{self.num_activation()/1024/1024:.2f}M", "green")
+ " */"
)
main_str = self._get_name() + stat + " ("
for (start_id, end_id), b in zip(start_end_indices, repeated_blocks):
local_repr = f"({start_id}): {b}" # default repr
if start_id != end_id:
n = end_id - start_id + 1
local_repr = f"({start_id}-{end_id}): {n} x {b}"
local_repr = _addindent(local_repr, 2)
lines.append(local_repr)
main_str += "\n " + "\n ".join(lines) + "\n"
main_str += ")"
return main_str
def dump(self):
list_of_reprs = [repr(item) for item in self.modules]
if len(list_of_reprs) == 0:
return self._get_name() + "()"
list_of_dumps = [item.dump() for item in self.modules]
start_end_indices = [[0, 0]]
repeated_blocks = [list_of_reprs[0]]
repeated_blocks_dump = [list_of_dumps[0]]
for i, r in enumerate(list_of_reprs[1:], 1):
if r == repeated_blocks[-1]:
start_end_indices[-1][1] += 1
continue
start_end_indices.append([i, i])
repeated_blocks.append(r)
repeated_blocks_dump(list_of_dumps[i])
modules = {}
for (start_id, end_id), b in zip(start_end_indices, repeated_blocks_dump):
key = f"({start_id})"
if start_id != end_id:
n = end_id - start_id + 1
key = f"({start_id}-{end_id}) {n} layers"
modules[key] = b
ret = {}
ret["name"] = self._get_name()
ret["n_params"] = self.num_parameter()
ret["n_act"] = self.num_activation()
if len(modules) > 0:
ret["modules"] = modules
return ret
def append(self, m: MemEstimator):
self.modules.append(m)
def __len__(
self,
):
return self.modules.__len__()
def num_parameter(self):
return sum([x.num_parameter() for x in self.modules])
def num_activation(self, input_shape: list[int]):
result = 0
for m in self.modules:
result += m.num_activation(input_shape)
input_shape = m.mock_forward(input_shape)
return result
def mock_forward(self, input_shape: list[int]):
for m in self.modules:
result += m.num_activation(input_shape)
input_shape = m.mock_forward(input_shape)
return input_shape
class SequentialMLP(MemEstimator):
def __init__(self, num_local_experts, config: TransformerConfig, submodules):
super().__init__()
self.config = config
self.add_bias = config.add_bias_linear
self.moe_extended_tp = config.moe_extended_tp
self.num_local_experts = num_local_experts
self.local_experts = ModuleList()
for _ in range(self.num_local_experts):
expert = MLP(self.config, submodules, is_expert=True)
self.local_experts.append(expert)
def num_parameter(self):
return self.local_experts.num_parameter()
def num_activation(self, input_shape: list[int], tokens_per_expert=None):
# assume all the inputs are routed equally
all_tokens = input_shape[1]
result = 0
for m in self.local_experts.modules:
result += m.num_activation(
input_shape[:1]
+ [all_tokens // self.num_local_experts]
+ input_shape[2:]
)
return result
def mock_forward(self, input_shape: list[int], tokens_per_expert=None):
# assume all the inputs are routed to the first expert
input_shape = self.local_experts.modules[0].mock_forward(input_shape)
return input_shape
class TEGroupedMLP(MemEstimator):
"""An efficient implementation of the Experts layer using TE's GroupedLinear.
Executes multiple experts in parallel to maximize computational efficiency.
"""
def __init__(self, num_local_experts, config: TransformerConfig, submodules):
super().__init__()
self.config = config
self.moe_extended_tp = config.moe_extended_tp
self.num_local_experts = num_local_experts
self.input_size = self.config.hidden_size
# Double the output width with gated linear unit, see https://arxiv.org/pdf/2002.05202.pdf
ffn_hidden_size = self.config.moe_ffn_hidden_size
if self.config.gated_linear_unit:
ffn_hidden_size *= 2
self.linear_fc1 = build_module(
submodules.linear_fc1,
self.num_local_experts,
self.input_size,
ffn_hidden_size,
config=self.config,
init_method=self.config.init_method,
bias=self.config.add_bias_linear,
skip_bias_add=True,
is_expert=True,
tp_comm_buffer_name="fc1",
)
self.activation_func = self.config.activation_func
self.linear_fc2 = build_module(
submodules.linear_fc2,
self.num_local_experts,
self.config.moe_ffn_hidden_size,
self.config.hidden_size,
config=self.config,
init_method=self.config.output_layer_init_method,
bias=self.config.add_bias_linear,
skip_bias_add=True,
is_expert=True,
tp_comm_buffer_name="fc2",
)
# TODO if self.config.fp8:
def num_parameter(self):
ret = self.linear_fc1.num_parameter()
ret += self.linear_fc2.num_parameter()
return ret
def num_activation(self, input_shape: list[int], tokens_per_expert=None):
ret = 0
ret += self.linear_fc1.num_activation(input_shape)
input_shape = self.linear_fc1.mock_forward(input_shape)
# activation
ret += cum_mul(input_shape) / 2 # swiglu or gelu
input_shape = deepcopy(input_shape)
input_shape[-1] //= 2
self.linear_fc2.num_activation(input_shape)
return ret
def mock_forward(self, input_shape: list[int], tokens_per_expert=None):
# assume all the inputs are routed to the first expert
input_shape = self.local_experts.modules[0].mock_forward(input_shape)
return input_shape
class TEGroupedLinear(MemEstimator):
def __init__(
self,
num_gemms: int,
input_size: int,
output_size: int,
*,
parallel_mode: str,
config: ModelParallelConfig,
init_method,
bias: bool,
skip_bias_add: bool,
is_expert: bool = False,
tp_comm_buffer_name: str = None,
):
super().__init__()
self.config = config
# TE returns a zero length Tensor when bias=False and
# return_bias=True, but we prefer None. So in that case we
# tell TE to not return the bias, and return None
# ourselves. This way our forward always returns two values
# and we don't have to deal with the zero length Tensor.
self.te_return_bias = skip_bias_add and bias
self.is_first_microbatch = True
self.disable_parameter_transpose_cache = (
self.config.disable_parameter_transpose_cache
)
extra_kwargs = _get_extra_te_kwargs(config)
extra_kwargs["ub_name"] = tp_comm_buffer_name
self.expert_parallel = self.config.expert_model_parallel_size > 1
if self.expert_parallel:
extra_kwargs["rng_tracker_name"] = get_expert_parallel_rng_tracker_name()
# For MoE models, the comms between TP and EP group is explicitly handled by
# MoE token dispatcher. So we disable comms by making TE agnostic of model parallel.
self.explicit_expert_comm = is_expert and (
config.tensor_model_parallel_size > 1 or self.expert_parallel
)
if is_expert:
tp_size = get_expert_tensor_parallel_world_size()
else:
tp_size = get_tensor_model_parallel_world_size()
if self.explicit_expert_comm:
if parallel_mode == "column":
output_size = divide(output_size, tp_size)
elif parallel_mode == "row":
input_size = divide(input_size, tp_size)
parallel_mode = None
tp_size = 1
assert not bias, "bias is not considered for now"
self.num_gemms = num_gemms
self.input_size = input_size
self.output_size = output_size
def num_parameter(self):
ret = self.num_gemms * self.input_size * self.output_size
return ret
def num_activation(self, input_shape: list[int], tokens_per_expert=None):
ret = cum_mul(self.mock_forward(input_shape))
return ret
def mock_forward(self, input_shape: list[int], tokens_per_expert=None):
return input_shape[:-1] + [self.output_size]
class TEColumnParallelGroupedLinear(TEGroupedLinear):
def __init__(
self,
num_gemms: int,
input_size: int,
output_size: int,
*,
config: ModelParallelConfig,
init_method,
bias: bool,
skip_bias_add: bool,
is_expert: bool,
tp_comm_buffer_name: str = None,
):
super().__init__(
num_gemms=num_gemms,
input_size=input_size,
output_size=output_size,
parallel_mode="column",
config=config,
init_method=condition_init_method(config, init_method),
bias=bias,
skip_bias_add=skip_bias_add,
is_expert=is_expert,
tp_comm_buffer_name=tp_comm_buffer_name,
)
class TERowParallelGroupedLinear(TEGroupedLinear):
def __init__(
self,
num_gemms: int,
input_size: int,
output_size: int,
*,
config: ModelParallelConfig,
init_method,
bias: bool,
skip_bias_add: bool,
is_expert: bool,
tp_comm_buffer_name: str = None,
):
super().__init__(
num_gemms=num_gemms,
input_size=input_size,
output_size=output_size,
parallel_mode="row",
config=config,
init_method=condition_init_method(config, init_method),
bias=bias,
skip_bias_add=skip_bias_add,
is_expert=is_expert,
tp_comm_buffer_name=tp_comm_buffer_name,
)
class SharedExpertMLP(MLP):
"""
MLP layer for Shared Experts.
"""
def __init__(self, config: TransformerConfig, spec: ModuleSpec):
config = deepcopy(config)
assert (
config.add_bias_linear == False
), "bias is not supported in the shared experts, "
"please set '--disable-bias-linear' instead."
config.ffn_hidden_size = config.moe_shared_expert_intermediate_size
super().__init__(config=config, submodules=spec.submodules)
self.use_shared_expert_gate = spec.params.get("gate", False)
if self.use_shared_expert_gate:
assert False, "use_shared_expert_gate is not Implemented"
# self.gate_weight = torch.nn.Parameter(torch.empty((1, self.config.hidden_size)))
# if config.perform_initialization:
# if get_cuda_rng_tracker().is_initialized():
# with get_cuda_rng_tracker().fork(get_data_parallel_rng_tracker_name()):
# config.init_method(self.gate_weight)
# else:
# config.init_method(self.gate_weight)
# self.gate_weight.data = self.gate_weight.data.to(dtype=config.params_dtype)
# setattr(self.gate_weight, 'sequence_parallel', self.config.sequence_parallel)
else:
self.gate_weight = None
class TransformerBlock(MemEstimator):
"""Transformer class."""
def __init__(
self,
config: TransformerConfig,
spec: Union[TransformerBlockSubmodules, ModuleSpec],
post_layer_norm: bool = True,
pre_process: bool = True,
post_process: bool = True,
):
super().__init__()
self.config = config
self.submodules = _get_block_submodules(config, spec)
self.post_layer_norm = post_layer_norm
self.pre_process = pre_process
self.post_process = post_process
self.cuda_graphs = {}
self.current_microbatch = -1
self.input_tensor = None
self.checkpoint_core_attention = (
self.config.recompute_granularity == "selective"
)
self._build_layers()
self.num_layers_per_pipeline_rank = len(self.layers)
self.tp_only_amax_red = config.tp_only_amax_red
def _build_layers(self):
def build_layer(layer_spec, layer_number):
return build_module(
layer_spec, config=self.config, layer_number=layer_number
)
# offset is implicit in TransformerLayer
self.layers = ModuleList(
[
build_layer(layer_spec, i + 1)
for i, layer_spec in enumerate(self.submodules.layer_specs)
]
)
if self.submodules.layer_norm and self.post_process and self.post_layer_norm:
self.final_layernorm = build_module(
self.submodules.layer_norm,
config=self.config,
hidden_size=self.config.hidden_size,
eps=self.config.layernorm_epsilon,
)
else:
self.final_layernorm = None # Either this or nn.Identity
def num_parameter(self):
ret = self.layers.num_parameter()
if self.final_layernorm is not None:
ret += self.final_layernorm.num_parameter()
return ret
def num_activation(self, input_shape: list[int]):
result = self.layers.num_activation(input_shape)
if self.final_layernorm is not None:
result += self.final_layernorm.num_activation(input_shape)
return result
def mock_forward(self, input_shape: list[int]):
return input_shape
class TopKRouter(MemEstimator):
def __init__(self, config: TransformerConfig) -> None:
super().__init__()
self.config = config
self.topk = self.config.moe_router_topk
self.routing_type = self.config.moe_router_load_balancing_type
self.input_jitter = None
def num_parameter(self):
return 0
def num_activation(self, input_shape: list[int]):
result = cum_mul(input_shape) * 2 # sinkhorn and sinkhorn activation
return result
def mock_forward(self, input_shape: list[int]):
return input_shape[:-1] + [self.topk]
class MoELayer(MemEstimator):
def __init__(
self, config: TransformerConfig, submodules=None, layer_number: int = None
):
super().__init__()
self.config = config
self.submodules = submodules
self.moe_layer_recompute = config.moe_layer_recompute
self.expert_parallel_size = get_expert_model_parallel_world_size()
assert (
self.expert_parallel_size > 0
), "Expected non-negative expert parallel size"
assert self.config.num_moe_experts % self.expert_parallel_size == 0
self.num_local_experts = (
self.config.num_moe_experts // self.expert_parallel_size
)
local_expert_indices_offset = (
get_expert_model_parallel_rank() * self.num_local_experts
)
self.router = TopKRouter(config=self.config)
self.use_shared_expert = (
self.config.moe_shared_expert_intermediate_size is not None
)
self.shared_expert_overlap = self.config.moe_shared_expert_overlap
self.local_expert_indices = [
local_expert_indices_offset + i for i in range(self.num_local_experts)
]
assert all(
map(lambda x: x < self.config.num_moe_experts, self.local_expert_indices)
)
self.experts = None
self.shared_experts = None
self.token_dispatcher = None
self.layer_number = layer_number
# Initialize experts
self.experts = build_module(
self.submodules.experts, self.num_local_experts, self.config
)
# Initialize shared experts
if self.use_shared_expert:
self.shared_experts = SharedExpertMLP(
self.config, self.submodules.shared_experts
)
# if self.shared_expert_overlap:
# self.token_dispatcher.set_shared_experts(self.shared_experts)
def num_parameter(self):
ret = self.experts.num_parameter() + self.router.num_parameter()
if self.use_shared_expert:
ret += self.shared_experts.num_parameter()
return ret
def num_activation(self, input_shape: list[int]):
result = self.router.num_activation(input_shape)
result += cum_mul(input_shape) * self.router.topk # token dispatcher
moe_input_shape_average = deepcopy(input_shape)
moe_input_shape_average[1] = int(moe_input_shape_average[1] * self.router.topk)
result += self.experts.num_activation(moe_input_shape_average)
if self.use_shared_expert:
result += self.shared_experts.num_activation(input_shape)
if self.config.moe_layer_recompute:
result = cum_mul(input_shape) * 2
return result
def mock_forward(self, input_shape: list[int]):
return input_shape
class IdentityOp(MemEstimator):
def num_parameter(self):
return 0
def num_activation(self, input_shape: list[int]):
return 0
def mock_forward(self, input_shape: list[int]):
return input_shape
IdentityFuncOp = IdentityOp
TERowParallelLinear = RowParallelLinear
TEColumnParallelLinear = ColumnParallelLinear
TELayerNormColumnParallelLinear = ColumnParallelLinear
class TEDotProductAttention(MemEstimator):
def __init__(self, config: TransformerConfig, *args, **kwargs):
super().__init__()
self.config = config
def num_parameter(self):
return 0
def num_activation(
self, q_shape: list[int], k_shape: list[int], v_shape: list[int]
):
bs, seqs, heads, dim = q_shape
if self.config.multi_latent_attention and False:
result = bs * seqs * seqs * heads
else:
bs, seqs, heads, dim = k_shape
result = (
bs * seqs * dim * heads * 2 # * self.config.tensor_model_parallel_size
) # flash attention
if self.config.context_parallel_size > 1:
result *= 2
return result
def mock_forward(
self,
hidden_size: int,
q_shape: list[int],
k_shape: list[int],
v_shape: list[int],
):
seqs, bs, heads, dim = q_shape
return [seqs, bs, hidden_size]
class TransformerLayer(MemEstimator):
def __init__(
self,
config: TransformerConfig,
submodules,
layer_number: int = 1,
hidden_dropout: float = None,
):
super().__init__()
self.config = config
if config.enable_cuda_graph and self.training:
assert (
not config.cpu_offloading and config.recompute_granularity is None
), "Cudagraphs not supported"
self.cudagraph_manager = CudaGraphManager()
self.submodules_config = submodules
self.layer_number = layer_number + get_transformer_layer_offset(self.config)
self.hidden_dropout = (
config.hidden_dropout if hidden_dropout is None else hidden_dropout
)
# [Module 1: Input Layernorm] Optional Layernorm on the input data
# TODO: add pytorch only layernorm
self.input_layernorm = build_module(
submodules.input_layernorm,
config=self.config,
hidden_size=self.config.hidden_size,
eps=self.config.layernorm_epsilon,
)
# [Module 2: SelfAttention]
self.self_attention = build_module(
submodules.self_attention, config=self.config, layer_number=layer_number
)
# [Module 3: BiasDropoutFusion]
self.self_attn_bda = build_module(submodules.self_attn_bda)
# [Module 4: Post SelfAttention] Optional Layernorm after self-attn
self.pre_cross_attn_layernorm = build_module(
submodules.pre_cross_attn_layernorm,
config=self.config,
hidden_size=self.config.hidden_size,
eps=self.config.layernorm_epsilon,
)
# [Module 5: CrossAttention]
self.cross_attention = build_module(
submodules.cross_attention, config=self.config, layer_number=layer_number
)
# [Module 6: BiasDropoutFusion]
self.cross_attn_bda = build_module(
submodules.cross_attn_bda, config=self.config
)
# [Module 7: Pre MLP] Optional Layernorm before MLP
self.pre_mlp_layernorm = build_module(
submodules.pre_mlp_layernorm,
config=self.config,
hidden_size=self.config.hidden_size,
eps=self.config.layernorm_epsilon,
)
# [Module 8: MLP block]
self.mlp = build_module(submodules.mlp, config=self.config)
if hasattr(self.mlp, "set_layer_number"):
self.mlp.set_layer_number(self.layer_number)
# [Module 9: BiasDropoutFusion]
self.mlp_bda = build_module(submodules.mlp_bda)
def num_parameter(self):
result = self.input_layernorm.num_parameter()
result += self.self_attention.num_parameter()
result += self.pre_cross_attn_layernorm.num_parameter()
result += self.cross_attention.num_parameter()
result += self.cross_attn_bda.num_parameter()
result += self.pre_mlp_layernorm.num_parameter()
result += self.mlp.num_parameter()
return result
def num_activation(self, input_shape: list[int]):
result = 0
result += self.self_attention.num_activation(input_shape)
result += self.mlp.num_activation(input_shape)
# __import__('ipdb').set_trace()
# sequence parallel
if self.config.sequence_parallel and self.config.tensor_model_parallel_size > 1:
input_shape = deepcopy(input_shape)
input_shape[1] /= self.config.tensor_model_parallel_size
result += self.input_layernorm.num_activation(input_shape)
result += self.pre_mlp_layernorm.num_activation(input_shape)
result += self.self_attn_bda.num_activation(input_shape)
result += self.mlp_bda.num_activation(input_shape)
return result
def mock_forward(self, input_shape: list[int]):
return input_shape
class SelfAttention(MemEstimator):
def __init__(
self,
config: TransformerConfig,
submodules,
layer_number: int,
attn_mask_type,
):
super().__init__()
self.config = config
self.layer_number = layer_number
self.attn_mask_type = attn_mask_type
self.attention_type = ""
# For normal attention without groups, num_query_groups == num_attention_heads,
# so these two will be the same
self.query_projection_size = (
self.config.kv_channels * self.config.num_attention_heads
)
self.kv_projection_size = self.config.kv_channels * self.config.num_query_groups
# Per attention head and per partition values.
world_size = get_tensor_model_parallel_world_size()
self.hidden_size_per_attention_head = divide(
self.query_projection_size, self.config.num_attention_heads
)
self.num_attention_heads_per_partition = divide(
self.config.num_attention_heads, world_size
)
self.num_query_groups_per_partition = divide(
self.config.num_query_groups, world_size
)
self.core_attention = build_module(
submodules.core_attention,
config=self.config,
layer_number=self.layer_number,
attn_mask_type=self.attn_mask_type,
)
self.linear_qkv = build_module(
submodules.linear_qkv,
self.config.hidden_size,
self.query_projection_size + 2 * self.kv_projection_size,
config=self.config,
init_method=self.config.init_method,
gather_output=False,
bias=self.config.add_bias_linear or self.config.add_qkv_bias,
skip_bias_add=False,
is_expert=False,
tp_comm_buffer_name="qkv",
)
if submodules.q_layernorm is not None:
self.q_layernorm = build_module(
submodules.q_layernorm,
hidden_size=self.hidden_size_per_attention_head,
config=self.config,
eps=self.config.layernorm_epsilon,
)
else:
self.q_layernorm = None
if submodules.k_layernorm is not None:
self.k_layernorm = build_module(
submodules.k_layernorm,
hidden_size=self.hidden_size_per_attention_head,
config=self.config,
eps=self.config.layernorm_epsilon,
)
else:
self.k_layernorm = None
self.linear_proj = build_module(
submodules.linear_proj,
self.query_projection_size,
self.config.hidden_size,
config=self.config,
init_method=self.config.output_layer_init_method,
bias=self.config.add_bias_linear,
input_is_parallel=True,
skip_bias_add=True,
is_expert=False,
tp_comm_buffer_name="proj",
)
self.checkpoint_core_attention = (
self.config.recompute_granularity == "selective"
)
def num_parameter(self):
result = 0
result += self.core_attention.num_parameter()
result += self.linear_proj.num_parameter()
result += self.linear_qkv.num_parameter()
if self.q_layernorm is not None:
result += self.q_layernorm.num_parameter()
if self.k_layernorm is not None:
result += self.k_layernorm.num_parameter()
return result
def num_activation(self, input_shape: list[int]):
ret = 0
## in estimator: act(linear) = 1.5*cum_mul(input_shape)
## in reality: act(linear) = cum_mul(input_shape), act(rotary) = cum_mul(input_shape), act(attn_forward_func_with_cp) = cum_mul(input_shape)
# ret += self.linear_qkv.num_activation(input_shape)
mixed_qkv_shape = self.linear_qkv.mock_forward(input_shape)
new_tensor_shape = mixed_qkv_shape[:-1] + [
self.num_query_groups_per_partition,
(
(
self.num_attention_heads_per_partition
// self.num_query_groups_per_partition
+ 2
)
* self.hidden_size_per_attention_head
),
]
split_arg_list = [
(
self.num_attention_heads_per_partition
// self.num_query_groups_per_partition
* self.hidden_size_per_attention_head
),
self.hidden_size_per_attention_head,
self.hidden_size_per_attention_head,
]
# [sq, b, ng, (np/ng + 2) * hn]
# --> [sq, b, ng, np/ng * hn], [sq, b, ng, hn], [sq, b, ng, hn]
q_shape = new_tensor_shape[:-1] + [split_arg_list[0]]
k_shape = new_tensor_shape[:-1] + [split_arg_list[1]]
v_shape = new_tensor_shape[:-1] + [split_arg_list[2]]
# [sq, b, ng, np/ng * hn] -> [sq, b, np, hn]
q_shape = (
q_shape[:2]
+ [cum_mul(q_shape[-2:]) // self.hidden_size_per_attention_head]
+ [self.hidden_size_per_attention_head]
)
if not self.checkpoint_core_attention:
ret += self.core_attention.num_activation(q_shape, k_shape, v_shape)
ret += self.linear_proj.num_activation(input_shape)
## in reality: act(linear) = cum_mul(input_shape), act(rotary) = cum_mul(input_shape), act(attn_forward_func_with_cp) = cum_mul(input_shape)
ret += self.linear_proj.num_activation(input_shape) * 3
return ret
def mock_forward(self, input_shape: list[int]):
return input_shape
class Linear(MemEstimator):
def __init__(
self,
in_features: int,
out_features: int,
bias: bool = True,
device=None,
dtype=None,
) -> None:
super().__init__()
self.weight = (in_features, out_features)
def num_parameter(self):
return self.weight[0] * self.weight[1]
def num_activation(self, input_shape: list[int]):
return cum_mul(input_shape[:-1]) * self.weight[1]
def mock_forward(self, input_shape: list[int]):
return input_shape[:-1] + [self.weight[1]]
class MLASelfAttention(MemEstimator):
"""MLA Self-attention layer class
Self-attention layer takes input with size [s, b, h]
and returns output of the same size.
"""
def __init__(
self,
config: MLATransformerConfig,
submodules,
layer_number: int,
attn_mask_type=AttnMaskType.padding,
) -> None:
super().__init__()
self.config = config
self.layer_number = layer_number
self.attn_mask_type = attn_mask_type
self.attention_type = "self"
self.world_size = get_tensor_model_parallel_world_size()
# assert (
# world_size == 1
# ), "MLA is not supported with Tensor Parallelism yet, \
# use Expert Parallelism and Pipeline Parallelism for better performance."
self.query_projection_size = (
self.config.v_head_dim * self.config.num_attention_heads
)
self.q_head_dim = self.config.qk_head_dim + self.config.qk_pos_emb_head_dim
mscale = _yarn_get_mscale(self.config.rotary_scaling_factor, self.config.mscale)
self.softmax_scale = mscale * mscale / math.sqrt(self.q_head_dim)
# Per attention head and per partition values.
world_size = get_tensor_model_parallel_world_size()
self.hidden_size_per_attention_head = divide(
self.query_projection_size, self.config.num_attention_heads
)
self.num_attention_heads_per_partition = divide(
self.config.num_attention_heads, world_size
)
self.num_query_groups_per_partition = divide(
self.config.num_query_groups, world_size
)
# TODO Rotary Embedding
# self.rotary_pos_emb = YarnRotaryEmbedding(
# self.config.qk_pos_emb_head_dim,
# rotary_base=self.config.rotary_base,
# scaling_factor=self.config.rotary_scaling_factor,
# original_max_position_embeddings=self.config.max_position_embeddings,
# beta_fast=self.config.beta_fast,
# beta_slow=self.config.beta_slow,
# mscale=self.config.mscale,
# mscale_all_dim=self.config.mscale_all_dim,
# )
self.core_attention = build_module(
submodules.core_attention,
config=self.config,
layer_number=self.layer_number,
attn_mask_type=self.attn_mask_type,
attention_type=self.attention_type,
softmax_scale=self.softmax_scale,
k_channels=self.q_head_dim,
v_channels=self.config.v_head_dim,
)
if self.config.q_lora_rank is None:
# Not projectiing query
self.linear_q_proj = build_module(
submodules.linear_q_proj,
self.config.hidden_size,
self.config.num_attention_heads * self.q_head_dim,
config=self.config,
init_method=self.config.init_method,
gather_output=False,
bias=False,
skip_bias_add=False,
is_expert=False,
is_mla=True,
)
else:
self.linear_q_down_proj = Linear(
self.config.hidden_size, self.config.q_lora_rank, bias=False
)
self.linear_q_up_proj = build_module(
submodules.linear_q_up_proj,
self.config.q_lora_rank,
self.config.num_attention_heads * self.q_head_dim,
config=self.config,
init_method=self.config.init_method,
gather_output=False,
bias=False,
skip_bias_add=False,
is_expert=False,
is_mla=True,
)
self.linear_kv_down_proj = Linear(
self.config.hidden_size,
self.config.kv_lora_rank + self.config.qk_pos_emb_head_dim,
bias=False,
)
self.linear_kv_up_proj = build_module(
submodules.linear_kv_up_proj,
self.config.kv_lora_rank,
self.config.num_attention_heads
* (self.config.qk_head_dim + self.config.v_head_dim),
config=self.config,
init_method=self.config.init_method,
gather_output=False,
bias=False,
skip_bias_add=False,
is_expert=False,
is_mla=True,
)
if self.config.q_lora_rank is not None:
self.q_layernorm = build_module(
submodules.q_layernorm,
hidden_size=self.config.q_lora_rank,
config=self.config,
eps=self.config.layernorm_epsilon,
)
self.kv_layernorm = build_module(
submodules.kv_layernorm,
hidden_size=self.config.kv_lora_rank,
config=self.config,
eps=self.config.layernorm_epsilon,
)
# Output.
self.linear_proj = build_module(
submodules.linear_proj,
self.query_projection_size,
self.config.hidden_size,
config=self.config,
init_method=self.config.output_layer_init_method,
bias=self.config.add_bias_linear,
input_is_parallel=True,
skip_bias_add=True,
is_expert=False,
tp_comm_buffer_name="proj",
)
self.checkpoint_core_attention = (
self.config.recompute_granularity == "selective"
)
def num_parameter(self):
result = 0
result += self.core_attention.num_parameter()
result += self.linear_proj.num_parameter()
if self.config.q_lora_rank is None:
result += self.linear_q_proj.num_parameter()
else:
result += self.linear_q_down_proj.num_parameter()
result += self.linear_q_up_proj.num_parameter()
result += self.linear_kv_down_proj.num_parameter()
result += self.linear_kv_up_proj.num_parameter()
result += self.kv_layernorm.num_parameter()
if self.config.q_lora_rank is not None:
result += self.q_layernorm.num_parameter()
return result
def num_activation(self, input_shape: list[int]):
q_len, bsz, _ = input_shape
ret = 0
if self.config.q_lora_rank is not None:
ret += self.linear_q_down_proj.num_activation(input_shape)
q_compressed_shape = self.linear_q_down_proj.mock_forward(input_shape)
ret += self.q_layernorm.num_activation(q_compressed_shape)
ret += self.linear_q_up_proj.num_activation(q_compressed_shape)
q_shape = self.linear_q_up_proj.mock_forward(q_compressed_shape)
else:
# hidden_states:[s, b, 2048], q: [s, b, n * 192]
ret += self.linear_q_proj.num_activation(input_shape)
q_shape = self.linear_q_proj.mock_forward(input_shape)
# kv_combined: [s, b, 576]
ret += self.linear_kv_down_proj.num_activation(input_shape)
kv_combined_shape = self.linear_kv_down_proj.mock_forward(input_shape)
# kv_compressed:[s, b, 512], k_pos_emb: [s, b, 64]
kv_compressed_shape = kv_combined_shape[:-1] + [self.config.kv_lora_rank]
# kv: [s, b, 2048]
ret += self.kv_layernorm.num_activation(kv_compressed_shape)
ret += self.linear_kv_up_proj.num_activation(kv_compressed_shape)
q_shape = [q_len, bsz, self.num_attention_heads_per_partition, self.q_head_dim]
k_shape = [q_len, bsz, self.num_attention_heads_per_partition, self.q_head_dim]
v_shape = [
q_len,
bsz,
self.num_attention_heads_per_partition,
self.config.v_head_dim,
]
if not self.checkpoint_core_attention:
ret += self.core_attention.num_activation(q_shape, k_shape, v_shape)
ret += self.linear_proj.num_activation(input_shape)
return ret
def mock_forward(self, input_shape: list[int]):
return input_shape
class TENorm:
def __new__(cls, config: TransformerConfig, hidden_size: int, eps: float = 1e-5):
from megatron.core.extensions.transformer_engine import _get_extra_te_kwargs, te
if config.normalization == "LayerNorm":
# TODO layernorm
pass
elif config.normalization == "RMSNorm":
assert hasattr(
te.pytorch, "RMSNorm"
), "Transformer-Engine >= v0.11 required to use this feature"
instance = RMSNorm(
hidden_size=hidden_size,
eps=eps,
sequence_parallel=config.sequence_parallel,
zero_centered_gamma=config.layernorm_zero_centered_gamma,
**_get_extra_te_kwargs(config),
)
else:
raise Exception("Only LayerNorm and RMSNorm are curently supported")
return instance
def build_module(
spec_or_module: Union[ModuleSpec, type], *args, **kwargs
) -> MemEstimator:
"""replace module with MemEstimators"""
if isinstance(spec_or_module, types.FunctionType):
return globals()[spec_or_module.__name__]
if isinstance(spec_or_module, ModuleSpec) and isinstance(
spec_or_module.module, types.FunctionType
):
assert False
return spec_or_module.module
if isinstance(spec_or_module, type):
module = spec_or_module
elif hasattr(spec_or_module, "module") and isinstance(spec_or_module.module, type):
module = spec_or_module.module
else:
module = import_module(spec_or_module.module)
if isinstance(module, types.FunctionType):
assert False
return module
if hasattr(spec_or_module, "submodules") and spec_or_module.submodules is not None:
kwargs["submodules"] = spec_or_module.submodules
try:
module = globals()[module.__name__]
return module(
*args,
**spec_or_module.params if hasattr(spec_or_module, "params") else {},
**kwargs,
)
except Exception as e:
# import ipdb
# ipdb.set_trace()
# improve the error message since we hide the module name in the line above
import sys
raise type(e)(f"{str(e)} when instantiating {module.__name__}").with_traceback(
sys.exc_info()[2]
)
from megatron.core.transformer.transformer_block import (
TransformerBlockSubmodules,
BaseTransformerLayer,
LayerNormImpl,
)
def _get_block_submodules(
config: TransformerConfig, spec: Union[TransformerBlockSubmodules, ModuleSpec]
) -> TransformerBlockSubmodules:
"""
Retrieve or construct TransformerBlockSubmodules based on the provided specification.
Args:
config (TransformerConfig): Configuration object for the transformer model.
spec (Union[TransformerBlockSubmodules, ModuleSpec]): Specification for the
transformer block submodules. Can be either a TransformerBlockSubmodules
instance or a ModuleSpec.
Returns:
TransformerBlockSubmodules: The submodules for the transformer block.
"""
# Transformer block submodules.
if isinstance(spec, TransformerBlockSubmodules):
return spec
# ModuleSpec here is generally assumed to be for a transformer layer that
# is implemented in `transformer_layer.py` or if it subclasses
# `BaseTransformerLayer` from the `transformer_layer.py` file.
elif isinstance(spec, ModuleSpec):
if issubclass(spec.module, TransformerBlock):
return spec.submodules
elif issubclass(spec.module, BaseTransformerLayer):
num_layers = get_num_layers_to_build(config)
return TransformerBlockSubmodules(
layer_specs=[spec] * num_layers, layer_norm=LayerNormImpl
)
else:
raise Exception(f"specialize for {spec.module.__name__}.")
else:
raise Exception(f"specialize for {type(spec).__name__}.")
def get_num_layers_to_build(config: TransformerConfig) -> int:
"""
Determine the number of transformer layers to build for the current pipeline stage.
Args:
config (TransformerConfig): Configuration object containing transformer model parameters.
Returns:
int: The number of layers to be built for the current pipeline stage.
"""
if (
config.num_layers_in_first_pipeline_stage is not None
or config.num_layers_in_last_pipeline_stage is not None
):
assert not (
config.account_for_embedding_in_pipeline_split
or config.account_for_loss_in_pipeline_split
), " \
Does not support standalone embedding stage and standalone loss stage with uneven pp"
# Number of layers to distribute over rest of pipeline stages
layers_to_distribute = config.num_layers
# Number of pipeline stages left for distributing transformer layers
pipeline_stages_left = get_pipeline_model_parallel_world_size()
# If the uneven first (last) pipeline stage is enabled, remove the specified number
# of layers to calculate the number of layers on each middle pipeline stage.
if config.num_layers_in_first_pipeline_stage is not None:
layers_to_distribute -= config.num_layers_in_first_pipeline_stage
pipeline_stages_left -= 1
if config.num_layers_in_last_pipeline_stage is not None:
layers_to_distribute -= config.num_layers_in_last_pipeline_stage
pipeline_stages_left -= 1
assert (
layers_to_distribute % pipeline_stages_left == 0
), "With uneven pipelineing the left over layers must be divisible by left over stages"
num_layers_per_pipeline_rank = layers_to_distribute // pipeline_stages_left
# If the uneven first (last) pipeline stage is enabled, return the specified number
# of layers for all virtual pipeline parallel stages within the first (last) pipeline
# parallel stage.
if (
is_pipeline_first_stage(ignore_virtual=True)
and config.num_layers_in_first_pipeline_stage is not None
):
num_layers_per_pipeline_rank = config.num_layers_in_first_pipeline_stage
if (
is_pipeline_last_stage(ignore_virtual=True)
and config.num_layers_in_last_pipeline_stage is not None
):
num_layers_per_pipeline_rank = config.num_layers_in_last_pipeline_stage
else:
# Include the embedding layer and loss layer into pipeline parallelism partition
num_layers = config.num_layers
if config.account_for_embedding_in_pipeline_split:
num_layers += 1
if config.account_for_loss_in_pipeline_split:
num_layers += 1
assert (
num_layers % config.pipeline_model_parallel_size == 0
), "num_layers should be divisible by pipeline_model_parallel_size"
num_layers_per_pipeline_rank = num_layers // config.pipeline_model_parallel_size
# if get_virtual_pipeline_model_parallel_world_size() is not None:
# # Interleaved pipeline parallelism:
# # Number of layers in each model chunk is the number of layers in the stage,
# # divided by the number of model chunks in a stage.
# # With 8 layers, 2 stages, and 4 model chunks, we want an assignment of
# # layers to stages like (each list is a model chunk):
# # Stage 0: [0] [2] [4] [6]
# # Stage 1: [1] [3] [5] [7]
# # With 8 layers, 2 stages, and 2 virtual stages, we want an assignment of
# # layers to stages like (each list is a model chunk):
# # Stage 0: [0, 1] [4, 5]
# # Stage 1: [2, 3] [6, 7]
# vp_size = get_virtual_pipeline_model_parallel_world_size()
# assert (
# num_layers_per_pipeline_rank % vp_size == 0
# ), "num_layers_per_pipeline_rank should be divisible by vp_size"
# num_layers_per_virtual_rank = num_layers_per_pipeline_rank // vp_size
# num_layers_to_build = num_layers_per_virtual_rank
# else:
# # Non-interleaved pipeline parallelism:
# # Each stage gets a contiguous set of layers.
# num_layers_to_build = num_layers_per_pipeline_rank
num_layers_to_build = num_layers_per_pipeline_rank
# The embedding (or loss) layer cannot function as a standalone transformer layer
# Reduce the number of layers to construct by 1 on the first (or last) stage if the
# embedding (or loss) layer is included in the pipeline parallelism partition and placement.
if is_pipeline_first_stage() and config.account_for_embedding_in_pipeline_split:
num_layers_to_build -= 1
assert (
num_layers_to_build >= 0
), "Not enough layers in the first virtual pipeline stage"
if is_pipeline_last_stage() and config.account_for_loss_in_pipeline_split:
num_layers_to_build -= 1
assert (
num_layers_to_build >= 0
), "Not enough layers in the last virtual pipeline stage"
return num_layers_to_build
def get_transformer_layer_offset(config: TransformerConfig):
"""Get the index offset of current pipeline stage, given the level of pipelining."""
pipeline_rank = get_pipeline_model_parallel_rank()
# if not is_inside_encoder():
if True:
pp_decoder_start = 0
if pp_decoder_start is not None:
pipeline_rank = pipeline_rank - pp_decoder_start
if config.pipeline_model_parallel_size > 1:
if (
config.num_layers_in_first_pipeline_stage is not None
or config.num_layers_in_last_pipeline_stage is not None
):
# Calculate number of pipeline stages to distribute the remaining Transformer
# layers after deducting the Transformer layers in the first or the last stages
middle_pipeline_stages = config.pipeline_model_parallel_size
middle_pipeline_stages -= sum(
[
1 if x is not None else 0
for x in (
config.num_layers_in_first_pipeline_stage,
config.num_layers_in_last_pipeline_stage,
)
]
)
# Calculate layers to distribute in each pipeline stage. If the
# num_layers_in_first_pipeline_stage and num_layers_in_last_pipeline_stage
# are not set, we will not enable uneven pipeline. All layers will be treated
# as middle layers.
num_layers_in_first_pipeline_stage = (
0
if config.num_layers_in_first_pipeline_stage is None
else config.num_layers_in_first_pipeline_stage
)
num_layers_in_last_pipeline_stage = (
0
if config.num_layers_in_last_pipeline_stage is None
else config.num_layers_in_last_pipeline_stage
)
middle_num_layers = (
config.num_layers
- num_layers_in_first_pipeline_stage
- num_layers_in_last_pipeline_stage
)
if middle_pipeline_stages > 0:
num_layers_per_pipeline_rank = (
middle_num_layers // middle_pipeline_stages
)
else:
num_layers_per_pipeline_rank = 0
middle_pipeline_rank = (
pipeline_rank
if config.num_layers_in_first_pipeline_stage is None
else pipeline_rank - 1
)
if pipeline_rank == 0:
offset = 0
else:
offset = (
middle_pipeline_rank * num_layers_per_pipeline_rank
) + num_layers_in_first_pipeline_stage
else:
num_layers = config.num_layers
# Increase the number of layers by one if we include the embedding (loss)
# layer into pipeline parallelism partition and placement
if config.account_for_embedding_in_pipeline_split:
num_layers += 1
if config.account_for_loss_in_pipeline_split:
num_layers += 1
num_layers_per_pipeline_rank = (
num_layers // config.pipeline_model_parallel_size
)
offset = pipeline_rank * num_layers_per_pipeline_rank
# Reduce the offset of embedding layer from the total layer number
if (
config.account_for_embedding_in_pipeline_split
and not is_pipeline_first_stage()
):
offset -= 1
else:
offset = 0
return offset