import os import torch from torch.nn import ModuleDict from transformers import ( PreTrainedModel, AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer ) from transformers.modeling_outputs import CausalLMOutput from typing import Optional, Tuple, Union, Dict from .qwen3moe_configuration import Qwen3MoEConfig class Qwen3MoEForCausalLM(PreTrainedModel): config_class = Qwen3MoEConfig def __init__(self, config: Qwen3MoEConfig): super().__init__(config) self.router = AutoModelForSequenceClassification.from_pretrained( config.router_model_path, torch_dtype=config.torch_dtype, trust_remote_code=True, local_files_only=True ) self.router_tokenizer = AutoTokenizer.from_pretrained( config.router_model_path, trust_remote_code=True, local_files_only=True ) self.experts = ModuleDict({ label: AutoModelForCausalLM.from_pretrained( path, torch_dtype=config.torch_dtype, trust_remote_code=True, local_files_only=True ) for label, path in config.expert_model_paths.items() }) self.expert_tokenizer = AutoTokenizer.from_pretrained( config.tokenizer_path, trust_remote_code=True, local_files_only=True ) @classmethod def from_pretrained(cls, pretrained_dir: str, config: Optional[Qwen3MoEConfig] = None, **kwargs): if config is None: config = Qwen3MoEConfig.from_pretrained(pretrained_dir) base = pretrained_dir config.router_model_path = os.path.join(base, config.router_model_path) config.expert_model_paths = { label: os.path.join(base, path) for label, path in config.expert_model_paths.items() } config.tokenizer_path = os.path.join(base, config.tokenizer_path) return cls(config) def get_tokenizer(self): return self.expert_tokenizer def route(self, plain_text: str) -> str: with torch.no_grad(): inputs = self.router_tokenizer(plain_text, return_tensors="pt").to(self.router.device) logits = self.router(**inputs).logits if logits.dim() == 2: class_id = torch.argmax(logits, dim=-1).item() return self.config.labels[class_id] return self.config.labels[0] def generate( self, text: str, max_new_tokens: int = 50, **kwargs ) -> torch.LongTensor: # 1. Route using router tokenizer plain_text = text if "<|im_start|>" in plain_text: temp = plain_text.split("<|im_start|>")[-2] plain_text = temp[:temp.find("<|im_end|>")][4:] label = self.route(plain_text) expert = self.experts[label] # 2. Tokenize once with the expert tokenizer inputs = self.expert_tokenizer(text, return_tensors="pt").to(expert.device) # 3. Generate using selected expert return expert.generate( input_ids=inputs.input_ids, attention_mask=inputs.attention_mask, max_new_tokens=max_new_tokens, **kwargs ) def forward( self, input_ids: torch.LongTensor, attention_mask: Optional[torch.Tensor] = None, labels: Optional[torch.LongTensor] = None, **kwargs ) -> Union[Tuple, CausalLMOutput]: raise NotImplementedError("Use `generate(text=...)` instead for inference.")