|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
import random |
|
import textwrap |
|
from copy import deepcopy |
|
from typing import Any, Callable, Optional, Union |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from accelerate.utils import is_deepspeed_available |
|
from datasets import Dataset |
|
from transformers import ( |
|
AutoModelForCausalLM, |
|
BaseImageProcessor, |
|
DataCollator, |
|
FeatureExtractionMixin, |
|
GenerationConfig, |
|
PreTrainedModel, |
|
PreTrainedTokenizerBase, |
|
ProcessorMixin, |
|
is_wandb_available, |
|
) |
|
from transformers.trainer_callback import TrainerCallback |
|
from transformers.trainer_utils import EvalPrediction |
|
from transformers.utils import is_peft_available |
|
|
|
from ..models import PreTrainedModelWrapper |
|
from ..models.utils import unwrap_model_for_generation |
|
from .gkd_config import GKDConfig |
|
from .sft_trainer import SFTTrainer |
|
from .utils import ( |
|
DataCollatorForChatML, |
|
disable_dropout_in_model, |
|
empty_cache, |
|
generate_model_card, |
|
get_comet_experiment_url, |
|
) |
|
|
|
|
|
if is_deepspeed_available(): |
|
import deepspeed |
|
|
|
|
|
if is_peft_available(): |
|
from peft import PeftConfig |
|
|
|
if is_wandb_available(): |
|
import wandb |
|
|
|
|
|
class GKDTrainer(SFTTrainer): |
|
_tag_names = ["trl", "gkd"] |
|
|
|
def __init__( |
|
self, |
|
model: Optional[Union[PreTrainedModel, nn.Module, str]] = None, |
|
teacher_model: Union[PreTrainedModel, nn.Module, str] = None, |
|
args: Optional[GKDConfig] = None, |
|
data_collator: Optional[DataCollator] = None, |
|
train_dataset: Optional[Dataset] = None, |
|
eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None, |
|
processing_class: Optional[ |
|
Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin] |
|
] = None, |
|
compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None, |
|
callbacks: Optional[list[TrainerCallback]] = None, |
|
optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None), |
|
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None, |
|
peft_config: Optional["PeftConfig"] = None, |
|
formatting_func: Optional[Callable] = None, |
|
): |
|
|
|
args.remove_unused_columns = False |
|
data_collator = DataCollatorForChatML(tokenizer=processing_class, max_length=args.max_length) |
|
|
|
super().__init__( |
|
model, |
|
args=args, |
|
data_collator=data_collator, |
|
train_dataset=train_dataset, |
|
eval_dataset=eval_dataset, |
|
processing_class=processing_class, |
|
compute_metrics=compute_metrics, |
|
callbacks=callbacks, |
|
optimizers=optimizers, |
|
preprocess_logits_for_metrics=preprocess_logits_for_metrics, |
|
peft_config=peft_config, |
|
formatting_func=formatting_func, |
|
) |
|
|
|
if args.teacher_model_init_kwargs is None: |
|
teacher_model_init_kwargs = {} |
|
elif not isinstance(teacher_model, str): |
|
raise ValueError( |
|
"You passed teacher_model_init_kwargs to the GKDConfig, but your teacher_model is already instantiated." |
|
) |
|
else: |
|
teacher_model_init_kwargs = args.teacher_model_init_kwargs |
|
teacher_model_init_kwargs["torch_dtype"] = ( |
|
teacher_model_init_kwargs["torch_dtype"] |
|
if teacher_model_init_kwargs["torch_dtype"] in ["auto", None] |
|
else getattr(torch, teacher_model_init_kwargs["torch_dtype"]) |
|
) |
|
|
|
if isinstance(teacher_model, str): |
|
teacher_model = AutoModelForCausalLM.from_pretrained(teacher_model, **teacher_model_init_kwargs) |
|
|
|
|
|
if args.disable_dropout: |
|
disable_dropout_in_model(self.model) |
|
|
|
if self.is_deepspeed_enabled: |
|
self.teacher_model = self._prepare_deepspeed(teacher_model) |
|
else: |
|
self.teacher_model = self.accelerator.prepare_model(teacher_model, evaluation_mode=True) |
|
|
|
self.lmbda = args.lmbda |
|
self.beta = args.beta |
|
self.temperature = args.temperature |
|
self.seq_kd = args.seq_kd |
|
|
|
self.generation_config = GenerationConfig( |
|
max_new_tokens=args.max_new_tokens, |
|
temperature=args.temperature, |
|
do_sample=True, |
|
top_k=0, |
|
use_cache=False if args.gradient_checkpointing else True, |
|
pad_token_id=self.processing_class.pad_token_id, |
|
) |
|
|
|
|
|
|
|
|
|
if ( |
|
hasattr(self.model.generation_config, "eos_token_id") |
|
and self.model.generation_config.eos_token_id is not None |
|
): |
|
self.generation_config.eos_token_id = self.model.generation_config.eos_token_id |
|
|
|
def _prepare_dataset(self, dataset, *args): |
|
|
|
|
|
dataset = dataset.add_column("_messages", dataset["messages"]) |
|
dataset = super()._prepare_dataset(dataset, *args) |
|
dataset = dataset.rename_column("_messages", "messages") |
|
return dataset |
|
|
|
@staticmethod |
|
def generalized_jsd_loss( |
|
student_logits, teacher_logits, labels=None, beta=0.5, temperature=1.0, reduction="batchmean" |
|
): |
|
""" |
|
Compute the generalized Jensen-Shannon Divergence loss for knowledge distillation using F.kl_div. See Eq. (1) |
|
of https://huggingface.co/papers/2306.13649 for the definition. |
|
|
|
Args: |
|
student_logits: Tensor of shape (batch_size, sequence_length, vocab_size) |
|
teacher_logits: Tensor of shape (batch_size, sequence_length, vocab_size) |
|
labels: Tensor of shape (batch_size, sequence_length) with -100 for padding tokens to ignore when computing loss |
|
beta: Interpolation coefficient between 0 and 1 (default: 0.5) |
|
temperature: Softmax temperature (default: 1.0) |
|
reduction: Specifies the reduction to apply to the output (default: 'batchmean') |
|
|
|
Returns: |
|
loss: Scalar tensor with the generalized JSD loss |
|
""" |
|
|
|
|
|
student_logits = student_logits / temperature |
|
teacher_logits = teacher_logits / temperature |
|
|
|
|
|
student_log_probs = F.log_softmax(student_logits, dim=-1) |
|
teacher_log_probs = F.log_softmax(teacher_logits, dim=-1) |
|
|
|
if beta == 0: |
|
jsd = F.kl_div(student_log_probs, teacher_log_probs, reduction="none", log_target=True) |
|
elif beta == 1: |
|
jsd = F.kl_div(teacher_log_probs, student_log_probs, reduction="none", log_target=True) |
|
else: |
|
|
|
|
|
beta = torch.tensor(beta, dtype=student_log_probs.dtype) |
|
mixture_log_probs = torch.logsumexp( |
|
torch.stack([student_log_probs + torch.log(1 - beta), teacher_log_probs + torch.log(beta)]), |
|
dim=0, |
|
) |
|
|
|
|
|
|
|
kl_teacher = F.kl_div(mixture_log_probs, teacher_log_probs, reduction="none", log_target=True) |
|
kl_student = F.kl_div(mixture_log_probs, student_log_probs, reduction="none", log_target=True) |
|
|
|
|
|
jsd = beta * kl_teacher + (1 - beta) * kl_student |
|
|
|
|
|
if labels is not None: |
|
mask = labels != -100 |
|
jsd = jsd[mask] |
|
|
|
|
|
if reduction == "batchmean": |
|
return jsd.sum() / mask.sum() if labels is not None else jsd.sum() / (jsd.size(0) * jsd.size(1)) |
|
elif reduction == "sum": |
|
return jsd.sum() |
|
elif reduction == "mean": |
|
return jsd.mean() |
|
else: |
|
return jsd |
|
|
|
def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): |
|
|
|
outputs_student = model( |
|
input_ids=inputs["input_ids"], |
|
attention_mask=inputs["attention_mask"], |
|
) |
|
|
|
|
|
self.teacher_model.eval() |
|
with torch.no_grad(): |
|
outputs_teacher = self.teacher_model( |
|
input_ids=inputs["input_ids"], |
|
attention_mask=inputs["attention_mask"], |
|
) |
|
|
|
|
|
prompt_lengths = inputs["prompts"].shape[1] |
|
shifted_student_logits = outputs_student.logits[:, prompt_lengths - 1 : -1, :] |
|
shifted_teacher_logits = outputs_teacher.logits[:, prompt_lengths - 1 : -1, :] |
|
shifted_labels = inputs["labels"][:, prompt_lengths:] |
|
|
|
|
|
loss = self.generalized_jsd_loss( |
|
student_logits=shifted_student_logits, |
|
teacher_logits=shifted_teacher_logits, |
|
labels=shifted_labels, |
|
beta=self.beta, |
|
) |
|
|
|
|
|
empty_cache() |
|
|
|
|
|
return (loss, outputs_student) if return_outputs else loss |
|
|
|
@staticmethod |
|
def generate_on_policy_outputs(model, inputs, generation_config, pad_token_id=None): |
|
|
|
generated_outputs = model.generate( |
|
input_ids=inputs["prompts"], |
|
attention_mask=inputs.get("prompt_attention_mask", None), |
|
generation_config=generation_config, |
|
return_dict_in_generate=True, |
|
) |
|
|
|
|
|
generated_tokens = generated_outputs.sequences |
|
|
|
new_attention_mask = torch.ones_like(generated_tokens) |
|
new_labels = generated_tokens.clone() |
|
|
|
|
|
if pad_token_id is not None: |
|
new_labels[new_labels == pad_token_id] = -100 |
|
new_attention_mask[generated_tokens == pad_token_id] = 0 |
|
|
|
return generated_tokens, new_attention_mask, new_labels |
|
|
|
def training_step( |
|
self, model: nn.Module, inputs: dict[str, Union[torch.Tensor, Any]], num_items_in_batch: Optional[int] = None |
|
) -> torch.Tensor: |
|
""" |
|
Perform a training step for the Generalized Knowledge Distillation (GKD) model. |
|
|
|
This method implements the on-policy learning approach described in the GKD paper. |
|
With probability `self.lmbda`, it generates new responses using the student model, |
|
which are then used for training instead of the original inputs. |
|
""" |
|
if self.seq_kd: |
|
with unwrap_model_for_generation(self.teacher_model, self.accelerator) as unwrapped_model: |
|
new_input_ids, new_attention_mask, new_labels = self.generate_on_policy_outputs( |
|
unwrapped_model, inputs, self.generation_config, self.processing_class.pad_token_id |
|
) |
|
inputs["input_ids"] = new_input_ids |
|
inputs["attention_mask"] = new_attention_mask |
|
inputs["labels"] = new_labels |
|
if random.random() <= self.lmbda: |
|
with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model: |
|
new_input_ids, new_attention_mask, new_labels = self.generate_on_policy_outputs( |
|
unwrapped_model, inputs, self.generation_config, self.processing_class.pad_token_id |
|
) |
|
inputs["input_ids"] = new_input_ids |
|
inputs["attention_mask"] = new_attention_mask |
|
inputs["labels"] = new_labels |
|
|
|
loss = super().training_step(model, inputs, num_items_in_batch) |
|
return loss |
|
|
|
def _prepare_deepspeed(self, model: PreTrainedModelWrapper): |
|
|
|
deepspeed_plugin = self.accelerator.state.deepspeed_plugin |
|
config_kwargs = deepcopy(deepspeed_plugin.deepspeed_config) |
|
|
|
if model is not None: |
|
if hasattr(model, "config"): |
|
hidden_size = ( |
|
max(model.config.hidden_sizes) |
|
if getattr(model.config, "hidden_sizes", None) |
|
else getattr(model.config, "hidden_size", None) |
|
) |
|
if hidden_size is not None and config_kwargs["zero_optimization"]["stage"] == 3: |
|
|
|
|
|
config_kwargs.update( |
|
{ |
|
"zero_optimization.reduce_bucket_size": hidden_size * hidden_size, |
|
"zero_optimization.stage3_param_persistence_threshold": 10 * hidden_size, |
|
"zero_optimization.stage3_prefetch_bucket_size": 0.9 * hidden_size * hidden_size, |
|
} |
|
) |
|
|
|
|
|
|
|
if config_kwargs["zero_optimization"]["stage"] != 3: |
|
config_kwargs["zero_optimization"]["stage"] = 0 |
|
model, *_ = deepspeed.initialize(model=model, config=config_kwargs) |
|
model.eval() |
|
return model |
|
|
|
def create_model_card( |
|
self, |
|
model_name: Optional[str] = None, |
|
dataset_name: Optional[str] = None, |
|
tags: Union[str, list[str], None] = None, |
|
): |
|
""" |
|
Creates a draft of a model card using the information available to the `Trainer`. |
|
|
|
Args: |
|
model_name (`str` or `None`, *optional*, defaults to `None`): |
|
Name of the model. |
|
dataset_name (`str` or `None`, *optional*, defaults to `None`): |
|
Name of the dataset used for training. |
|
tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`): |
|
Tags to be associated with the model card. |
|
""" |
|
if not self.is_world_process_zero(): |
|
return |
|
|
|
if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path): |
|
base_model = self.model.config._name_or_path |
|
else: |
|
base_model = None |
|
|
|
tags = tags or [] |
|
if isinstance(tags, str): |
|
tags = [tags] |
|
|
|
if hasattr(self.model.config, "unsloth_version"): |
|
tags.append("unsloth") |
|
|
|
citation = textwrap.dedent("""\ |
|
@inproceedings{agarwal2024on-policy, |
|
title = {{On-Policy Distillation of Language Models: Learning from Self-Generated Mistakes}}, |
|
author = {Rishabh Agarwal and Nino Vieillard and Yongchao Zhou and Piotr Stanczyk and Sabela Ramos Garea and Matthieu Geist and Olivier Bachem}, |
|
year = 2024, |
|
booktitle = {The Twelfth International Conference on Learning Representations, {ICLR} 2024, Vienna, Austria, May 7-11, 2024}, |
|
publisher = {OpenReview.net}, |
|
url = {https://openreview.net/forum?id=3zKtaqxLhW}, |
|
}""") |
|
|
|
model_card = generate_model_card( |
|
base_model=base_model, |
|
model_name=model_name, |
|
hub_model_id=self.hub_model_id, |
|
dataset_name=dataset_name, |
|
tags=tags, |
|
wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None, |
|
comet_url=get_comet_experiment_url(), |
|
trainer_name="GKD", |
|
trainer_citation=citation, |
|
paper_title="On-Policy Distillation of Language Models: Learning from Self-Generated Mistakes", |
|
paper_id="2306.13649", |
|
) |
|
|
|
model_card.save(os.path.join(self.args.output_dir, "README.md")) |
|
|