# Copyright 2020-2025 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # 0. imports import os from dataclasses import dataclass, field from typing import Optional import torch from accelerate import Accelerator from datasets import Dataset, load_dataset from peft import LoraConfig from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser, set_seed from trl import DPOConfig, DPOTrainer # Define and parse arguments. @dataclass class ScriptArguments: """ The arguments for the DPO training script. """ # data parameters beta: Optional[float] = field(default=0.1, metadata={"help": "the beta parameter for DPO loss"}) # training parameters model_name_or_path: Optional[str] = field( default="../sft/results/final_checkpoint", metadata={"help": "the location of the SFT model name or path"}, ) learning_rate: Optional[float] = field(default=5e-4, metadata={"help": "optimizer learning rate"}) lr_scheduler_type: Optional[str] = field(default="cosine", metadata={"help": "the lr scheduler type"}) warmup_steps: Optional[int] = field(default=100, metadata={"help": "the number of warmup steps"}) weight_decay: Optional[float] = field(default=0.05, metadata={"help": "the weight decay"}) optimizer_type: Optional[str] = field(default="paged_adamw_32bit", metadata={"help": "the optimizer type"}) per_device_train_batch_size: Optional[int] = field(default=4, metadata={"help": "train batch size per device"}) per_device_eval_batch_size: Optional[int] = field(default=1, metadata={"help": "eval batch size per device"}) gradient_accumulation_steps: Optional[int] = field( default=4, metadata={"help": "the number of gradient accumulation steps"} ) gradient_checkpointing: Optional[bool] = field( default=True, metadata={"help": "whether to use gradient checkpointing"} ) gradient_checkpointing_use_reentrant: Optional[bool] = field( default=False, metadata={"help": "whether to use reentrant for gradient checkpointing"} ) lora_alpha: Optional[float] = field(default=16, metadata={"help": "the lora alpha parameter"}) lora_dropout: Optional[float] = field(default=0.05, metadata={"help": "the lora dropout parameter"}) lora_r: Optional[int] = field(default=8, metadata={"help": "the lora r parameter"}) max_prompt_length: Optional[int] = field(default=512, metadata={"help": "the maximum prompt length"}) max_length: Optional[int] = field(default=1024, metadata={"help": "the maximum sequence length"}) max_steps: Optional[int] = field(default=1000, metadata={"help": "max number of training steps"}) logging_steps: Optional[int] = field(default=10, metadata={"help": "the logging frequency"}) save_steps: Optional[int] = field(default=100, metadata={"help": "the saving frequency"}) eval_steps: Optional[int] = field(default=100, metadata={"help": "the evaluation frequency"}) output_dir: Optional[str] = field(default="./results", metadata={"help": "the output directory"}) log_freq: Optional[int] = field(default=1, metadata={"help": "the logging frequency"}) load_in_4bit: Optional[bool] = field(default=True, metadata={"help": "whether to load the model in 4bit"}) model_dtype: Optional[str] = field( default="float16", metadata={"help": "model_dtype[float16, bfloat16, float] for loading."} ) # instrumentation report_to: Optional[str] = field( default="wandb", metadata={ "help": 'The list of integrations to report the results and logs to. Supported platforms are `"azure_ml"`,' '`"comet_ml"`, `"mlflow"`, `"neptune"`, `"tensorboard"`,`"clearml"` and `"wandb"`. ' 'Use `"all"` to report to all integrations installed, `"none"` for no integrations.' }, ) # debug argument for distributed training ignore_bias_buffers: Optional[bool] = field( default=False, metadata={ "help": "fix for DDP issues with LM bias/mask buffers - invalid scalar type,`inplace operation. See" "https://github.com/huggingface/transformers/issues/22482#issuecomment-1595790992" }, ) seed: Optional[int] = field( default=0, metadata={"help": "Random seed that will be set at the beginning of training."} ) def get_stack_exchange_paired( data_dir: str = "data/rl", cache_dir: Optional[str] = None, num_proc=24, ) -> Dataset: """Load the stack-exchange-paired dataset from Hugging Face and convert it to the necessary format. The dataset is converted to a dictionary with the following structure: { 'prompt': list[str], 'chosen': list[str], 'rejected': list[str], } Prompts are structured as follows: "Question: " + + "\n\nAnswer: " """ dataset = load_dataset( "lvwerra/stack-exchange-paired", split="train", cache_dir=cache_dir, data_dir=data_dir, verification_mode="no_checks", ) original_columns = dataset.column_names def return_prompt_and_responses(samples) -> dict[str, str]: return { "prompt": ["Question: " + question + "\n\nAnswer: " for question in samples["question"]], "chosen": samples["response_j"], "rejected": samples["response_k"], } return dataset.map( return_prompt_and_responses, batched=True, num_proc=num_proc, remove_columns=original_columns, ) if __name__ == "__main__": parser = HfArgumentParser(ScriptArguments) script_args = parser.parse_args_into_dataclasses()[0] set_seed(script_args.seed) # 1. load a pretrained model torch_dtype = torch.float if script_args.model_dtype == "float16": torch_dtype = torch.float16 elif script_args.model_dtype == "bfloat16": torch_dtype = torch.bfloat16 model = AutoModelForCausalLM.from_pretrained( script_args.model_name_or_path, low_cpu_mem_usage=True, torch_dtype=torch_dtype, load_in_4bit=script_args.load_in_4bit, device_map={"": Accelerator().local_process_index}, ) model.config.use_cache = False if script_args.ignore_bias_buffers: # torch distributed hack model._ddp_params_and_buffers_to_ignore = [ name for name, buffer in model.named_buffers() if buffer.dtype == torch.bool ] tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf") tokenizer.pad_token = tokenizer.eos_token # 2. Load the Stack-exchange paired dataset train_dataset = get_stack_exchange_paired(data_dir="data/rl") train_dataset = train_dataset.filter( lambda x: len(x["prompt"]) + len(x["chosen"]) <= script_args.max_length and len(x["prompt"]) + len(x["rejected"]) <= script_args.max_length, num_proc=script_args.num_proc, ) # 3. Load evaluation dataset eval_dataset = get_stack_exchange_paired(data_dir="data/evaluation") eval_dataset = eval_dataset.filter( lambda x: len(x["prompt"]) + len(x["chosen"]) <= script_args.max_length and len(x["prompt"]) + len(x["rejected"]) <= script_args.max_length, num_proc=script_args.num_proc, ) # 4. initialize training arguments: training_args = DPOConfig( per_device_train_batch_size=script_args.per_device_train_batch_size, per_device_eval_batch_size=script_args.per_device_eval_batch_size, max_steps=script_args.max_steps, logging_steps=script_args.logging_steps, save_steps=script_args.save_steps, gradient_accumulation_steps=script_args.gradient_accumulation_steps, gradient_checkpointing=script_args.gradient_checkpointing, learning_rate=script_args.learning_rate, eval_strategy="steps", eval_steps=script_args.eval_steps, output_dir=script_args.output_dir, report_to=script_args.report_to, lr_scheduler_type=script_args.lr_scheduler_type, warmup_steps=script_args.warmup_steps, optim=script_args.optimizer_type, bf16=True, remove_unused_columns=False, run_name="dpo_llama2", gradient_checkpointing_kwargs=dict(use_reentrant=script_args.gradient_checkpointing_use_reentrant), seed=script_args.seed, ) peft_config = LoraConfig( r=script_args.lora_r, lora_alpha=script_args.lora_alpha, lora_dropout=script_args.lora_dropout, target_modules=[ "q_proj", "v_proj", "k_proj", "out_proj", "fc_in", "fc_out", "wte", ], bias="none", task_type="CAUSAL_LM", ) # 5. initialize the DPO trainer dpo_trainer = DPOTrainer( model, ref_model=None, args=training_args, beta=script_args.beta, train_dataset=train_dataset, eval_dataset=eval_dataset, processing_class=tokenizer, peft_config=peft_config, max_prompt_length=script_args.max_prompt_length, max_length=script_args.max_length, ) # 6. train dpo_trainer.train() dpo_trainer.save_model(script_args.output_dir) # 7. save output_dir = os.path.join(script_args.output_dir, "final_checkpoint") dpo_trainer.model.save_pretrained(output_dir)