# Copyright 2020-2025 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import random import textwrap from copy import deepcopy from typing import Any, Callable, Optional, Union import torch import torch.nn as nn import torch.nn.functional as F from accelerate.utils import is_deepspeed_available from datasets import Dataset from transformers import ( AutoModelForCausalLM, BaseImageProcessor, DataCollator, FeatureExtractionMixin, GenerationConfig, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, is_wandb_available, ) from transformers.trainer_callback import TrainerCallback from transformers.trainer_utils import EvalPrediction from transformers.utils import is_peft_available from ..models import PreTrainedModelWrapper from ..models.utils import unwrap_model_for_generation from .gkd_config import GKDConfig from .sft_trainer import SFTTrainer from .utils import ( DataCollatorForChatML, disable_dropout_in_model, empty_cache, generate_model_card, get_comet_experiment_url, ) if is_deepspeed_available(): import deepspeed if is_peft_available(): from peft import PeftConfig if is_wandb_available(): import wandb class GKDTrainer(SFTTrainer): _tag_names = ["trl", "gkd"] def __init__( self, model: Optional[Union[PreTrainedModel, nn.Module, str]] = None, teacher_model: Union[PreTrainedModel, nn.Module, str] = None, args: Optional[GKDConfig] = None, data_collator: Optional[DataCollator] = None, # type: ignore train_dataset: Optional[Dataset] = None, eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None, processing_class: Optional[ Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin] ] = None, compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None, callbacks: Optional[list[TrainerCallback]] = None, optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None), preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None, peft_config: Optional["PeftConfig"] = None, formatting_func: Optional[Callable] = None, ): # add remove_unused_columns=False to the dataclass args args.remove_unused_columns = False data_collator = DataCollatorForChatML(tokenizer=processing_class, max_length=args.max_length) super().__init__( model, args=args, data_collator=data_collator, train_dataset=train_dataset, eval_dataset=eval_dataset, processing_class=processing_class, compute_metrics=compute_metrics, callbacks=callbacks, optimizers=optimizers, preprocess_logits_for_metrics=preprocess_logits_for_metrics, peft_config=peft_config, formatting_func=formatting_func, ) if args.teacher_model_init_kwargs is None: teacher_model_init_kwargs = {} elif not isinstance(teacher_model, str): raise ValueError( "You passed teacher_model_init_kwargs to the GKDConfig, but your teacher_model is already instantiated." ) else: teacher_model_init_kwargs = args.teacher_model_init_kwargs teacher_model_init_kwargs["torch_dtype"] = ( teacher_model_init_kwargs["torch_dtype"] if teacher_model_init_kwargs["torch_dtype"] in ["auto", None] else getattr(torch, teacher_model_init_kwargs["torch_dtype"]) ) if isinstance(teacher_model, str): teacher_model = AutoModelForCausalLM.from_pretrained(teacher_model, **teacher_model_init_kwargs) # Disable dropout in the model if args.disable_dropout: disable_dropout_in_model(self.model) if self.is_deepspeed_enabled: self.teacher_model = self._prepare_deepspeed(teacher_model) else: self.teacher_model = self.accelerator.prepare_model(teacher_model, evaluation_mode=True) self.lmbda = args.lmbda self.beta = args.beta self.temperature = args.temperature self.seq_kd = args.seq_kd self.generation_config = GenerationConfig( max_new_tokens=args.max_new_tokens, temperature=args.temperature, do_sample=True, top_k=0, use_cache=False if args.gradient_checkpointing else True, pad_token_id=self.processing_class.pad_token_id, ) # Set custom EOS tokens if they are specified by the model's generation # config. This is important for models with the Llama 3 chat template, # which use special tokens <|eot_id|> and <|eom_id|> to mark the end of # turns or messages. if ( hasattr(self.model.generation_config, "eos_token_id") and self.model.generation_config.eos_token_id is not None ): self.generation_config.eos_token_id = self.model.generation_config.eos_token_id def _prepare_dataset(self, dataset, *args): # SFTTrainer._prepare_dataset() applies the chat template and rename the messages column to text. However, we # need to keep the messages column as it is. We use the following workaround to keep the messages column. dataset = dataset.add_column("_messages", dataset["messages"]) dataset = super()._prepare_dataset(dataset, *args) dataset = dataset.rename_column("_messages", "messages") return dataset @staticmethod def generalized_jsd_loss( student_logits, teacher_logits, labels=None, beta=0.5, temperature=1.0, reduction="batchmean" ): """ Compute the generalized Jensen-Shannon Divergence loss for knowledge distillation using F.kl_div. See Eq. (1) of https://huggingface.co/papers/2306.13649 for the definition. Args: student_logits: Tensor of shape (batch_size, sequence_length, vocab_size) teacher_logits: Tensor of shape (batch_size, sequence_length, vocab_size) labels: Tensor of shape (batch_size, sequence_length) with -100 for padding tokens to ignore when computing loss beta: Interpolation coefficient between 0 and 1 (default: 0.5) temperature: Softmax temperature (default: 1.0) reduction: Specifies the reduction to apply to the output (default: 'batchmean') Returns: loss: Scalar tensor with the generalized JSD loss """ # Apply temperature scaling student_logits = student_logits / temperature teacher_logits = teacher_logits / temperature # Compute log probabilities for student and probabilities for teacher student_log_probs = F.log_softmax(student_logits, dim=-1) teacher_log_probs = F.log_softmax(teacher_logits, dim=-1) if beta == 0: jsd = F.kl_div(student_log_probs, teacher_log_probs, reduction="none", log_target=True) elif beta == 1: jsd = F.kl_div(teacher_log_probs, student_log_probs, reduction="none", log_target=True) else: # Compute the log of the mixture distribution # log(a + b) = log(exp(log(a)) + exp(log(b))) -> for mixture beta = torch.tensor(beta, dtype=student_log_probs.dtype) mixture_log_probs = torch.logsumexp( torch.stack([student_log_probs + torch.log(1 - beta), teacher_log_probs + torch.log(beta)]), dim=0, ) # Compute KL divergences using F.kl_div # PyTorch differs from the standard mathematical definition, so the order of the probability distributions is swapped compared to that defined in the paper. kl_teacher = F.kl_div(mixture_log_probs, teacher_log_probs, reduction="none", log_target=True) kl_student = F.kl_div(mixture_log_probs, student_log_probs, reduction="none", log_target=True) # Compute the Generalized Jensen-Shannon Divergence jsd = beta * kl_teacher + (1 - beta) * kl_student # Masking if labels is not None: mask = labels != -100 jsd = jsd[mask] # Apply reduction if reduction == "batchmean": return jsd.sum() / mask.sum() if labels is not None else jsd.sum() / (jsd.size(0) * jsd.size(1)) elif reduction == "sum": return jsd.sum() elif reduction == "mean": return jsd.mean() else: return jsd def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): # compute student output outputs_student = model( input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], ) # compute teacher output in eval mode self.teacher_model.eval() with torch.no_grad(): outputs_teacher = self.teacher_model( input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], ) # slice the logits for the generated tokens using the inputs["prompts"] lengths prompt_lengths = inputs["prompts"].shape[1] shifted_student_logits = outputs_student.logits[:, prompt_lengths - 1 : -1, :] shifted_teacher_logits = outputs_teacher.logits[:, prompt_lengths - 1 : -1, :] shifted_labels = inputs["labels"][:, prompt_lengths:] # compute loss loss = self.generalized_jsd_loss( student_logits=shifted_student_logits, teacher_logits=shifted_teacher_logits, labels=shifted_labels, beta=self.beta, ) # empty cache empty_cache() # Return loss return (loss, outputs_student) if return_outputs else loss @staticmethod def generate_on_policy_outputs(model, inputs, generation_config, pad_token_id=None): # Generate output with respect to the prompt only generated_outputs = model.generate( input_ids=inputs["prompts"], attention_mask=inputs.get("prompt_attention_mask", None), generation_config=generation_config, return_dict_in_generate=True, ) # Get the generated token IDs generated_tokens = generated_outputs.sequences # Calculate new attention mask new_attention_mask = torch.ones_like(generated_tokens) new_labels = generated_tokens.clone() # If there's pad_token_id, set attention mask to 0 for padding tokens if pad_token_id is not None: new_labels[new_labels == pad_token_id] = -100 new_attention_mask[generated_tokens == pad_token_id] = 0 return generated_tokens, new_attention_mask, new_labels def training_step( self, model: nn.Module, inputs: dict[str, Union[torch.Tensor, Any]], num_items_in_batch: Optional[int] = None ) -> torch.Tensor: """ Perform a training step for the Generalized Knowledge Distillation (GKD) model. This method implements the on-policy learning approach described in the GKD paper. With probability `self.lmbda`, it generates new responses using the student model, which are then used for training instead of the original inputs. """ if self.seq_kd: with unwrap_model_for_generation(self.teacher_model, self.accelerator) as unwrapped_model: new_input_ids, new_attention_mask, new_labels = self.generate_on_policy_outputs( unwrapped_model, inputs, self.generation_config, self.processing_class.pad_token_id ) inputs["input_ids"] = new_input_ids inputs["attention_mask"] = new_attention_mask inputs["labels"] = new_labels if random.random() <= self.lmbda: with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model: new_input_ids, new_attention_mask, new_labels = self.generate_on_policy_outputs( unwrapped_model, inputs, self.generation_config, self.processing_class.pad_token_id ) inputs["input_ids"] = new_input_ids inputs["attention_mask"] = new_attention_mask inputs["labels"] = new_labels loss = super().training_step(model, inputs, num_items_in_batch) return loss def _prepare_deepspeed(self, model: PreTrainedModelWrapper): # Adapted from accelerate: https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473 deepspeed_plugin = self.accelerator.state.deepspeed_plugin config_kwargs = deepcopy(deepspeed_plugin.deepspeed_config) if model is not None: if hasattr(model, "config"): hidden_size = ( max(model.config.hidden_sizes) if getattr(model.config, "hidden_sizes", None) else getattr(model.config, "hidden_size", None) ) if hidden_size is not None and config_kwargs["zero_optimization"]["stage"] == 3: # Note that `stage3_prefetch_bucket_size` can produce DeepSpeed messages like: `Invalidate trace cache @ step 0: expected module 1, but got module 0` # This is expected and is not an error, see: https://github.com/microsoft/DeepSpeed/discussions/4081 config_kwargs.update( { "zero_optimization.reduce_bucket_size": hidden_size * hidden_size, "zero_optimization.stage3_param_persistence_threshold": 10 * hidden_size, "zero_optimization.stage3_prefetch_bucket_size": 0.9 * hidden_size * hidden_size, } ) # If ZeRO-3 is used, we shard both the active and reference model. # Otherwise, we assume the reference model fits in memory and is initialized on each device with ZeRO disabled (stage 0) if config_kwargs["zero_optimization"]["stage"] != 3: config_kwargs["zero_optimization"]["stage"] = 0 model, *_ = deepspeed.initialize(model=model, config=config_kwargs) model.eval() return model def create_model_card( self, model_name: Optional[str] = None, dataset_name: Optional[str] = None, tags: Union[str, list[str], None] = None, ): """ Creates a draft of a model card using the information available to the `Trainer`. Args: model_name (`str` or `None`, *optional*, defaults to `None`): Name of the model. dataset_name (`str` or `None`, *optional*, defaults to `None`): Name of the dataset used for training. tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`): Tags to be associated with the model card. """ if not self.is_world_process_zero(): return if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path): base_model = self.model.config._name_or_path else: base_model = None tags = tags or [] if isinstance(tags, str): tags = [tags] if hasattr(self.model.config, "unsloth_version"): tags.append("unsloth") citation = textwrap.dedent("""\ @inproceedings{agarwal2024on-policy, title = {{On-Policy Distillation of Language Models: Learning from Self-Generated Mistakes}}, author = {Rishabh Agarwal and Nino Vieillard and Yongchao Zhou and Piotr Stanczyk and Sabela Ramos Garea and Matthieu Geist and Olivier Bachem}, year = 2024, booktitle = {The Twelfth International Conference on Learning Representations, {ICLR} 2024, Vienna, Austria, May 7-11, 2024}, publisher = {OpenReview.net}, url = {https://openreview.net/forum?id=3zKtaqxLhW}, }""") model_card = generate_model_card( base_model=base_model, model_name=model_name, hub_model_id=self.hub_model_id, dataset_name=dataset_name, tags=tags, wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None, comet_url=get_comet_experiment_url(), trainer_name="GKD", trainer_citation=citation, paper_title="On-Policy Distillation of Language Models: Learning from Self-Generated Mistakes", paper_id="2306.13649", ) model_card.save(os.path.join(self.args.output_dir, "README.md"))