File size: 2,216 Bytes
7e1725c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 |
# remove_biases.py
import torch
import os
import argparse
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer
def main(source_model_id, output_path):
"""
Loads a model, removes all tensors ending in '.bias', and saves the result.
"""
print(f"Loading source donor model: {source_model_id}")
# Load on CPU to avoid using VRAM
model = AutoModelForCausalLM.from_pretrained(
source_model_id,
torch_dtype=torch.bfloat16,
device_map="cpu",
trust_remote_code=True
)
tokenizer = AutoTokenizer.from_pretrained(source_model_id, trust_remote_code=True)
source_state_dict = model.state_dict()
new_state_dict = {}
print("Removing all '.bias' tensors...")
removed_count = 0
for name, tensor in tqdm(source_state_dict.items(), desc="Processing Tensors"):
if name.endswith(".bias"):
removed_count += 1
continue # Skip this tensor
new_state_dict[name] = tensor
print(f"Removed {removed_count} bias tensors.")
# We don't need to create a new model from config, as the architecture is
# a subset of the original. We can load the new state dict with strict=False.
print("Loading the no-bias state dict back into the model...")
model.load_state_dict(new_state_dict, strict=False)
print(f"Saving the no-bias model and tokenizer to: {output_path}")
os.makedirs(output_path, exist_ok=True)
model.save_pretrained(output_path)
tokenizer.save_pretrained(output_path)
print("\nPhase 1b (No-Bias Donor Creation) Complete!")
print(f"The no-bias donor is ready at '{output_path}'.")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Remove bias tensors from a model.")
parser.add_argument("--source_model", type=str, default="Qwen/Qwen2.5-72B-Instruct", help="The Hugging Face model ID of the source model.")
parser.add_argument("--output_path", type=str, required=True, help="The local directory path to save the no-bias model.")
args = parser.parse_args()
# Example: python remove_biases.py --output_path ./Qwen2.5-72B-Instruct-NoBias
main(args.source_model, args.output_path)
|