File size: 3,819 Bytes
7e1725c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
# prepare_donor_v6.py
import torch
import os
import argparse
from tqdm import tqdm
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
from accelerate import init_empty_weights

def main(foundation_model_id, donor_model_id, output_path):
    """
    Creates the definitive 'Aligned' donor model by manually handling all architectural mismatches.
    This version MANUALLY INSTANTIATES new tensors to defeat memory sharing optimizations.
    """
    print("--- Phase 1: Building the target Qwen3 80-Layer Architecture ---")
    
    foundation_config = AutoConfig.from_pretrained(foundation_model_id, trust_remote_code=True)
    
    target_config = foundation_config
    target_config.num_hidden_layers = 80
    target_config.hidden_size = 8192
    target_config.intermediate_size = 29568
    target_config.vocab_size = 151936
    target_config.torch_dtype = torch.bfloat16
    
    print("Creating empty Qwen3 80-layer model shell...")
    with init_empty_weights():
        aligned_model = AutoModelForCausalLM.from_config(target_config, trust_remote_code=True)
    print("Empty shell created successfully.")

    print("\n--- Phase 2: Loading and Manually Aligning Donor Weights ---")
    print(f"Loading weights from donor: {donor_model_id}")
    
    donor_state_dict = AutoModelForCausalLM.from_pretrained(
        donor_model_id, torch_dtype=torch.bfloat16, device_map="cpu", trust_remote_code=True
    ).state_dict()

    target_state_dict = aligned_model.state_dict()
    new_state_dict = {}

    print("Copying and aligning tensors one-by-one...")
    for name, target_tensor in tqdm(target_state_dict.items(), desc="Aligning Tensors"):
        if name in donor_state_dict:
            # This logic is for tensors that exist in the donor.
            donor_tensor = donor_state_dict[name]
            if donor_tensor.shape == target_tensor.shape:
                new_state_dict[name] = donor_tensor.clone()
            else: # Vocab mismatch case
                vocab_dim = target_tensor.shape[0]
                new_state_dict[name] = donor_tensor[:vocab_dim, :].clone()
        else:
            # --- THIS IS THE FINAL FIX ---
            # This logic is for tensors NOT in the donor (q_norm, k_norm).
            # We will not use the shared `target_tensor`. Instead, we create a new
            # unique tensor of the correct shape and value for each one.
            new_state_dict[name] = torch.ones(target_tensor.shape, dtype=torch.bfloat16)
    
    print("Loading the fully aligned state_dict into the Qwen3 shell...")
    aligned_model.load_state_dict(new_state_dict, strict=True, assign=True)
    
    # Tie weights *after* all unique tensors are loaded.
    aligned_model.tie_weights()

    print("\n--- Phase 3: Saving the Aligned Donor ---")
    tokenizer = AutoTokenizer.from_pretrained(foundation_model_id, trust_remote_code=True)
    
    print(f"Saving the architecturally aligned model to: {output_path}")
    os.makedirs(output_path, exist_ok=True)
    aligned_model.save_pretrained(output_path)
    tokenizer.save_pretrained(output_path)

    print("\nDonor preparation complete! This is the definitive donor model.")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Prepare a Qwen2.5 donor model for merging with Qwen3.")
    parser.add_argument("--foundation_model", type=str, default="Qwen/Qwen3-32B", help="Model to use for the Qwen3 architecture blueprint.")
    parser.add_argument("--donor_model", type=str, default="Qwen/Qwen2.5-72B-Instruct", help="The donor model providing the weights.")
    parser.add_argument("--output_path", type=str, required=True, help="The local directory path to save the prepared donor model.")
    args = parser.parse_args()
    
    main(args.foundation_model, args.donor_model, args.output_path)