|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import dataclasses |
|
import os |
|
import warnings |
|
from collections import defaultdict |
|
from dataclasses import dataclass |
|
from typing import Any, Callable, Optional, Union |
|
|
|
import torch |
|
import torch.nn as nn |
|
import transformers |
|
from accelerate import PartialState |
|
from datasets import Dataset, IterableDataset |
|
from packaging import version |
|
from transformers import ( |
|
AutoModelForCausalLM, |
|
AutoTokenizer, |
|
BaseImageProcessor, |
|
DataCollator, |
|
DataCollatorWithFlattening, |
|
FeatureExtractionMixin, |
|
PreTrainedModel, |
|
PreTrainedTokenizerBase, |
|
ProcessorMixin, |
|
Trainer, |
|
TrainingArguments, |
|
is_wandb_available, |
|
) |
|
from transformers.data.data_collator import DataCollatorMixin |
|
from transformers.trainer_callback import TrainerCallback |
|
from transformers.trainer_utils import EvalPrediction |
|
from transformers.utils import is_peft_available |
|
|
|
from ..data_utils import ( |
|
apply_chat_template, |
|
is_conversational, |
|
maybe_convert_to_chatml, |
|
pack_dataset, |
|
truncate_dataset, |
|
) |
|
from .sft_config import SFTConfig |
|
from .utils import ( |
|
ConstantLengthDataset, |
|
generate_model_card, |
|
get_comet_experiment_url, |
|
pad, |
|
peft_module_casting_to_bf16, |
|
) |
|
|
|
|
|
if is_peft_available(): |
|
import peft |
|
from peft import PeftConfig, PeftModel, get_peft_model, prepare_model_for_kbit_training |
|
|
|
if is_wandb_available(): |
|
import wandb |
|
|
|
|
|
@dataclass |
|
class DataCollatorForLanguageModeling(DataCollatorMixin): |
|
""" |
|
Data collator used for language modeling data. Inputs are dynamically padded to the maximum length of a batch if |
|
they are not all of the same length. |
|
|
|
Args: |
|
pad_token_id (`int`): |
|
Token ID to use for padding. |
|
completion_only_loss (`bool`, *optional*, defaults to `True`): |
|
When the input contains a completion mask (`completion_mask`), the labels are set to -100 for the tokens |
|
that are not in the completion. |
|
return_tensors (`str`, *optional*, defaults to `"pt"`): |
|
Type of Tensor to return. Only `"pt"` is currently supported. |
|
|
|
Examples: |
|
```python |
|
>>> from trl import DataCollatorForLanguageModeling |
|
>>> collator = DataCollatorForLanguageModeling(pad_token_id=0) |
|
>>> examples = [ |
|
... {"input_ids": [1, 2, 3]}, |
|
... {"input_ids": [4, 5]} |
|
... ] |
|
>>> collator(examples) |
|
{'input_ids': tensor([[ 1, 2, 3], |
|
[ 4, 5, 0]]), |
|
'attention_mask': tensor([[ 1, 1, 1], |
|
[ 1, 1, 0]]), |
|
'labels': tensor([[ 1, 2, 3], |
|
[ 4, 5, -100]])} |
|
>>> # With completion mask |
|
>>> examples = [ |
|
... {"input_ids": [1, 2, 3], "completion_mask": [0, 1, 1]}, |
|
... {"input_ids": [4, 5], "completion_mask": [0, 1]} |
|
... ] |
|
>>> collator(examples) |
|
{'input_ids': tensor([[ 1, 2, 3], |
|
[ 4, 5, 0]]), |
|
'attention_mask': tensor([[ 1, 1, 1], |
|
[ 1, 1, 0]]), |
|
'labels': tensor([[-100, 2, 3], |
|
[-100, 5, -100]])} |
|
``` |
|
""" |
|
|
|
pad_token_id: int |
|
completion_only_loss: bool = True |
|
return_tensors: str = "pt" |
|
|
|
def torch_call(self, examples: list[Union[list[int], Any, dict[str, Any]]]) -> dict[str, Any]: |
|
|
|
input_ids = [torch.tensor(example["input_ids"]) for example in examples] |
|
attention_mask = [torch.ones_like(input_ids) for input_ids in input_ids] |
|
labels = [torch.tensor(example["input_ids"]) for example in examples] |
|
if self.completion_only_loss and "completion_mask" in examples[0]: |
|
completion_mask = [torch.tensor(example["completion_mask"]) for example in examples] |
|
|
|
|
|
output = {} |
|
output["input_ids"] = pad(input_ids, padding_value=self.pad_token_id, padding_side="right") |
|
output["attention_mask"] = pad(attention_mask, padding_value=0, padding_side="right") |
|
output["labels"] = pad(labels, padding_value=-100, padding_side="right") |
|
if self.completion_only_loss and "completion_mask" in examples[0]: |
|
completion_mask = pad(completion_mask, padding_value=0, padding_side="right") |
|
output["labels"][completion_mask == 0] = -100 |
|
|
|
return output |
|
|
|
|
|
class SFTTrainer(Trainer): |
|
""" |
|
Trainer for Supervised Fine-Tuning (SFT) method. |
|
|
|
This class is a wrapper around the [`transformers.Trainer`] class and inherits all of its attributes and methods. |
|
|
|
Example: |
|
|
|
```python |
|
from datasets import load_dataset |
|
from trl import SFTTrainer |
|
|
|
dataset = load_dataset("roneneldan/TinyStories", split="train[:1%]") |
|
|
|
trainer = SFTTrainer(model="Qwen/Qwen2-0.5B-Instruct", train_dataset=dataset) |
|
trainer.train() |
|
``` |
|
|
|
Args: |
|
model (`Union[str, PreTrainedModel]`): |
|
Model to be trained. Can be either: |
|
|
|
- A string, being the *model id* of a pretrained model hosted inside a model repo on huggingface.co, or |
|
a path to a *directory* containing model weights saved using |
|
[`~transformers.PreTrainedModel.save_pretrained`], e.g., `'./my_model_directory/'`. The model is |
|
loaded using [`~transformers.AutoModelForCausalLM.from_pretrained`] with the keywork arguments |
|
in `args.model_init_kwargs`. |
|
- A [`~transformers.PreTrainedModel`] object. Only causal language models are supported. |
|
args ([`SFTConfig`], *optional*, defaults to `None`): |
|
Configuration for this trainer. If `None`, a default configuration is used. |
|
data_collator (`DataCollator`, *optional*): |
|
Function to use to form a batch from a list of elements of the processed `train_dataset` or `eval_dataset`. |
|
Will default to [`~transformers.default_data_collator`] if no `processing_class` is provided, an instance |
|
of [`~transformers.DataCollatorWithPadding`] otherwise if the processing_class is a feature extractor or |
|
tokenizer. |
|
train_dataset ([`~datasets.Dataset`] or [`~datasets.IterableDataset`]): |
|
Dataset to use for training. SFT supports both [language modeling](#language-modeling) type and |
|
[prompt-completion](#prompt-completion) type. The format of the samples can be either: |
|
|
|
- [Standard](dataset_formats#standard): Each sample contains plain text. |
|
- [Conversational](dataset_formats#conversational): Each sample contains structured messages (e.g., role |
|
and content). |
|
|
|
The trainer also supports processed datasets (tokenized) as long as they contain an `input_ids` field. |
|
eval_dataset ([`~datasets.Dataset`], [`~datasets.IterableDataset`] or `dict[str, Union[Dataset, IterableDataset]]`): |
|
Dataset to use for evaluation. It must meet the same requirements as `train_dataset`. |
|
processing_class ([`~transformers.PreTrainedTokenizerBase`], *optional*, defaults to `None`): |
|
Processing class used to process the data. If `None`, the processing class is loaded from the model's name |
|
with [`~transformers.AutoTokenizer.from_pretrained`]. |
|
callbacks (list of [`~transformers.TrainerCallback`], *optional*, defaults to `None`): |
|
List of callbacks to customize the training loop. Will add those to the list of default callbacks |
|
detailed in [here](https://huggingface.co/docs/transformers/main_classes/callback). |
|
|
|
If you want to remove one of the default callbacks used, use the [`~transformers.Trainer.remove_callback`] |
|
method. |
|
optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`, *optional*, defaults to `(None, None)`): |
|
A tuple containing the optimizer and the scheduler to use. Will default to an instance of [`AdamW`] on your |
|
model and a scheduler given by [`get_linear_schedule_with_warmup`] controlled by `args`. |
|
optimizer_cls_and_kwargs (`Tuple[Type[torch.optim.Optimizer], Dict[str, Any]]`, *optional*, defaults to `None`): |
|
A tuple containing the optimizer class and keyword arguments to use. |
|
Overrides `optim` and `optim_args` in `args`. Incompatible with the `optimizers` argument. |
|
|
|
Unlike `optimizers`, this argument avoids the need to place model parameters on the correct devices before initializing the Trainer. |
|
preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`, *optional*, defaults to `None`): |
|
A function that preprocess the logits right before caching them at each evaluation step. Must take two |
|
tensors, the logits and the labels, and return the logits once processed as desired. The modifications made |
|
by this function will be reflected in the predictions received by `compute_metrics`. |
|
|
|
Note that the labels (second parameter) will be `None` if the dataset does not have them. |
|
peft_config ([`~peft.PeftConfig`], *optional*, defaults to `None`): |
|
PEFT configuration used to wrap the model. If `None`, the model is not wrapped. |
|
formatting_func (`Optional[Callable]`): |
|
Formatting function applied to the dataset before tokenization. |
|
""" |
|
|
|
_tag_names = ["trl", "sft"] |
|
|
|
def __init__( |
|
self, |
|
model: Union[str, nn.Module, PreTrainedModel], |
|
args: Optional[Union[SFTConfig, TrainingArguments]] = None, |
|
data_collator: Optional[DataCollator] = None, |
|
train_dataset: Optional[Union[Dataset, IterableDataset]] = None, |
|
eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None, |
|
processing_class: Optional[ |
|
Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin] |
|
] = None, |
|
compute_loss_func: Optional[Callable] = None, |
|
compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None, |
|
callbacks: Optional[list[TrainerCallback]] = None, |
|
optimizers: tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]] = (None, None), |
|
optimizer_cls_and_kwargs: Optional[tuple[type[torch.optim.Optimizer], dict[str, Any]]] = None, |
|
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None, |
|
peft_config: Optional["PeftConfig"] = None, |
|
formatting_func: Optional[Union[Callable[[dict], str], Callable[[dict], list[str]]]] = None, |
|
): |
|
|
|
model_id = model if isinstance(model, str) else model.config._name_or_path |
|
if args is None: |
|
model_name = model_id.split("/")[-1] |
|
args = SFTConfig(f"{model_name}-SFT") |
|
elif isinstance(args, TrainingArguments) and not isinstance(args, SFTConfig): |
|
dict_args = args.to_dict() |
|
dict_args["hub_token"] = args.hub_token |
|
dict_args.pop("push_to_hub_token") |
|
args = SFTConfig(**dict_args) |
|
|
|
|
|
if processing_class is None: |
|
processing_class = AutoTokenizer.from_pretrained(model_id) |
|
|
|
if args.eos_token is not None: |
|
eos_token = args.eos_token |
|
eos_token_id = processing_class.convert_tokens_to_ids(eos_token) |
|
if eos_token_id is None: |
|
raise ValueError( |
|
f"The specified `eos_token` ('{eos_token}') is not found in the vocabulary of the given " |
|
f"`processing_class` ({processing_class.__class__.__name__}). Ensure that the `eos_token` exists " |
|
"in the vocabulary before using it as an EOS token." |
|
) |
|
processing_class.eos_token_id = eos_token_id |
|
|
|
|
|
if args.model_init_kwargs is not None and not isinstance(model, str): |
|
warnings.warn( |
|
"You passed model_init_kwargs to the `SFTConfig`, but your model is already instantiated. " |
|
"The `model_init_kwargs` will be ignored." |
|
) |
|
if isinstance(model, str): |
|
model = self._create_model_from_path(model, args) |
|
|
|
|
|
if peft_config is not None: |
|
model = self._prepare_peft_model(model, peft_config, args) |
|
|
|
|
|
if args.padding_free: |
|
if data_collator is not None: |
|
raise ValueError("Passing a custom data collator is not supported when using padding-free.") |
|
if args.packing: |
|
warnings.warn( |
|
"You are passing `packing=True` and `padding_free=True` which is not recommended. Please refer " |
|
"to the documentation to understand why this is not recommended." |
|
) |
|
if model.config._attn_implementation != "flash_attention_2": |
|
warnings.warn( |
|
"Padding-free training is enabled, but the attention implementation is not set to " |
|
"'flash_attention_2'. Padding-free training flattens batches into a single sequence, and " |
|
"'flash_attention_2' is the only known attention mechanism that reliably supports this. Using " |
|
"other implementations may lead to unexpected behavior. To ensure compatibility, set " |
|
"`attn_implementation='flash_attention_2'` in the model configuration, or verify that your " |
|
"attention mechanism can handle flattened sequences." |
|
) |
|
if args.per_device_train_batch_size == 1: |
|
warnings.warn( |
|
"You are using a per_device_train_batch_size of 1 with padding-free training. Using a batch size " |
|
"of 1 anihilate the benefits of padding-free training. Please consider increasing the batch size " |
|
"to at least 2." |
|
) |
|
data_collator = DataCollatorWithFlattening() |
|
|
|
if args.completion_only_loss is None: |
|
first_example = next(iter(train_dataset)) |
|
self.completion_only_loss = "prompt" in first_example |
|
else: |
|
self.completion_only_loss = args.completion_only_loss |
|
if data_collator is None: |
|
|
|
|
|
pad_token = args.pad_token or processing_class.pad_token or processing_class.eos_token |
|
pad_token_id = processing_class.convert_tokens_to_ids(pad_token) |
|
if pad_token_id is None: |
|
raise ValueError( |
|
f"The specified `pad_token` ('{pad_token}') is not found in the vocabulary of the given " |
|
f"`processing_class` ({processing_class.__class__.__name__}). Ensure that the `pad_token` exists " |
|
"in the vocabulary before using it as a padding token." |
|
) |
|
data_collator = DataCollatorForLanguageModeling(pad_token_id, self.completion_only_loss) |
|
|
|
|
|
preprocess_dataset = args.dataset_kwargs is None or not args.dataset_kwargs.get("skip_prepare_dataset", False) |
|
if preprocess_dataset: |
|
train_dataset = self._prepare_dataset( |
|
train_dataset, processing_class, args, args.packing, formatting_func, "train" |
|
) |
|
if eval_dataset is not None: |
|
packing = args.packing if args.eval_packing is None else args.eval_packing |
|
if isinstance(eval_dataset, dict): |
|
eval_dataset = { |
|
key: self._prepare_dataset(dataset, processing_class, args, packing, formatting_func, key) |
|
for key, dataset in eval_dataset.items() |
|
} |
|
else: |
|
eval_dataset = self._prepare_dataset( |
|
eval_dataset, processing_class, args, packing, formatting_func, "eval" |
|
) |
|
|
|
|
|
self._metrics = {"train": defaultdict(list), "eval": defaultdict(list)} |
|
self._total_train_tokens = 0 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
super_init_kwargs = {} |
|
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"): |
|
super_init_kwargs["optimizer_cls_and_kwargs"] = optimizer_cls_and_kwargs |
|
else: |
|
if optimizer_cls_and_kwargs is not None: |
|
warnings.warn( |
|
"The `optimizer_cls_and_kwargs` argument is only available for `transformers>=4.47.0`. " |
|
"The default optimizer will be used. " |
|
"Remove the `optimizer_cls_and_kwargs` or upgrade to `transformers>=4.47.0`." |
|
) |
|
super().__init__( |
|
model=model, |
|
args=args, |
|
data_collator=data_collator, |
|
train_dataset=train_dataset, |
|
eval_dataset=eval_dataset, |
|
processing_class=processing_class, |
|
compute_loss_func=compute_loss_func, |
|
compute_metrics=compute_metrics, |
|
callbacks=callbacks, |
|
optimizers=optimizers, |
|
preprocess_logits_for_metrics=preprocess_logits_for_metrics, |
|
**super_init_kwargs, |
|
) |
|
|
|
|
|
if hasattr(self.model, "add_model_tags"): |
|
self.model.add_model_tags(self._tag_names) |
|
|
|
def _create_model_from_path(self, model_path: str, args: SFTConfig) -> PreTrainedModel: |
|
"""Creates a model from a path or model identifier.""" |
|
model_init_kwargs = args.model_init_kwargs or {} |
|
|
|
torch_dtype = model_init_kwargs.get("torch_dtype") |
|
if isinstance(torch_dtype, torch.dtype) or torch_dtype == "auto" or torch_dtype is None: |
|
pass |
|
elif isinstance(torch_dtype, str): |
|
torch_dtype = getattr(torch, torch_dtype) |
|
model_init_kwargs["torch_dtype"] = torch_dtype |
|
else: |
|
raise ValueError( |
|
"Invalid `torch_dtype` passed to `SFTConfig`. Expected either 'auto' or a string representing " |
|
f"a `torch.dtype` (e.g., 'float32'), but got {torch_dtype}." |
|
) |
|
|
|
|
|
|
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained(model_path, **model_init_kwargs) |
|
return model |
|
|
|
def _prepare_peft_model(self, model: PreTrainedModel, peft_config: Any, args: SFTConfig) -> PreTrainedModel: |
|
"""Prepares a model for PEFT training.""" |
|
if not is_peft_available(): |
|
raise ImportError("To use PeftModel, you need to install the `peft` library.") |
|
|
|
if not isinstance(peft_config, PeftConfig): |
|
raise ValueError( |
|
f"Expected PeftConfig object but got {type(peft_config)}. If you want to use the PeftModel, you need " |
|
"to pass a PeftConfig object to the SFTTrainer." |
|
) |
|
|
|
if isinstance(model, PeftModel): |
|
return model |
|
|
|
|
|
is_qlora = getattr(model, "is_loaded_in_4bit", False) or getattr(model, "is_loaded_in_8bit", False) |
|
|
|
is_sharded_qlora = False |
|
if getattr(model, "is_loaded_in_4bit", False): |
|
|
|
for _, param in model.named_parameters(): |
|
if param.__class__.__name__ == "Params4bit": |
|
is_sharded_qlora = param.data.device.type in {"cpu", "meta"} |
|
break |
|
|
|
|
|
if is_qlora and not is_sharded_qlora: |
|
model = self._prepare_model_for_kbit_training(model, args) |
|
|
|
args = dataclasses.replace(args, gradient_checkpointing=False) |
|
elif args.gradient_checkpointing: |
|
model = self._enable_gradient_checkpointing(model, args) |
|
|
|
|
|
if ( |
|
version.parse(peft.__version__) >= version.parse("0.12") |
|
and getattr(model, "is_loaded_in_4bit", False) |
|
and is_sharded_qlora |
|
): |
|
model = get_peft_model(model, peft_config, autocast_adapter_dtype=False) |
|
else: |
|
model = get_peft_model(model, peft_config) |
|
|
|
|
|
if args.bf16 and getattr(model, "is_loaded_in_4bit", False) and not is_sharded_qlora: |
|
peft_module_casting_to_bf16(model) |
|
|
|
return model |
|
|
|
def _prepare_model_for_kbit_training(self, model: PreTrainedModel, args: SFTConfig) -> PreTrainedModel: |
|
"""Prepares a quantized model for kbit training.""" |
|
prepare_model_kwargs = { |
|
"use_gradient_checkpointing": args.gradient_checkpointing, |
|
"gradient_checkpointing_kwargs": args.gradient_checkpointing_kwargs or {}, |
|
} |
|
|
|
return prepare_model_for_kbit_training(model, **prepare_model_kwargs) |
|
|
|
def _enable_gradient_checkpointing(self, model: PreTrainedModel, args: SFTConfig) -> PreTrainedModel: |
|
"""Enables gradient checkpointing for the model.""" |
|
gradient_checkpointing_kwargs = args.gradient_checkpointing_kwargs or {} |
|
use_reentrant = ( |
|
"use_reentrant" not in gradient_checkpointing_kwargs or gradient_checkpointing_kwargs["use_reentrant"] |
|
) |
|
|
|
if use_reentrant: |
|
if hasattr(model, "enable_input_require_grads"): |
|
model.enable_input_require_grads() |
|
else: |
|
|
|
def make_inputs_require_grad(module, input, output): |
|
output.requires_grad_(True) |
|
|
|
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) |
|
|
|
return model |
|
|
|
def _prepare_dataset( |
|
self, |
|
dataset: Union[Dataset, IterableDataset], |
|
processing_class: Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin], |
|
args: SFTConfig, |
|
packing: bool, |
|
formatting_func: Optional[Callable[[dict], str]], |
|
dataset_name: str, |
|
) -> Union[Dataset, IterableDataset]: |
|
|
|
if isinstance(dataset, ConstantLengthDataset): |
|
return dataset |
|
|
|
|
|
column_names = list(next(iter(dataset)).keys()) |
|
is_processed = "input_ids" in column_names |
|
|
|
|
|
map_kwargs = {} |
|
if isinstance(dataset, Dataset): |
|
map_kwargs["num_proc"] = args.dataset_num_proc |
|
|
|
with PartialState().main_process_first(): |
|
|
|
if formatting_func is not None and is_processed: |
|
warnings.warn( |
|
"You passed a dataset that is already processed (contains an `input_ids` field) together with a " |
|
"formatting function. Therefore `formatting_func` will be ignored. Either remove the " |
|
"`formatting_func` or pass a dataset that is not already processed.", |
|
UserWarning, |
|
) |
|
|
|
if formatting_func is not None and not is_processed: |
|
if isinstance(dataset, Dataset): |
|
map_kwargs["desc"] = f"Applying formatting function to {dataset_name} dataset" |
|
|
|
def _func(example): |
|
return {"text": formatting_func(example)} |
|
|
|
try: |
|
dataset = dataset.map(_func, batched=False, **map_kwargs) |
|
except Exception as e: |
|
warnings.warn( |
|
f"Failed to apply the formatting function due to the following error: {e}. This may be " |
|
"because the function is designed for batched input. Please update it to process one example " |
|
"at a time (i.e., accept and return a single example). For now, we will attempt to apply the " |
|
"function in batched mode, but note that batched formatting is deprecated and will be removed " |
|
"in version 0.21.", |
|
DeprecationWarning, |
|
) |
|
dataset = dataset.map(_func, batched=True, **map_kwargs) |
|
|
|
if not is_processed: |
|
|
|
if isinstance(dataset, Dataset): |
|
map_kwargs["desc"] = f"Converting {dataset_name} dataset to ChatML" |
|
column_names = next(iter(dataset)).keys() |
|
dataset = dataset.map( |
|
maybe_convert_to_chatml, |
|
remove_columns="conversations" if "conversations" in column_names else None, |
|
**map_kwargs, |
|
) |
|
|
|
|
|
first_example = next(iter(dataset)) |
|
if is_conversational(first_example): |
|
if isinstance(dataset, Dataset): |
|
map_kwargs["desc"] = f"Applying chat template to {dataset_name} dataset" |
|
column_names = first_example.keys() |
|
dataset = dataset.map( |
|
apply_chat_template, |
|
fn_kwargs={"tokenizer": processing_class}, |
|
remove_columns="messages" if "messages" in column_names else None, |
|
**map_kwargs, |
|
) |
|
|
|
|
|
add_special_tokens = False |
|
|
|
|
|
|
|
else: |
|
if isinstance(dataset, Dataset): |
|
map_kwargs["desc"] = f"Adding EOS to {dataset_name} dataset" |
|
|
|
def add_eos(example, eos_token): |
|
if "text" in example and not example["text"].endswith(eos_token): |
|
example["text"] = example["text"] + eos_token |
|
elif "completion" in example and not example["completion"].endswith(eos_token): |
|
example["completion"] = example["completion"] + eos_token |
|
return example |
|
|
|
dataset = dataset.map( |
|
add_eos, |
|
fn_kwargs={"eos_token": processing_class.eos_token}, |
|
remove_columns="messages" if "messages" in column_names else None, |
|
**map_kwargs, |
|
) |
|
|
|
|
|
add_special_tokens = True |
|
|
|
|
|
if isinstance(dataset, Dataset): |
|
map_kwargs["desc"] = f"Tokenizing {dataset_name} dataset" |
|
|
|
def tokenize(example, processing_class, dataset_text_field, add_special_tokens): |
|
if "prompt" in example: |
|
processed_prompt = processing_class( |
|
text=example["prompt"], |
|
add_special_tokens=add_special_tokens, |
|
) |
|
processed = processing_class( |
|
text=example["prompt"] + example["completion"], add_special_tokens=add_special_tokens |
|
) |
|
|
|
|
|
prompt_ids = processed_prompt["input_ids"] |
|
prompt_completion_ids = processed["input_ids"] |
|
if not prompt_completion_ids[: len(prompt_ids)] == prompt_ids: |
|
warnings.warn( |
|
"Mismatch between tokenized prompt and the start of tokenized prompt+completion. " |
|
"This may be due to unexpected tokenizer behavior, whitespace issues, or special " |
|
"token handling. Verify that the tokenizer is processing text consistently." |
|
) |
|
|
|
|
|
completion_mask = [0] * len(prompt_ids) + [1] * (len(prompt_completion_ids) - len(prompt_ids)) |
|
processed = {**processed, "completion_mask": completion_mask} |
|
|
|
else: |
|
processed = processing_class( |
|
text=example[dataset_text_field], add_special_tokens=add_special_tokens |
|
) |
|
return processed |
|
|
|
dataset = dataset.map( |
|
tokenize, |
|
fn_kwargs={ |
|
"processing_class": processing_class, |
|
"dataset_text_field": args.dataset_text_field, |
|
"add_special_tokens": add_special_tokens, |
|
}, |
|
**map_kwargs, |
|
) |
|
|
|
|
|
if packing: |
|
if args.max_length is None: |
|
raise ValueError("When packing is enabled, `max_length` can't be `None`.") |
|
if isinstance(dataset, Dataset): |
|
map_kwargs["desc"] = f"Packing {dataset_name} dataset" |
|
dataset = dataset.select_columns("input_ids") |
|
dataset = pack_dataset(dataset, args.max_length, map_kwargs) |
|
elif args.max_length is not None: |
|
if isinstance(dataset, Dataset): |
|
map_kwargs["desc"] = f"Truncating {dataset_name} dataset" |
|
dataset = truncate_dataset(dataset, args.max_length, map_kwargs) |
|
|
|
if args.use_liger_kernel: |
|
dataset = dataset.select_columns("input_ids") |
|
|
|
return dataset |
|
|
|
def _set_signature_columns_if_needed(self): |
|
|
|
|
|
|
|
|
|
if self._signature_columns is None: |
|
self._signature_columns = ["input_ids", "attention_mask", "completion_mask"] |
|
|
|
def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): |
|
""" |
|
Compute training loss and additionally compute token accuracies |
|
""" |
|
mode = "eval" if self.control.should_evaluate else "train" |
|
(loss, outputs) = super().compute_loss( |
|
model, inputs, return_outputs=True, num_items_in_batch=num_items_in_batch |
|
) |
|
if mode == "train": |
|
|
|
|
|
if "attention_mask" in inputs: |
|
num_tokens_in_batch = self.accelerator.gather_for_metrics(inputs["attention_mask"].sum()).sum().item() |
|
elif "position_ids" in inputs: |
|
local_num_tokens = torch.tensor(inputs["position_ids"].size(1), device=inputs["position_ids"].device) |
|
num_tokens_in_batch = self.accelerator.gather_for_metrics(local_num_tokens).sum().item() |
|
else: |
|
raise ValueError("Expected 'attention_mask' or 'position_ids' in inputs.") |
|
self._total_train_tokens += num_tokens_in_batch |
|
self._metrics[mode]["num_tokens"] = [self._total_train_tokens] |
|
|
|
|
|
if "labels" in inputs and not self.args.use_liger_kernel: |
|
shift_logits = outputs.logits[..., :-1, :].contiguous() |
|
shift_labels = inputs["labels"][..., 1:].contiguous() |
|
|
|
|
|
predictions = shift_logits.argmax(dim=-1) |
|
|
|
|
|
mask = shift_labels != -100 |
|
|
|
|
|
correct_predictions = (predictions == shift_labels) & mask |
|
total_tokens = mask.sum() |
|
correct_tokens = correct_predictions.sum() |
|
|
|
|
|
correct_tokens = self.accelerator.gather_for_metrics(correct_tokens) |
|
total_tokens = self.accelerator.gather_for_metrics(total_tokens) |
|
|
|
|
|
total_sum = total_tokens.sum() |
|
accuracy = (correct_tokens.sum() / total_sum).item() if total_sum > 0 else 0.0 |
|
self._metrics[mode]["mean_token_accuracy"].append(accuracy) |
|
|
|
return (loss, outputs) if return_outputs else loss |
|
|
|
def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None: |
|
mode = "eval" if self.control.should_evaluate else "train" |
|
metrics = {key: sum(val) / len(val) for key, val in self._metrics[mode].items()} |
|
|
|
|
|
|
|
if mode == "eval": |
|
metrics = {f"eval_{key}": val for key, val in metrics.items()} |
|
|
|
logs = {**logs, **metrics} |
|
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"): |
|
super().log(logs, start_time) |
|
else: |
|
super().log(logs) |
|
self._metrics[mode].clear() |
|
|
|
def create_model_card( |
|
self, |
|
model_name: Optional[str] = None, |
|
dataset_name: Optional[str] = None, |
|
tags: Union[str, list[str], None] = None, |
|
): |
|
""" |
|
Creates a draft of a model card using the information available to the `Trainer`. |
|
|
|
Args: |
|
model_name (`str` or `None`, *optional*, defaults to `None`): |
|
Name of the model. |
|
dataset_name (`str` or `None`, *optional*, defaults to `None`): |
|
Name of the dataset used for training. |
|
tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`): |
|
Tags to be associated with the model card. |
|
""" |
|
if not self.is_world_process_zero(): |
|
return |
|
|
|
if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path): |
|
base_model = self.model.config._name_or_path |
|
else: |
|
base_model = None |
|
|
|
tags = tags or [] |
|
if isinstance(tags, str): |
|
tags = [tags] |
|
|
|
if hasattr(self.model.config, "unsloth_version"): |
|
tags.append("unsloth") |
|
|
|
model_card = generate_model_card( |
|
base_model=base_model, |
|
model_name=model_name, |
|
hub_model_id=self.hub_model_id, |
|
dataset_name=dataset_name, |
|
tags=tags, |
|
wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None, |
|
comet_url=get_comet_experiment_url(), |
|
trainer_name="SFT", |
|
) |
|
|
|
model_card.save(os.path.join(self.args.output_dir, "README.md")) |
|
|