# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai # SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai # SPDX-License-Identifier: Apache-2.0 # Contact: qubitium@modelcloud.ai, x.com/qubitium """PyTorch implementation of the MiniMax M2 architecture for Hugging Face Transformers.""" from __future__ import annotations import copy import time from typing import Optional, Tuple, Union import torch import torch.nn.functional as F from torch import nn from transformers.activations import ACT2FN from transformers.cache_utils import Cache, DynamicCache from transformers.generation import GenerationMixin from transformers.masking_utils import create_causal_mask, create_sliding_window_causal_mask from transformers.modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast from transformers.modeling_utils import PreTrainedModel from transformers.utils import logging from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding, repeat_kv, rotate_half from .configuration_minimax_m2 import MiniMaxM2Config logger = logging.get_logger(__name__) _CONFIG_FOR_DOC = "MiniMaxM2Config" _CHECKPOINT_FOR_DOC = "MiniMaxAI/MiniMax-M2" def load_balancing_loss_func( gate_logits: Union[torch.Tensor, Tuple[torch.Tensor, ...]], num_experts: int, top_k: int, attention_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: if gate_logits is None: return torch.tensor(0.0) if isinstance(gate_logits, torch.Tensor): logits = gate_logits else: logits = torch.cat([layer_gate.to(gate_logits[0].device) for layer_gate in gate_logits], dim=0) routing_weights = torch.softmax(logits, dim=-1, dtype=torch.float32) _, selected = torch.topk(routing_weights, top_k, dim=-1) expert_mask = torch.nn.functional.one_hot(selected, num_experts) if attention_mask is None: tokens_per_expert = torch.mean(expert_mask.float(), dim=0) router_prob_per_expert = torch.mean(routing_weights, dim=0) else: batch_size, seq_len = attention_mask.shape num_layers = logits.shape[0] // (batch_size * seq_len) expanded_mask = ( attention_mask[None, :, :, None, None] .expand(num_layers, batch_size, seq_len, top_k, num_experts) .reshape(-1, top_k, num_experts) .to(logits.device) ) tokens_per_expert = torch.sum(expert_mask.float() * expanded_mask, dim=0) / torch.sum(expanded_mask, dim=0) router_mask = ( attention_mask[None, :, :, None] .expand(num_layers, batch_size, seq_len, num_experts) .reshape(-1, num_experts) .to(logits.device) ) router_prob_per_expert = torch.sum(routing_weights * router_mask, dim=0) / torch.sum(router_mask, dim=0) loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0)) return loss * num_experts def apply_rotary_pos_emb_partial( q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, rotary_dim: int, unsqueeze_dim: int = 2, ) -> Tuple[torch.Tensor, torch.Tensor]: cos = cos.unsqueeze(unsqueeze_dim)[..., :rotary_dim] sin = sin.unsqueeze(unsqueeze_dim)[..., :rotary_dim] q_rot = q[..., :rotary_dim] k_rot = k[..., :rotary_dim] q_rot = (q_rot * cos) + (rotate_half(q_rot) * sin) k_rot = (k_rot * cos) + (rotate_half(k_rot) * sin) q = torch.cat((q_rot, q[..., rotary_dim:]), dim=-1) k = torch.cat((k_rot, k[..., rotary_dim:]), dim=-1) return q, k class MiniMaxM2RMSNorm(nn.Module): def __init__(self, hidden_size: int, eps: float = 1e-6) -> None: super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.variance_epsilon = eps def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: input_dtype = hidden_states.dtype hidden_states = hidden_states.to(torch.float32) variance = hidden_states.pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) return (self.weight * hidden_states).to(input_dtype) class MiniMaxM2MLP(nn.Module): def __init__(self, config: MiniMaxM2Config) -> None: super().__init__() self.hidden_size = config.hidden_size self.intermediate_size = config.intermediate_size self.w1 = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) self.w2 = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) self.w3 = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) self.act_fn = ACT2FN[config.hidden_act] def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: gate = self.act_fn(self.w1(hidden_states)) up = self.w3(hidden_states) hidden_states = gate * up hidden_states = self.w2(hidden_states) return hidden_states class MiniMaxM2SparseMoeBlock(nn.Module): def __init__(self, config: MiniMaxM2Config) -> None: super().__init__() self.hidden_dim = config.hidden_size self.experts = nn.ModuleList([MiniMaxM2MLP(config) for _ in range(config.num_local_experts)]) self.num_experts = config.num_local_experts self.top_k = config.num_experts_per_tok self.jitter_noise = config.router_jitter_noise self.use_routing_bias = config.use_routing_bias self.scoring_func = getattr(config, "scoring_func", "softmax") self.use_grouped_topk = getattr(config, "use_grouped_topk", False) self.num_expert_group = getattr(config, "num_expert_group", None) self.topk_group = getattr(config, "topk_group", None) self.routed_scaling_factor = getattr(config, "routed_scaling_factor", 1.0) if self.use_grouped_topk: if self.num_expert_group is None or self.num_expert_group <= 0: self.num_expert_group = 1 if self.topk_group is None or self.topk_group <= 0: self.topk_group = min(self.num_expert_group, self.top_k) else: self.num_expert_group = 1 self.topk_group = 1 self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False) if self.use_routing_bias: self.e_score_correction_bias = nn.Parameter(torch.zeros(self.num_experts, dtype=torch.float32)) else: self.register_parameter("e_score_correction_bias", None) def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: batch_size, seq_len, hidden_dim = hidden_states.shape if self.training and self.jitter_noise > 0: noise = torch.empty_like(hidden_states).uniform_( 1.0 - self.jitter_noise, 1.0 + self.jitter_noise, ) hidden_states = hidden_states * noise hidden_states = hidden_states.view(-1, hidden_dim) gate_dtype = self.gate.weight.dtype router_logits = self.gate(hidden_states.to(gate_dtype)).to(torch.float32) if self.e_score_correction_bias is not None: # Bias is applied after scoring (see vLLM/SGLang implementations). correction_bias = self.e_score_correction_bias.to(router_logits.device, router_logits.dtype) else: correction_bias = None if self.scoring_func == "sigmoid": scores = torch.sigmoid(router_logits) elif self.scoring_func == "softmax": scores = torch.softmax(router_logits, dim=-1) else: raise ValueError(f"Unsupported scoring function: {self.scoring_func}") if correction_bias is not None: original_scores = scores scores = scores + correction_bias else: original_scores = scores topk_scores: torch.Tensor if self.use_grouped_topk and self.num_expert_group > 1: experts_per_group = scores.size(-1) // self.num_expert_group scores_grouped = scores.view(scores.size(0), self.num_expert_group, experts_per_group) if correction_bias is not None: topk_in_group = min(2, experts_per_group) if topk_in_group > 0: group_scores = scores_grouped.topk(topk_in_group, dim=-1)[0].sum(dim=-1) else: group_scores = torch.zeros_like(scores_grouped[..., 0]) else: group_scores = scores_grouped.max(dim=-1).values group_mask = torch.zeros_like(group_scores) selected_groups = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=True).indices group_mask.scatter_(1, selected_groups, 1.0) mask = group_mask.unsqueeze(-1).expand(-1, -1, experts_per_group).reshape(scores.size()) masked_scores = scores.masked_fill(mask == 0, float("-inf")) topk_scores, selected_experts = torch.topk(masked_scores, self.top_k, dim=-1, sorted=True) else: topk_scores, selected_experts = torch.topk(scores, self.top_k, dim=-1, sorted=True) if correction_bias is not None: routing_weights = original_scores.gather(1, selected_experts) else: routing_weights = topk_scores routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True).clamp(min=1e-12) if self.routed_scaling_factor != 1.0: routing_weights = routing_weights * self.routed_scaling_factor routing_weights = routing_weights.to(hidden_states.dtype) selected_experts = selected_experts.to(torch.long) final_hidden_states = torch.zeros_like(hidden_states) expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0) expert_hit = torch.nonzero(expert_mask.sum(dim=(-1, -2)) > 0, as_tuple=False).flatten() for expert_idx in expert_hit.tolist(): expert_layer = self.experts[expert_idx] idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0)) token_states = hidden_states.index_select(0, top_x) expert_output = expert_layer(token_states) * routing_weights[top_x, idx].unsqueeze(-1) final_hidden_states.index_add_(0, top_x, expert_output.to(final_hidden_states.dtype)) final_hidden_states = final_hidden_states.view(batch_size, seq_len, hidden_dim) return final_hidden_states, router_logits class MiniMaxM2Attention(nn.Module): def __init__(self, config: MiniMaxM2Config, layer_idx: int) -> None: super().__init__() self.config = config self.layer_idx = layer_idx self.head_dim = config.head_dim self.num_heads = config.num_attention_heads self.num_key_value_heads = config.num_key_value_heads self.num_key_value_groups = self.num_heads // max(1, self.num_key_value_heads) self.rotary_dim = config.rotary_dim self.scaling = self.head_dim**-0.5 self.attention_dropout = config.attention_dropout self.is_causal = True max_position_embeddings = getattr(config, "max_position_embeddings", 8192) max_model_len = getattr(config, "max_model_len", None) if max_model_len is not None: max_position_embeddings = max(max_position_embeddings, max_model_len) attn_window_size = getattr(config, "attn_window_size", None) if isinstance(attn_window_size, list): sliding_window = attn_window_size[layer_idx] else: sliding_window = attn_window_size if sliding_window is not None and sliding_window <= 0: sliding_window = None self.sliding_window = sliding_window swa_rope_theta = getattr(config, "swa_rope_theta", -1.0) rope_theta = config.rope_theta if self.sliding_window is not None and swa_rope_theta > 0: rope_theta = swa_rope_theta self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=False) self.k_proj = nn.Linear(config.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) self.v_proj = nn.Linear(config.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) self.o_proj = nn.Linear(self.num_heads * self.head_dim, config.hidden_size, bias=False) self.use_qk_norm = config.use_qk_norm if self.use_qk_norm: self.q_norm = MiniMaxM2RMSNorm(self.num_heads * self.head_dim, eps=config.rms_norm_eps) self.k_norm = MiniMaxM2RMSNorm(self.num_key_value_heads * self.head_dim, eps=config.rms_norm_eps) rope_config = copy.deepcopy(config) rope_config.hidden_size = config.hidden_size rope_config.num_attention_heads = config.num_attention_heads rope_config.partial_rotary_factor = float(config.rotary_dim) / float(self.head_dim) rope_config.rope_theta = rope_theta rope_config.max_position_embeddings = max_position_embeddings self.rotary_emb = LlamaRotaryEmbedding(rope_config) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, output_attentions: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: bsz, q_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) if self.use_qk_norm: q_flat = query_states.transpose(1, 2).reshape(bsz * q_len, -1) k_flat = key_states.transpose(1, 2).reshape(bsz * q_len, -1) q_flat = self.q_norm(q_flat) k_flat = self.k_norm(k_flat) query_states = q_flat.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = k_flat.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) if position_embeddings is None: cos, sin = self.rotary_emb(value_states, position_ids) else: cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb_partial( query_states.transpose(1, 2), key_states.transpose(1, 2), cos, sin, self.rotary_dim ) query_states = query_states.transpose(1, 2) key_states = key_states.transpose(1, 2) if past_key_values is not None: cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) attn_weights = torch.matmul(query_states, key_states.transpose(-2, -1)) * self.scaling if attention_mask is not None: attn_weights = attn_weights + attention_mask if self.sliding_window is not None and past_key_values is None: query_positions = torch.arange(q_len, device=hidden_states.device).view(1, 1, q_len, 1) key_positions = torch.arange(key_states.shape[-2], device=hidden_states.device).view(1, 1, 1, -1) window_mask = key_positions < (query_positions - self.sliding_window) if window_mask.any(): attn_weights = attn_weights.masked_fill(window_mask, float("-inf")) attn_weights = torch.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) if self.training and self.attention_dropout > 0: attn_weights = F.dropout(attn_weights, p=self.attention_dropout) attn_output = torch.matmul(attn_weights, value_states) attn_output = attn_output.transpose(1, 2).contiguous().view(bsz, q_len, -1) attn_output = self.o_proj(attn_output) if not output_attentions: attn_weights = None return attn_output, attn_weights class MiniMaxM2LogitsProcessor(nn.Module): def __init__(self, config: MiniMaxM2Config) -> None: super().__init__() self.scale = getattr(config, "logits_scale", 1.0) def forward(self, lm_head: nn.Module, hidden_states: torch.Tensor) -> torch.Tensor: logits = lm_head(hidden_states) if self.scale != 1.0: logits = logits * self.scale return logits class MiniMaxM2DecoderLayer(nn.Module): def __init__(self, config: MiniMaxM2Config, layer_idx: int) -> None: super().__init__() self.hidden_size = config.hidden_size self.self_attn = MiniMaxM2Attention(config, layer_idx) self.block_sparse_moe = MiniMaxM2SparseMoeBlock(config) self.input_layernorm = MiniMaxM2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = MiniMaxM2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, output_attentions: bool = False, residual: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], torch.Tensor]: residual_input = hidden_states if residual is None else residual hidden_states = self.input_layernorm(hidden_states) attn_output, attn_weights = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, output_attentions=output_attentions, ) hidden_states = residual_input + attn_output residual_post_attn = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) moe_output, router_logits = self.block_sparse_moe(hidden_states) hidden_states = residual_post_attn + moe_output return hidden_states, hidden_states, router_logits, attn_weights class MiniMaxM2PreTrainedModel(PreTrainedModel): config_class = MiniMaxM2Config base_model_prefix = "model" supports_gradient_checkpointing = True _no_split_modules = ["MiniMaxM2DecoderLayer"] _supports_flash_attn = False _supports_sdpa = False _supports_attention_backend = False def _init_weights(self, module: nn.Module) -> None: if isinstance(module, nn.Linear): module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() def _remap_qkv_weights(self, state_dict): num_q = self.config.num_attention_heads * self.config.head_dim num_kv = self.config.num_key_value_heads * self.config.head_dim for layer_idx in range(self.config.num_hidden_layers): prefix = f"model.layers.{layer_idx}.self_attn" weight_key = f"{prefix}.qkv_proj.weight" if weight_key in state_dict: qkv_weight = state_dict.pop(weight_key) q_weight, k_weight, v_weight = qkv_weight.split([num_q, num_kv, num_kv], dim=0) state_dict.setdefault(f"{prefix}.q_proj.weight", q_weight) state_dict.setdefault(f"{prefix}.k_proj.weight", k_weight) state_dict.setdefault(f"{prefix}.v_proj.weight", v_weight) def load_state_dict(self, state_dict, strict: bool = True): if not isinstance(state_dict, dict): raise TypeError(f"Expected state_dict to be dict, got {type(state_dict)}") filtered_state_dict = {} drop_suffixes = ("weight_scale_inv", "weight_scale", "input_scale", "scales", "amax") for key, value in state_dict.items(): if key.endswith(drop_suffixes) or "fp8" in key: continue filtered_state_dict[key] = value self._remap_qkv_weights(filtered_state_dict) if logger.isEnabledFor(logging.INFO): logger.info( "MiniMaxM2: loading %d tensors (filtered from %d original).", len(filtered_state_dict), len(state_dict), ) load_start = time.perf_counter() result = super().load_state_dict(filtered_state_dict, strict=strict) load_elapsed = time.perf_counter() - load_start if logger.isEnabledFor(logging.INFO): logger.info("MiniMaxM2: state_dict load finished in %.2f seconds.", load_elapsed) return result class MiniMaxM2Model(MiniMaxM2PreTrainedModel): def __init__(self, config: MiniMaxM2Config) -> None: super().__init__(config) self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) self.layers = nn.ModuleList( [MiniMaxM2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ) self.norm = MiniMaxM2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.gradient_checkpointing = False self.post_init() def get_input_embeddings(self) -> nn.Module: return self.embed_tokens def set_input_embeddings(self, value: nn.Module) -> None: self.embed_tokens = value def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.Tensor] = None, cache_position: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: bool = False, output_hidden_states: bool = False, output_router_logits: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[MoeModelOutputWithPast, Tuple]: if (input_ids is None) == (inputs_embeds is None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds.") return_dict = return_dict if return_dict is not None else self.config.use_return_dict use_cache = use_cache if use_cache is not None else self.config.use_cache output_router_logits = ( output_router_logits if output_router_logits is not None else self.config.output_router_logits ) if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) if use_cache and past_key_values is None: past_key_values = DynamicCache(config=self.config) if cache_position is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) if position_ids is None: position_ids = cache_position.unsqueeze(0) if self.config.sliding_window is not None: causal_mask = create_sliding_window_causal_mask( config=self.config, input_embeds=inputs_embeds, attention_mask=attention_mask, cache_position=cache_position, past_key_values=past_key_values, position_ids=position_ids, ) else: causal_mask = create_causal_mask( config=self.config, input_embeds=inputs_embeds, attention_mask=attention_mask, cache_position=cache_position, past_key_values=past_key_values, position_ids=position_ids, ) hidden_states = inputs_embeds all_hidden_states = () if output_hidden_states else None all_attentions = () if output_attentions else None all_router_logits = () if output_router_logits else None residual = None for decoder_layer in self.layers: if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) layer_outputs = decoder_layer( hidden_states, attention_mask=causal_mask, position_ids=position_ids, past_key_values=past_key_values, use_cache=use_cache, cache_position=cache_position, position_embeddings=None, output_attentions=output_attentions, residual=residual, ) hidden_states, residual, router_logits, attn_weights = layer_outputs if output_router_logits: all_router_logits = all_router_logits + (router_logits,) if output_attentions: all_attentions = all_attentions + (attn_weights,) hidden_states = self.norm(hidden_states) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) if not return_dict: outputs = (hidden_states, past_key_values) if output_hidden_states: outputs += (all_hidden_states,) if output_attentions: outputs += (all_attentions,) if output_router_logits: outputs += (all_router_logits,) return outputs return MoeModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_attentions, router_logits=all_router_logits, ) class MiniMaxM2ForCausalLM(MiniMaxM2PreTrainedModel, GenerationMixin): def __init__(self, config: MiniMaxM2Config) -> None: super().__init__(config) self.model = MiniMaxM2Model(config) self.vocab_size = config.vocab_size self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) self.router_aux_loss_coef = config.router_aux_loss_coef self.num_experts = config.num_local_experts self.num_experts_per_tok = config.num_experts_per_tok self.logits_processor = MiniMaxM2LogitsProcessor(config) self.post_init() def get_input_embeddings(self) -> nn.Module: return self.model.embed_tokens def set_input_embeddings(self, value: nn.Module) -> None: self.model.embed_tokens = value def get_output_embeddings(self) -> nn.Module: return self.lm_head def set_output_embeddings(self, new_embeddings: nn.Module) -> None: self.lm_head = new_embeddings def prepare_inputs_for_generation( self, input_ids: torch.LongTensor, past_key_values: Optional[Cache] = None, attention_mask: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None, **kwargs, ): if past_key_values is not None: input_ids = input_ids[:, -1:] if attention_mask is not None: attention_mask = attention_mask[:, -past_key_values.get_seq_length() - 1 :] return { "input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values, "inputs_embeds": inputs_embeds, } def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.Tensor] = None, labels: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: bool = False, output_hidden_states: bool = False, output_router_logits: Optional[bool] = None, return_dict: Optional[bool] = None, logits_to_keep: Union[int, torch.Tensor] = 0, ) -> Union[MoeCausalLMOutputWithPast, Tuple]: return_dict = return_dict if return_dict is not None else self.config.use_return_dict output_router_logits = ( output_router_logits if output_router_logits is not None else self.config.output_router_logits ) model_outputs = self.model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, cache_position=cache_position, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, output_router_logits=output_router_logits, return_dict=True, ) hidden_states = model_outputs.last_hidden_state slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) and logits_to_keep > 0 else slice(None) logits = self.logits_processor(self.lm_head, hidden_states[:, slice_indices, :]) loss = None if labels is not None: shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() loss_fct = nn.CrossEntropyLoss() loss = loss_fct(shift_logits.view(-1, self.vocab_size), shift_labels.view(-1)) aux_loss = None if output_router_logits and model_outputs.router_logits is not None: aux_loss = load_balancing_loss_func( model_outputs.router_logits, num_experts=self.num_experts, top_k=self.num_experts_per_tok, attention_mask=attention_mask, ) if loss is not None: loss = loss + self.router_aux_loss_coef * aux_loss.to(loss.device) if not return_dict: output = (logits,) + (model_outputs.past_key_values,) if output_hidden_states: output += (model_outputs.hidden_states,) if output_attentions: output += (model_outputs.attentions,) if output_router_logits: output += (model_outputs.router_logits,) return ((loss,) + output) if loss is not None else output return MoeCausalLMOutputWithPast( loss=loss, aux_loss=aux_loss, logits=logits, past_key_values=model_outputs.past_key_values, hidden_states=model_outputs.hidden_states, attentions=model_outputs.attentions, router_logits=model_outputs.router_logits, ) # ----------------------------------------------------------------------------- # Backward compatibility aliases # ----------------------------------------------------------------------------- MiniMaxRMSNorm = MiniMaxM2RMSNorm MiniMaxSparseMoeBlock = MiniMaxM2SparseMoeBlock MiniMaxAttention = MiniMaxM2Attention MiniMaxDecoderLayer = MiniMaxM2DecoderLayer MiniMaxMLP = MiniMaxM2MLP MiniMaxPreTrainedModel = MiniMaxM2PreTrainedModel MiniMaxModel = MiniMaxM2Model class MiniMaxForCausalLM(MiniMaxM2ForCausalLM): """Alias for compatibility with checkpoints exporting MiniMaxForCausalLM.""" __all__ = [ "MiniMaxM2RMSNorm", "MiniMaxM2SparseMoeBlock", "MiniMaxM2Attention", "MiniMaxM2DecoderLayer", "MiniMaxM2Model", "MiniMaxM2ForCausalLM", "MiniMaxM2PreTrainedModel", "MiniMaxRMSNorm", "MiniMaxSparseMoeBlock", "MiniMaxAttention", "MiniMaxDecoderLayer", "MiniMaxPreTrainedModel", "MiniMaxModel", "MiniMaxMLP", "MiniMaxForCausalLM", ]