File size: 8,874 Bytes
17b604b 3bd1be9 17b604b 3bd1be9 17b604b 3bd1be9 17b604b 3bd1be9 17b604b 3bd1be9 17b604b 3bd1be9 17b604b 3bd1be9 17b604b 3bd1be9 17b604b 3bd1be9 17b604b 3bd1be9 17b604b 3bd1be9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 |
import torch
import os
import json
import re # <-- Import the regular expression module
from datetime import datetime
from tqdm import tqdm
from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM
from transformers import Qwen3Config, Qwen3ForCausalLM
from collections import Counter
# --- Helper Functions (Definitive Version) ---
def create_vocab_mapping(s_tok, t_tok):
s_vocab, t_vocab = s_tok.get_vocab(), t_tok.get_vocab()
s_tok_to_id = {t: i for t, i in s_vocab.items()}
mapping = {t_id: s_tok_to_id.get(t, -1) for t, t_id in t_vocab.items()}
matches = sum(1 for v in mapping.values() if v != -1)
print(f"Vocabulary overlap: {matches}/{len(t_vocab)} tokens ({matches/len(t_vocab)*100:.1f}%) will be transferred.")
return mapping
def verify_special_tokens(s_tok, t_tok, mapping):
print("\nVerifying special token mappings...")
for name, token_value in t_tok.special_tokens_map.items():
def _process_token(token_str):
if token_str and token_str in t_tok.get_vocab():
t_id = t_tok.convert_tokens_to_ids(token_str)
s_id = mapping.get(t_id, -1)
status = f"Mapped (T: {t_id} -> S: {s_id})" if s_id != -1 else "NOT FOUND in source (initialized with mean)"
print(f" ✓ ('{token_str}'): {status}")
if isinstance(token_value, str): _process_token(token_value)
elif isinstance(token_value, list):
for token_str_in_list in token_value: _process_token(token_str_in_list)
def create_hybrid_matrix(s_matrix, mapping, shape):
print(" -> Calculating mean embedding from source model for new token initialization...")
mean_embedding = s_matrix.mean(dim=0, keepdim=True)
hybrid = torch.zeros(shape, dtype=s_matrix.dtype, device='cpu')
for t_id, s_id in mapping.items():
hybrid[t_id] = s_matrix[s_id] if s_id != -1 else mean_embedding
return hybrid.to(s_matrix.device)
def save_config_diff(s_conf, t_conf, path):
s_dict, t_dict = s_conf.to_dict(), t_conf.to_dict()
diff = {'changed': {}, 'added': {}, 'removed': {}}
for k in set(s_dict.keys()) | set(t_dict.keys()):
if s_dict.get(k) != t_dict.get(k):
if k in s_dict and k in t_dict: diff['changed'][k] = {'from': s_dict[k], 'to': t_dict[k]}
elif k in t_dict: diff['added'][k] = t_dict[k]
else: diff['removed'][k] = s_dict[k]
with open(os.path.join(path, "config_diff.json"), "w") as f: json.dump(diff, f, indent=2)
def validate_model(path):
print("\n[Step 6/6] Validating final model (smoke test)...")
try:
tokenizer = AutoTokenizer.from_pretrained(path)
model = AutoModelForCausalLM.from_pretrained(path, device_map="auto", torch_dtype=torch.bfloat16)
model.eval()
prompt = "The theory of relativity states that"
print(f"\nValidation Prompt: '{prompt}'")
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
with torch.no_grad():
outputs = model.generate(**inputs, max_new_tokens=25, do_sample=False, pad_token_id=tokenizer.eos_token_id)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(f"Generated Response: '{response}'")
assert len(response) > len(prompt), "Model did not generate new tokens."
print("\n ✓ Validation successful: Model loads and generates coherent text using standard transformers.")
except Exception as e:
print(f"\n ✗ Validation FAILED: {e}")
# --- Main Conversion Logic ---
def convert_qwen2_to_qwen3_decoupled():
source_model_id, donor_model_id = "Qwen/Qwen2.5-72B-Instruct", "Qwen/Qwen3-32B"
target_model_path = "./Qwen3-72B-Instruct"
print("Starting DECOUPLED conversion process (v5.3)...")
# --- 1. Pre-flight Checks ---
print("\n[Step 1/6] Running pre-flight architectural checks...")
s_config = AutoConfig.from_pretrained(source_model_id)
d_config = AutoConfig.from_pretrained(donor_model_id)
assert s_config.hidden_act == d_config.hidden_act, f"FATAL: Hidden activation mismatch! Source: {s_config.hidden_act}, Donor: {d_config.hidden_act}."
print(" ✓ Hidden activation functions match.")
if s_config.rope_theta != d_config.rope_theta:
print(f" ✓ RoPE Theta: Using donor value {d_config.rope_theta} (source was {s_config.rope_theta})")
# --- 2. Load Models & Tokenizers using AutoModel ---
print("\n[Step 2/6] Loading models & tokenizers using standard AutoClasses...")
dtype = torch.bfloat16
s_model = AutoModelForCausalLM.from_pretrained(source_model_id, torch_dtype=dtype, device_map="auto")
d_model = AutoModelForCausalLM.from_pretrained(donor_model_id, torch_dtype=dtype, device_map="auto")
s_tokenizer = AutoTokenizer.from_pretrained(source_model_id)
t_tokenizer = AutoTokenizer.from_pretrained(donor_model_id)
# --- 3. Create Target Config & Initialize ---
print("\n[Step 3/6] Creating target Qwen3 72B config & initializing model shell...")
t_config = Qwen3Config(hidden_size=s_config.hidden_size, intermediate_size=s_config.intermediate_size, num_hidden_layers=s_config.num_hidden_layers, num_attention_heads=s_config.num_attention_heads, num_key_value_heads=s_config.num_key_value_heads, max_position_embeddings=s_config.max_position_embeddings, max_window_layers=s_config.max_window_layers, sliding_window=s_config.sliding_window, attention_bias=d_config.attention_bias, hidden_act=d_config.hidden_act, initializer_range=d_config.initializer_range, rms_norm_eps=d_config.rms_norm_eps, rope_theta=d_config.rope_theta, vocab_size=d_config.vocab_size, tie_word_embeddings=True)
with torch.device("meta"): t_model = Qwen3ForCausalLM(t_config)
# --- 4. Convert and Transfer Weights ---
print("\n[Step 4/6] Converting weights (memory-safe)...")
s_state_dict = {k: v.to('cpu', dtype=dtype) for k, v in tqdm(s_model.state_dict().items(), desc="Source state dict to CPU")}
d_state_dict = {k: v.to('cpu', dtype=dtype) for k, v in tqdm(d_model.state_dict().items(), desc="Donor state dict to CPU")}
vocab_mapping = create_vocab_mapping(s_tokenizer, t_tokenizer)
verify_special_tokens(s_tokenizer, t_tokenizer, vocab_mapping)
new_state_dict = {}
num_donor_layers = d_config.num_hidden_layers
for key in tqdm(t_model.state_dict().keys(), desc="Transferring weights"):
if "q_norm" in key or "k_norm" in key:
# --- FIX: Implement Cyclical Grafting for Norm Layers ---
match = re.search(r'layers\.(\d+)\.', key)
if match:
target_layer_idx = int(match.group(1))
donor_layer_idx = target_layer_idx % num_donor_layers
donor_key = key.replace(f'layers.{target_layer_idx}.', f'layers.{donor_layer_idx}.')
new_state_dict[key] = d_state_dict[donor_key].clone()
else:
print(f" ⚠️ Could not parse layer index for norm key: {key}. Skipping.")
elif "model.embed_tokens.weight" in key: new_state_dict[key] = create_hybrid_matrix(s_state_dict[key], vocab_mapping, (t_config.vocab_size, t_config.hidden_size))
elif "lm_head.weight" in key: new_state_dict[key] = create_hybrid_matrix(s_state_dict[key], vocab_mapping, (t_config.vocab_size, t_config.hidden_size))
elif key in s_state_dict: new_state_dict[key] = s_state_dict[key].clone()
else: print(f" ⚠️ Unhandled key: {key} (not in source, skipping)")
t_model.load_state_dict(new_state_dict, strict=True, assign=True)
t_model = t_model.to(dtype)
# --- 5. Save Final Model & Metadata ---
print("\n[Step 5/6] Saving final model and supporting files...")
if not os.path.exists(target_model_path): os.makedirs(target_model_path)
t_model.save_pretrained(target_model_path, safe_serialization=True)
t_tokenizer.save_pretrained(target_model_path)
save_config_diff(s_config, t_config, target_model_path)
metadata = {"conversion_date_utc": datetime.utcnow().isoformat(), "source_model": source_model_id, "donor_model": donor_model_id,
"warnings": ["This is a community-created model merge. Its behavior may be unpredictable.", "Sliding window config inherited from Qwen2.5 with Qwen3 RoPE theta - long context behavior MUST be validated.", "Post-conversion evaluation is highly recommended for numerical stability, quantization, and safety alignment."]}
with open(os.path.join(target_model_path, "conversion_metadata.json"), "w") as f: json.dump(metadata, f, indent=2)
print(f"✅ Model saved to: {target_model_path}")
# --- 6. Final Validation ---
del s_model, d_model, s_state_dict, d_state_dict, new_state_dict, t_model
torch.cuda.empty_cache()
validate_model(target_model_path)
if __name__ == "__main__":
convert_qwen2_to_qwen3_decoupled() |