|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import dataclasses |
|
import importlib.resources as pkg_resources |
|
import json |
|
import random |
|
import warnings |
|
from collections import deque |
|
from dataclasses import dataclass, field |
|
from importlib.metadata import version |
|
from typing import Any, Literal, Optional, Union |
|
|
|
import datasets |
|
import numpy as np |
|
import pandas as pd |
|
import torch |
|
import torch.nn.functional as F |
|
import torch.utils.data |
|
from accelerate import Accelerator, PartialState |
|
from accelerate.state import AcceleratorState |
|
from huggingface_hub import ModelCard, ModelCardData |
|
from torch.nn.utils.rnn import pad_sequence |
|
from torch.utils.data import IterableDataset |
|
from transformers import ( |
|
BitsAndBytesConfig, |
|
DataCollatorForLanguageModeling, |
|
EvalPrediction, |
|
GenerationConfig, |
|
PreTrainedTokenizerBase, |
|
TrainerState, |
|
TrainingArguments, |
|
is_comet_available, |
|
) |
|
from transformers.utils import ( |
|
is_peft_available, |
|
is_torch_mlu_available, |
|
is_torch_npu_available, |
|
is_torch_xpu_available, |
|
) |
|
|
|
from ..import_utils import is_rich_available |
|
from ..trainer.model_config import ModelConfig |
|
|
|
|
|
if is_rich_available(): |
|
from rich.console import Console |
|
from rich.panel import Panel |
|
from rich.table import Table |
|
from rich.text import Text |
|
|
|
if is_comet_available(): |
|
import comet_ml |
|
|
|
if is_peft_available(): |
|
from peft import LoraConfig, PeftConfig |
|
|
|
|
|
class DataCollatorForCompletionOnlyLM(DataCollatorForLanguageModeling): |
|
""" |
|
Data collator used for completion tasks. It ensures that all the tokens of the labels are set to an 'ignore_index' |
|
when they do not come from the assistant. This ensure that the loss is only |
|
calculated on the completion made by the assistant. |
|
|
|
Args: |
|
response_template (`Union[str, list[int]]`): the template form that indicates the start of the response, typically something like |
|
'### Response:\n'. It can also be passed as tokenized ids, which can be useful when using a tokenizer that encodes the response |
|
differently if it does not have proper context. |
|
instruction_template (`Union[str, list[int]]`): the template form that indicates the start of the human instruction, typically something like |
|
'### Human:\n'. Useful for assistant-style conversation datasets. It can also be passed as tokenized ids. |
|
mlm (`bool`, *optional*, defaults to `False`): Whether to use masked language modeling in the underlying |
|
`DataCollatorForLanguageModeling` class. Note that this option currently has no effect but is present |
|
for flexibility and backwards-compatibility. |
|
ignore_index (`int`, *optional*, defaults to `-100`): |
|
The index to use to ignore the initial tokens with |
|
""" |
|
|
|
def __init__( |
|
self, |
|
response_template: Union[str, list[int]], |
|
instruction_template: Optional[Union[str, list[int]]] = None, |
|
*args, |
|
mlm: bool = False, |
|
ignore_index: int = -100, |
|
padding_free: bool = False, |
|
**kwargs, |
|
): |
|
super().__init__(*args, mlm=mlm, **kwargs) |
|
warnings.warn( |
|
"This class is deprecated and will be removed in version 0.20.0. To train on completion only, please use " |
|
"the parameter `completion_only_loss` of `SFTConfig` instead.", |
|
DeprecationWarning, |
|
) |
|
|
|
self.instruction_template = instruction_template |
|
if isinstance(instruction_template, str): |
|
|
|
self.instruction_token_ids = self.tokenizer.encode(self.instruction_template, add_special_tokens=False) |
|
else: |
|
|
|
self.instruction_token_ids = instruction_template |
|
|
|
self.response_template = response_template |
|
if isinstance(response_template, str): |
|
|
|
self.response_token_ids = self.tokenizer.encode(self.response_template, add_special_tokens=False) |
|
else: |
|
|
|
self.response_token_ids = response_template |
|
|
|
if not self.mlm and self.instruction_template and self.tokenizer.pad_token_id == self.tokenizer.eos_token_id: |
|
warnings.warn( |
|
"The pad_token_id and eos_token_id values of this tokenizer are identical. " |
|
"If you are planning for multi-turn training, " |
|
"it can result in the model continuously generating questions and answers without eos token. " |
|
"To avoid this, set the pad_token_id to a different value.", |
|
UserWarning, |
|
) |
|
|
|
self.ignore_index = ignore_index |
|
self.padding_free = padding_free |
|
|
|
def torch_call(self, examples: list[Union[list[int], Any, dict[str, Any]]]) -> dict[str, Any]: |
|
batch = super().torch_call(examples) |
|
|
|
if self.instruction_template is None: |
|
for i in range(len(examples)): |
|
response_token_ids_start_idx = None |
|
|
|
for idx in np.where(batch["labels"][i] == self.response_token_ids[0])[0]: |
|
|
|
if ( |
|
self.response_token_ids |
|
== batch["labels"][i][idx : idx + len(self.response_token_ids)].tolist() |
|
): |
|
response_token_ids_start_idx = idx |
|
|
|
if response_token_ids_start_idx is None: |
|
warnings.warn( |
|
f"Could not find response key `{self.response_template}` in the following instance: " |
|
f"{self.tokenizer.decode(batch['input_ids'][i])}. This instance will be ignored in loss " |
|
"calculation. Note, if this happens often, consider increasing the `max_length`.", |
|
UserWarning, |
|
) |
|
batch["labels"][i, :] = self.ignore_index |
|
else: |
|
response_token_ids_end_idx = response_token_ids_start_idx + len(self.response_token_ids) |
|
|
|
|
|
batch["labels"][i, :response_token_ids_end_idx] = self.ignore_index |
|
|
|
else: |
|
for i in range(len(examples)): |
|
response_token_ids_idxs = [] |
|
human_token_ids_idxs = [] |
|
|
|
for assistant_idx in np.where(batch["labels"][i] == self.response_token_ids[0])[0]: |
|
|
|
if ( |
|
self.response_token_ids |
|
== batch["labels"][i][assistant_idx : assistant_idx + len(self.response_token_ids)].tolist() |
|
): |
|
response_token_ids_idxs.append(assistant_idx + len(self.response_token_ids)) |
|
|
|
if len(response_token_ids_idxs) == 0: |
|
warnings.warn( |
|
f"Could not find response key `{self.response_template}` in the following instance: " |
|
f"{self.tokenizer.decode(batch['input_ids'][i])}. This instance will be ignored in loss " |
|
"calculation. Note, if this happens often, consider increasing the `max_length`.", |
|
UserWarning, |
|
) |
|
batch["labels"][i, :] = self.ignore_index |
|
|
|
human_token_ids = self.instruction_token_ids |
|
for human_idx in np.where(batch["labels"][i] == human_token_ids[0])[0]: |
|
|
|
if human_token_ids == batch["labels"][i][human_idx : human_idx + len(human_token_ids)].tolist(): |
|
human_token_ids_idxs.append(human_idx) |
|
|
|
if len(human_token_ids_idxs) == 0: |
|
warnings.warn( |
|
f"Could not find instruction key `{self.instruction_template}` in the following instance: " |
|
f"{self.tokenizer.decode(batch['input_ids'][i])}. This instance will be ignored in loss " |
|
"calculation. Note, if this happens often, consider increasing the `max_length`.", |
|
UserWarning, |
|
) |
|
batch["labels"][i, :] = self.ignore_index |
|
|
|
if ( |
|
len(human_token_ids_idxs) > 0 |
|
and len(response_token_ids_idxs) > 0 |
|
and human_token_ids_idxs[0] > response_token_ids_idxs[0] |
|
): |
|
human_token_ids_idxs = [0] + human_token_ids_idxs |
|
|
|
for idx, (start, end) in enumerate(zip(human_token_ids_idxs, response_token_ids_idxs)): |
|
|
|
if idx != 0: |
|
batch["labels"][i, start:end] = self.ignore_index |
|
else: |
|
batch["labels"][i, :end] = self.ignore_index |
|
|
|
if len(response_token_ids_idxs) < len(human_token_ids_idxs): |
|
batch["labels"][i, human_token_ids_idxs[-1] :] = self.ignore_index |
|
|
|
if self.padding_free: |
|
|
|
attn_mask = batch.pop("attention_mask") |
|
batch["input_ids"] = batch["input_ids"][attn_mask.bool()].unsqueeze(0) |
|
batch["position_ids"] = attn_mask.cumsum(1)[attn_mask.bool()].unsqueeze(0) - 1 |
|
batch["labels"] = batch["labels"][attn_mask.bool()].unsqueeze(0) |
|
batch["labels"][batch["position_ids"] == 0] = self.ignore_index |
|
|
|
|
|
flattened_position_ids = batch["position_ids"].flatten() |
|
indices_q = torch.arange( |
|
flattened_position_ids.size(0), device=flattened_position_ids.device, dtype=torch.int32 |
|
) |
|
batch["cu_seq_lens_q"] = torch.cat( |
|
( |
|
indices_q[flattened_position_ids == 0], |
|
torch.tensor( |
|
flattened_position_ids.size(), device=flattened_position_ids.device, dtype=torch.int32 |
|
), |
|
) |
|
).unsqueeze(0) |
|
batch["cu_seq_lens_k"] = batch["cu_seq_lens_q"] |
|
|
|
|
|
batch["max_length_k"] = torch.tensor([flattened_position_ids.max().item() + 1]) |
|
batch["max_length_q"] = batch["max_length_k"] |
|
|
|
return batch |
|
|
|
|
|
@dataclass |
|
class DataCollatorForChatML: |
|
""" |
|
Data collator for ChatML format datasets. |
|
""" |
|
|
|
tokenizer: PreTrainedTokenizerBase |
|
ignore_index: int = -100 |
|
max_length: int = None |
|
prompt_key: str = "prompt" |
|
messages_key: str = "messages" |
|
|
|
def __post_init__(self): |
|
if self.tokenizer.pad_token_id is None: |
|
raise ValueError("The tokenizer does not have a pad token. Please set `pad_token_id` in the tokenizer.") |
|
if self.max_length is None: |
|
|
|
self.max_length = min(self.tokenizer.model_max_length, 1024) |
|
|
|
def __call__(self, examples: list[dict[str, Any]]) -> dict[str, torch.Tensor]: |
|
input_ids = [] |
|
attention_mask = [] |
|
prompts_input_ids = [] |
|
prompt_attention_mask = [] |
|
labels = [] |
|
|
|
for example in examples: |
|
formatted_prompt = example.get(self.prompt_key, None) |
|
if formatted_prompt is None: |
|
prompt = example[self.messages_key][:-1] |
|
formatted_prompt = self.tokenizer.apply_chat_template( |
|
prompt, tokenize=False, add_generation_prompt=True |
|
) |
|
|
|
if "input_ids" not in example: |
|
message = example[self.messages_key] |
|
formatted_message = self.tokenizer.apply_chat_template( |
|
message, tokenize=False, add_generation_prompt=False |
|
) |
|
tokenized_message = self.tokenizer( |
|
formatted_message, |
|
truncation=True, |
|
max_length=self.max_length, |
|
padding=False, |
|
return_tensors=None, |
|
add_special_tokens=False, |
|
) |
|
input_ids.append(tokenized_message["input_ids"]) |
|
attention_mask.append(tokenized_message["attention_mask"]) |
|
else: |
|
input_ids.append(example["input_ids"]) |
|
attention_mask.append(example["attention_mask"]) |
|
|
|
tokenized_prompt = self.tokenizer( |
|
formatted_prompt, |
|
truncation=True, |
|
max_length=len(input_ids[-1]), |
|
padding=False, |
|
return_tensors=None, |
|
add_special_tokens=False, |
|
) |
|
|
|
prompts_input_ids.append(tokenized_prompt["input_ids"]) |
|
prompt_attention_mask.append(tokenized_prompt["attention_mask"]) |
|
|
|
|
|
label = [self.ignore_index] * len(input_ids[-1]) |
|
completion_start_idx = len(tokenized_prompt["input_ids"]) |
|
label[completion_start_idx:] = input_ids[-1][completion_start_idx:] |
|
labels.append(label) |
|
|
|
|
|
input_ids = [torch.tensor(ids, dtype=torch.long) for ids in input_ids] |
|
attention_mask = [torch.tensor(mask, dtype=torch.long) for mask in attention_mask] |
|
labels = [torch.tensor(label, dtype=torch.long) for label in labels] |
|
input_ids = pad(input_ids, padding_side="left", padding_value=self.tokenizer.pad_token_id) |
|
attention_mask = pad(attention_mask, padding_side="left", padding_value=0) |
|
labels = pad(labels, padding_side="left", padding_value=self.ignore_index) |
|
|
|
prompts_input_ids = [torch.tensor(ids, dtype=torch.long) for ids in prompts_input_ids] |
|
prompt_attention_mask = [torch.tensor(mask, dtype=torch.long) for mask in prompt_attention_mask] |
|
prompts_input_ids = pad(prompts_input_ids, padding_side="left", padding_value=self.tokenizer.pad_token_id) |
|
prompt_attention_mask = pad(prompt_attention_mask, padding_side="left", padding_value=0) |
|
|
|
return { |
|
"input_ids": input_ids, |
|
"attention_mask": attention_mask, |
|
"labels": labels, |
|
"prompts": prompts_input_ids, |
|
"prompt_attention_mask": prompt_attention_mask, |
|
} |
|
|
|
|
|
@dataclass |
|
class RewardDataCollatorWithPadding: |
|
r""" |
|
Reward DataCollator class that pads the inputs to the maximum length of the batch. |
|
|
|
Args: |
|
tokenizer (`PreTrainedTokenizerBase`): |
|
The tokenizer used for encoding the data. |
|
padding (`Union[bool, str, `PaddingStrategy`]`, `optional`, defaults to `True`): |
|
padding_strategy to pass to the tokenizer. |
|
pad_to_multiple_of (`int` or `None`, `optional`, defaults to `None`): |
|
If set will pad the sequence to a multiple of the provided value. |
|
return_tensors (`str`, `optional`, defaults to `"pt"`): |
|
The tensor type to use. |
|
""" |
|
|
|
tokenizer: PreTrainedTokenizerBase |
|
padding: Union[bool, str] = True |
|
pad_to_multiple_of: Optional[int] = None |
|
return_tensors: str = "pt" |
|
|
|
def __call__(self, features: list[dict[str, Any]]) -> dict[str, Any]: |
|
features_chosen = [] |
|
features_rejected = [] |
|
margin = [] |
|
|
|
has_margin = "margin" in features[0] |
|
for feature in features: |
|
|
|
if ( |
|
"input_ids_chosen" not in feature |
|
or "input_ids_rejected" not in feature |
|
or "attention_mask_chosen" not in feature |
|
or "attention_mask_rejected" not in feature |
|
): |
|
raise ValueError( |
|
"The features should include `input_ids_chosen`, `attention_mask_chosen`, `input_ids_rejected` and `attention_mask_rejected`" |
|
) |
|
|
|
features_chosen.append( |
|
{ |
|
"input_ids": feature["input_ids_chosen"], |
|
"attention_mask": feature["attention_mask_chosen"], |
|
} |
|
) |
|
features_rejected.append( |
|
{ |
|
"input_ids": feature["input_ids_rejected"], |
|
"attention_mask": feature["attention_mask_rejected"], |
|
} |
|
) |
|
if has_margin: |
|
margin.append(feature["margin"]) |
|
batch_chosen = self.tokenizer.pad( |
|
features_chosen, |
|
padding=self.padding, |
|
pad_to_multiple_of=self.pad_to_multiple_of, |
|
return_tensors=self.return_tensors, |
|
) |
|
batch_rejected = self.tokenizer.pad( |
|
features_rejected, |
|
padding=self.padding, |
|
pad_to_multiple_of=self.pad_to_multiple_of, |
|
return_tensors=self.return_tensors, |
|
) |
|
batch = { |
|
"input_ids_chosen": batch_chosen["input_ids"], |
|
"attention_mask_chosen": batch_chosen["attention_mask"], |
|
"input_ids_rejected": batch_rejected["input_ids"], |
|
"attention_mask_rejected": batch_rejected["attention_mask"], |
|
"return_loss": True, |
|
} |
|
if has_margin: |
|
margin = torch.tensor(margin, dtype=torch.float) |
|
batch["margin"] = margin |
|
return batch |
|
|
|
|
|
def pad(tensors: list[torch.Tensor], padding_value: int = 0, padding_side: str = "right") -> torch.Tensor: |
|
""" |
|
Pads a list of tensors to the same shape along the first dimension. |
|
|
|
Args: |
|
tensors (`list[torch.Tensor]`): |
|
List of input tensors to pad. |
|
padding_value (`int`): |
|
Value to use for padding. Default is 0. |
|
padding_side (`str`): |
|
Side on which to add padding. Must be 'left' or 'right'. Default is 'right'. |
|
|
|
Returns: |
|
`torch.Tensor`: |
|
A single tensor containing the padded tensors. |
|
|
|
Examples: |
|
>>> import torch |
|
>>> pad([torch.tensor([1, 2, 3]), torch.tensor([4, 5])]) |
|
tensor([[1, 2, 3], |
|
[4, 5, 0]]) |
|
>>> pad([torch.tensor([[1, 2], [3, 4]]), torch.tensor([[5, 6]])]) |
|
tensor([[[1, 2], |
|
[3, 4]], |
|
|
|
[[5, 6], |
|
[0, 0]]]) |
|
""" |
|
|
|
output_shape = np.max([t.shape for t in tensors], 0).tolist() |
|
|
|
|
|
output = torch.full((len(tensors), *output_shape), padding_value, dtype=tensors[0].dtype, device=tensors[0].device) |
|
|
|
for i, t in enumerate(tensors): |
|
|
|
if padding_side == "left": |
|
seq_slice = slice(output_shape[0] - t.shape[0], output_shape[0]) |
|
elif padding_side == "right": |
|
seq_slice = slice(0, t.shape[0]) |
|
else: |
|
raise ValueError("padding_side must be 'left' or 'right'") |
|
|
|
slices = (seq_slice,) + tuple(slice(0, s) for s in t.shape[1:]) |
|
output[i][slices] = t |
|
|
|
return output |
|
|
|
|
|
@dataclass |
|
class DPODataCollatorWithPadding: |
|
r""" |
|
DPO DataCollator class that pads the tokenized inputs to the maximum length of the batch. |
|
|
|
Args: |
|
pad_token_id (`int` defaults to 0): |
|
The tokenizer's pad_token_id. |
|
label_pad_token_id (`int`, defaults to -100): |
|
The label used for masking. |
|
is_encoder_decoder (`bool` or `None`, `optional`, defaults to `None`): |
|
Whether you model has an encoder_decoder architecture. |
|
""" |
|
|
|
pad_token_id: int = 0 |
|
label_pad_token_id: int = -100 |
|
is_encoder_decoder: Optional[bool] = False |
|
|
|
def __call__(self, features: list[dict[str, Any]]) -> dict[str, Any]: |
|
|
|
padded_batch = {} |
|
for k in features[0].keys(): |
|
if k.endswith(("_input_ids", "_attention_mask", "_labels", "_pixel_values")): |
|
if self.is_encoder_decoder: |
|
to_pad = [torch.LongTensor(ex[k]) for ex in features] |
|
|
|
if (k.startswith("prompt")) and (k.endswith("input_ids")): |
|
if self.pad_token_id is None: |
|
raise ValueError( |
|
"Padding is enabled, but the tokenizer is not configured with a padding token." |
|
" Explicitly set `tokenizer.pad_token` (e.g. `tokenizer.pad_token = tokenizer.eos_token`)" |
|
" before calling the trainer." |
|
) |
|
padding_value = self.pad_token_id |
|
elif k.endswith("_attention_mask"): |
|
padding_value = 0 |
|
elif k.startswith(("chosen", "rejected", "completion")) or ("decoder" in k): |
|
padding_value = self.label_pad_token_id |
|
else: |
|
raise ValueError(f"Unexpected key in batch '{k}'") |
|
padded_batch[k] = pad_sequence(to_pad, batch_first=True, padding_value=padding_value) |
|
else: |
|
|
|
if k.endswith("_input_ids"): |
|
if self.pad_token_id is None: |
|
raise ValueError( |
|
"Padding is enabled, but the tokenizer is not configured with a padding token." |
|
" Explicitly set `tokenizer.pad_token` (e.g. `tokenizer.pad_token = tokenizer.eos_token`)" |
|
" before calling the trainer." |
|
) |
|
padding_value = self.pad_token_id |
|
elif k.endswith("_labels"): |
|
padding_value = self.label_pad_token_id |
|
elif k.endswith("_attention_mask"): |
|
padding_value = 0 |
|
elif k.endswith("_pixel_values"): |
|
padding_value = 0 |
|
else: |
|
raise ValueError(f"Unexpected key in batch '{k}'") |
|
|
|
|
|
if k in ["prompt_input_ids", "prompt_attention_mask"]: |
|
padding_side = "left" |
|
else: |
|
padding_side = "right" |
|
|
|
|
|
if k.endswith("_pixel_values"): |
|
dtype = torch.float32 |
|
else: |
|
dtype = torch.int64 |
|
|
|
|
|
to_pad = [torch.tensor(ex[k], dtype=dtype) for ex in features] |
|
padded_batch[k] = pad(to_pad, padding_value=padding_value, padding_side=padding_side) |
|
elif k.endswith("_logps"): |
|
|
|
padded_batch[k] = torch.tensor([ex[k] for ex in features]) |
|
else: |
|
padded_batch[k] = [ex[k] for ex in features] |
|
|
|
return padded_batch |
|
|
|
|
|
class ConstantLengthDataset(IterableDataset): |
|
""" |
|
Iterable dataset that returns constant length chunks of tokens from stream of text files. |
|
The dataset also formats the text before tokenization with a specific format that is provided |
|
by the user. |
|
|
|
Args: |
|
tokenizer (`transformers.PreTrainedTokenizer`): |
|
The processor used for processing the data. |
|
dataset (`dataset.Dataset`): |
|
Dataset with text files. |
|
dataset_text_field (`str` or `None`, *optional*, defaults to `None`): |
|
Name of the field in the dataset that contains the text. Only one of `dataset_text_field` and |
|
`formatting_func` should be provided. |
|
formatting_func (`Callable`, *optional*): |
|
Function that formats the text before tokenization. Usually it is recommended to follow a certain |
|
pattern such as `"### Question: {question} ### Answer: {answer}"`. Only one of `dataset_text_field` and |
|
`formatting_func` should be provided. |
|
infinite (`bool`, *optional*, defaults to `False`): |
|
If True the iterator is reset after dataset reaches end else stops. |
|
seq_length (`int`, *optional*, defaults to `1024`): |
|
Length of token sequences to return. |
|
num_of_sequences (`int`, *optional*, defaults to `1024`): |
|
Number of token sequences to keep in buffer. |
|
chars_per_token (`int`, *optional*, defaults to `3.6`): |
|
Number of characters per token used to estimate number of tokens in text buffer. |
|
eos_token_id (`int`, *optional*, defaults to `0`): |
|
Id of the end of sequence token if the passed tokenizer does not have an EOS token. |
|
shuffle (`bool`, *optional*, defaults to `True`) |
|
Shuffle the examples before they are returned |
|
append_concat_token (`bool`, *optional*, defaults to `True`) |
|
If true, appends `eos_token_id` at the end of each sample being packed. |
|
add_special_tokens (`bool`, *optional*, defaults to `True`) |
|
If true, tokenizers adds special tokens to each sample being packed. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
tokenizer, |
|
dataset, |
|
dataset_text_field=None, |
|
formatting_func=None, |
|
infinite=False, |
|
seq_length=1024, |
|
num_of_sequences=1024, |
|
chars_per_token=3.6, |
|
eos_token_id=0, |
|
shuffle=True, |
|
append_concat_token=True, |
|
add_special_tokens=True, |
|
): |
|
warnings.warn( |
|
"This class is deprecated and will be removed in version 0.20.0. To use packing, use the argument " |
|
"`packing` of `SFTConfig` instead.", |
|
DeprecationWarning, |
|
) |
|
self.tokenizer = tokenizer |
|
self.concat_token_id = tokenizer.eos_token_id if tokenizer.eos_token_id else eos_token_id |
|
self.dataset = dataset |
|
self.seq_length = seq_length |
|
self.infinite = infinite |
|
self.current_size = 0 |
|
self.max_buffer_size = seq_length * chars_per_token * num_of_sequences |
|
self.shuffle = shuffle |
|
self.append_concat_token = append_concat_token |
|
self.add_special_tokens = add_special_tokens |
|
|
|
if dataset_text_field is not None and formatting_func is not None: |
|
warnings.warn( |
|
"Only one of `dataset_text_field` and `formatting_func` should be provided. " |
|
"Ignoring `dataset_text_field` and using `formatting_func`.", |
|
UserWarning, |
|
) |
|
|
|
if formatting_func is not None: |
|
self.formatting_func = formatting_func |
|
elif dataset_text_field is not None: |
|
self.formatting_func = lambda x: x[dataset_text_field] |
|
else: |
|
raise ValueError("Either `dataset_text_field` or `formatting_func` should be provided.") |
|
|
|
self.pretokenized = False |
|
column_names = ( |
|
dataset.column_names if isinstance(dataset, (datasets.Dataset, datasets.IterableDataset)) else None |
|
) |
|
if column_names is not None and "input_ids" in column_names: |
|
self.pretokenized = True |
|
|
|
self.max_buffer_size = seq_length * num_of_sequences |
|
|
|
def __len__(self): |
|
return len(self.dataset) |
|
|
|
def __iter__(self): |
|
iterator = iter(self.dataset) |
|
more_examples = True |
|
while more_examples: |
|
buffer, buffer_len = [], 0 |
|
while True: |
|
if buffer_len >= self.max_buffer_size: |
|
break |
|
try: |
|
buffer.append(self.formatting_func(next(iterator))) |
|
buffer_len += len(buffer[-1]) |
|
except StopIteration: |
|
if self.infinite: |
|
iterator = iter(self.dataset) |
|
else: |
|
more_examples = False |
|
break |
|
if self.shuffle: |
|
random.shuffle(buffer) |
|
if self.pretokenized: |
|
tokenized_inputs = buffer |
|
else: |
|
tokenized_inputs = self.tokenizer( |
|
buffer, add_special_tokens=self.add_special_tokens, truncation=False |
|
)["input_ids"] |
|
all_token_ids = [] |
|
for tokenized_input in tokenized_inputs: |
|
if self.append_concat_token: |
|
tokenized_input = tokenized_input + [self.concat_token_id] |
|
all_token_ids.extend(tokenized_input) |
|
examples = [] |
|
for i in range(0, len(all_token_ids), self.seq_length): |
|
input_ids = all_token_ids[i : i + self.seq_length] |
|
if len(input_ids) == self.seq_length: |
|
examples.append(input_ids) |
|
if self.shuffle: |
|
|
|
random.shuffle(examples) |
|
for example in examples: |
|
self.current_size += 1 |
|
yield { |
|
"input_ids": torch.LongTensor(example), |
|
"labels": torch.LongTensor(example), |
|
} |
|
|
|
|
|
@dataclass |
|
class RunningMoments: |
|
""" |
|
Calculates the running mean and standard deviation of a data stream. Reference: |
|
https://github.com/OpenLMLab/MOSS-RLHF/blob/40b91eb2f2b71b16919addede0341d2bef70825d/utils.py#L75 |
|
""" |
|
|
|
accelerator: Accelerator |
|
mean: float = 0 |
|
std: float = 1 |
|
var: float = 1 |
|
count: float = 1e-24 |
|
|
|
@torch.no_grad() |
|
def update(self, xs: torch.Tensor) -> tuple[float, float]: |
|
""" |
|
Updates running moments from batch's moments computed across ranks |
|
""" |
|
if self.accelerator.use_distributed: |
|
xs_mean, xs_var, xs_count = get_global_statistics(self.accelerator, xs) |
|
else: |
|
xs_count = xs.numel() |
|
xs_var, xs_mean = torch.var_mean(xs, unbiased=False) |
|
xs_mean, xs_var = xs_mean.float(), xs_var.float() |
|
|
|
delta = xs_mean - self.mean |
|
tot_count = self.count + xs_count |
|
|
|
new_sum = xs_var * xs_count |
|
|
|
old_sum = self.var * self.count + delta**2 * self.count * xs_count / tot_count |
|
tot_sum = old_sum + new_sum |
|
|
|
self.mean += (delta * xs_count / tot_count).item() |
|
new_var = tot_sum / tot_count |
|
self.std = (new_var * tot_count / (tot_count - 1)).float().sqrt().item() |
|
self.var = new_var.item() |
|
self.count = tot_count |
|
|
|
return xs_mean.item(), (xs_var * xs_count / (xs_count - 1)).float().sqrt().item() |
|
|
|
def save_to_json(self, json_path: str): |
|
"""Save the content of this instance in JSON format inside `json_path`.""" |
|
|
|
if self.accelerator.is_main_process: |
|
save_dict = dataclasses.asdict(self, dict_factory=lambda x: {k: v for (k, v) in x if k != "accelerator"}) |
|
json_string = json.dumps(save_dict, indent=2, sort_keys=True) + "\n" |
|
with open(json_path, "w", encoding="utf-8") as f: |
|
f.write(json_string) |
|
|
|
@classmethod |
|
def load_from_json(cls, accelerator: Accelerator, json_path: str): |
|
"""Create an instance from the content of `json_path`.""" |
|
|
|
with open(json_path, encoding="utf-8") as f: |
|
text = f.read() |
|
return cls(accelerator=accelerator, **json.loads(text)) |
|
|
|
|
|
@torch.no_grad() |
|
def get_global_statistics( |
|
accelerator, xs: torch.Tensor, mask=None, device="cpu" |
|
) -> tuple[torch.Tensor, torch.Tensor, int]: |
|
""" |
|
Computes element-wise mean and variance of the tensor across processes. Reference: |
|
https://github.com/OpenLMLab/MOSS-RLHF/blob/40b91eb2f2b71b16919addede0341d2bef70825d/utils.py#L57C1-L73C75 |
|
""" |
|
xs = xs.to(accelerator.device) |
|
sum_and_count = torch.tensor([xs.sum(), (xs.numel() if mask is None else mask.sum())], device=xs.device) |
|
sum_and_count = accelerator.reduce(sum_and_count) |
|
global_sum, count = sum_and_count |
|
global_mean = global_sum / count |
|
|
|
sum_var = torch.sum(((xs - global_mean) ** 2).mul(1 if mask is None else mask)) |
|
sum_var = accelerator.reduce(sum_var) |
|
global_var = sum_var / count |
|
|
|
return global_mean.to(device), global_var.to(device), count.item() |
|
|
|
|
|
def compute_accuracy(eval_pred: EvalPrediction) -> dict[str, float]: |
|
predictions, labels = eval_pred |
|
if predictions.ndim == 3: |
|
|
|
|
|
predictions = np.argmax(predictions, axis=2) |
|
|
|
|
|
predictions = np.array( |
|
[p for prediction, label in zip(predictions, labels) for (p, lbl) in zip(prediction, label) if lbl != -100] |
|
) |
|
labels = np.array([lbl for label in labels for lbl in label if lbl != -100]) |
|
|
|
else: |
|
|
|
|
|
equal_mask = predictions[:, 0] == predictions[:, 1] |
|
equal_predictions_count = int(equal_mask.sum()) |
|
|
|
if equal_predictions_count > 0: |
|
warnings.warn( |
|
f"There are {equal_predictions_count} out of {len(predictions[:, 0])} instances where the predictions " |
|
"for both options are equal. These instances are ignored in the accuracy computation.", |
|
UserWarning, |
|
) |
|
|
|
|
|
predictions = predictions[~equal_mask] |
|
labels = labels[~equal_mask] |
|
|
|
|
|
predictions = np.argmax(predictions, axis=1) |
|
|
|
accuracy = np.array(predictions == labels, dtype=float).mean().item() |
|
return {"accuracy": accuracy} |
|
|
|
|
|
def pad_to_length(tensor: torch.Tensor, length: int, pad_value: Union[int, float], dim: int = -1) -> torch.Tensor: |
|
if tensor.size(dim) >= length: |
|
return tensor |
|
else: |
|
pad_size = list(tensor.shape) |
|
pad_size[dim] = length - tensor.size(dim) |
|
return torch.cat( |
|
[ |
|
tensor, |
|
pad_value * torch.ones(*pad_size, dtype=tensor.dtype, device=tensor.device), |
|
], |
|
dim=dim, |
|
) |
|
|
|
|
|
def disable_dropout_in_model(model: torch.nn.Module) -> None: |
|
for module in model.modules(): |
|
if isinstance(module, torch.nn.Dropout): |
|
module.p = 0 |
|
|
|
|
|
def exact_div(a, b, custom_error_message=""): |
|
q = a // b |
|
if a != q * b: |
|
raise ValueError(f"{custom_error_message}, inexact division: {a} / {b} = {a / b}") |
|
return q |
|
|
|
|
|
|
|
class PerPromptStatTracker: |
|
r""" |
|
Class for tracking statistics per prompt. Mainly used to calculate advantage for the DPPO algorithm |
|
|
|
Args: |
|
buffer_size (`int`): |
|
Size of the buffer to keep for each prompt. |
|
min_count (`int`): |
|
Minimum number of samples to keep in the buffer before calculating the mean and std. |
|
""" |
|
|
|
def __init__(self, buffer_size, min_count): |
|
self.buffer_size = buffer_size |
|
self.min_count = min_count |
|
self.stats = {} |
|
|
|
def update(self, prompts, rewards): |
|
prompts = np.array(prompts) |
|
rewards = np.array(rewards) |
|
unique = np.unique(prompts) |
|
advantages = np.empty_like(rewards) |
|
for prompt in unique: |
|
prompt_rewards = rewards[prompts == prompt] |
|
if prompt not in self.stats: |
|
self.stats[prompt] = deque(maxlen=self.buffer_size) |
|
self.stats[prompt].extend(prompt_rewards) |
|
|
|
if len(self.stats[prompt]) < self.min_count: |
|
mean = np.mean(rewards) |
|
std = np.std(rewards) + 1e-6 |
|
else: |
|
mean = np.mean(self.stats[prompt]) |
|
std = np.std(self.stats[prompt]) + 1e-6 |
|
advantages[prompts == prompt] = (prompt_rewards - mean) / std |
|
|
|
return advantages |
|
|
|
def get_stats(self): |
|
return {k: {"mean": np.mean(v), "std": np.std(v), "count": len(v)} for k, v in self.stats.items()} |
|
|
|
|
|
def peft_module_casting_to_bf16(model): |
|
for name, module in model.named_modules(): |
|
if isinstance(module, torch.nn.LayerNorm) or "norm" in name: |
|
module = module.to(torch.float32) |
|
elif any(x in name for x in ["lm_head", "embed_tokens", "wte", "wpe"]): |
|
if hasattr(module, "weight"): |
|
if module.weight.dtype == torch.float32: |
|
module = module.to(torch.bfloat16) |
|
|
|
|
|
def get_quantization_config(model_args: ModelConfig) -> Optional[BitsAndBytesConfig]: |
|
if model_args.load_in_4bit: |
|
quantization_config = BitsAndBytesConfig( |
|
load_in_4bit=True, |
|
bnb_4bit_compute_dtype=model_args.torch_dtype, |
|
bnb_4bit_quant_type=model_args.bnb_4bit_quant_type, |
|
bnb_4bit_use_double_quant=model_args.use_bnb_nested_quant, |
|
bnb_4bit_quant_storage=model_args.torch_dtype, |
|
) |
|
elif model_args.load_in_8bit: |
|
quantization_config = BitsAndBytesConfig( |
|
load_in_8bit=True, |
|
) |
|
else: |
|
quantization_config = None |
|
|
|
return quantization_config |
|
|
|
|
|
def get_kbit_device_map() -> Optional[dict[str, int]]: |
|
if is_torch_xpu_available(): |
|
return {"": f"xpu:{PartialState().local_process_index}"} |
|
elif torch.cuda.is_available(): |
|
return {"": PartialState().local_process_index} |
|
else: |
|
return None |
|
|
|
|
|
def get_peft_config(model_args: ModelConfig) -> "Optional[PeftConfig]": |
|
if model_args.use_peft is False: |
|
return None |
|
|
|
if not is_peft_available(): |
|
raise ValueError( |
|
"You need to have PEFT library installed in your environment, make sure to install `peft`. " |
|
"Make sure to run `pip install -U peft`." |
|
) |
|
|
|
peft_config = LoraConfig( |
|
task_type=model_args.lora_task_type, |
|
r=model_args.lora_r, |
|
target_modules=model_args.lora_target_modules, |
|
lora_alpha=model_args.lora_alpha, |
|
lora_dropout=model_args.lora_dropout, |
|
bias="none", |
|
use_rslora=model_args.use_rslora, |
|
use_dora=model_args.use_dora, |
|
modules_to_save=model_args.lora_modules_to_save, |
|
) |
|
|
|
return peft_config |
|
|
|
|
|
def get_exp_cap(value, decimal=4): |
|
""" |
|
Get the exponent cap of a value. This is used to cap the exponent of a value to avoid overflow. |
|
The formula is : log(value.dtype.max) |
|
E.g. |
|
For float32 data type, the maximum exponent value is 88.7228 to 4 decimal points. |
|
``` |
|
|
|
Args: |
|
value (`torch.Tensor`): |
|
The input tensor to obtain the data type |
|
decimal (`int`): |
|
The number of decimal points of the output exponent cap. |
|
eg: direct calling exp(log(torch.float32.max)) will result in inf |
|
so we cap the exponent to 88.7228 to avoid overflow. |
|
""" |
|
vdtype_max = torch.zeros([1]).to(value.dtype) + torch.finfo(value.dtype).max |
|
vdtype_log_max = torch.log(vdtype_max).to(value.device) |
|
return torch.floor(vdtype_log_max * 10**decimal) / 10**decimal if decimal > 0 else vdtype_log_max |
|
|
|
|
|
def cap_exp(value, cap=-1): |
|
|
|
cap = get_exp_cap(value) if cap < 0 else cap |
|
return torch.exp(torch.clamp(value, max=cap)) |
|
|
|
|
|
def print_rich_table(df: pd.DataFrame) -> Table: |
|
console = Console() |
|
table = Table(show_lines=True) |
|
for column in df.columns: |
|
table.add_column(column) |
|
for _, row in df.iterrows(): |
|
table.add_row(*row.astype(str).tolist()) |
|
console.print(table) |
|
|
|
|
|
SIMPLE_SFT_CHAT_TEMPLATE = "{% for message in messages %}{{' ' + message['content']}}{% endfor %}{{ eos_token }}" |
|
|
|
|
|
SIMPLE_CHAT_TEMPLATE = "{% for message in messages %}{{message['role'].capitalize() + ': ' + message['content'] + '\n\n'}}{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}" |
|
|
|
|
|
@dataclass |
|
class OnlineTrainerState(TrainerState): |
|
episode: int = 0 |
|
|
|
|
|
@dataclass |
|
class OnPolicyConfig(TrainingArguments): |
|
r""" |
|
Base configuration class for on-policy trainers. |
|
|
|
Using [`~transformers.HfArgumentParser`] we can turn this class into |
|
[argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the |
|
command line. |
|
|
|
Parameters: |
|
run_name (`str` or `None`, *optional*, defaults to `None`): |
|
Name of the run. |
|
dataset_num_proc (`int` or `None`, *optional*, defaults to `None`): |
|
Number of processes to use for processing the dataset. |
|
num_mini_batches (`int`, *optional*, defaults to `1`): |
|
Number of minibatches to split a batch into. |
|
total_episodes (`int` or `None`, *optional*, defaults to `None`): |
|
Total number of episodes in the dataset. |
|
local_rollout_forward_batch_size (`int`, *optional*, defaults to `64`): |
|
Per rank no grad forward pass in the rollout phase. |
|
num_sample_generations (`int`, *optional*, defaults to `10`): |
|
Number of debugging samples generations (i.e., `generate_completions` calls) throughout training. |
|
response_length (`int`, *optional*, defaults to `53`): |
|
Length of the response. |
|
stop_token (`str` or `None`, *optional*, defaults to `None`): |
|
Specifies the stop token to use for text generation. This parameter is mutually exclusive with |
|
`stop_token_id`. |
|
|
|
- `None`: No stop token is applied, unless `stop_token_id` is specified. |
|
- `'eos'`: Uses the tokenizer's `eos_token`. |
|
|
|
stop_token_id (`int` or `None`, *optional*, defaults to `None`): |
|
Specifies the ID of the stop token to use for text generation. If `None`, no stop token ID is applied, |
|
unless `stop_token` is specified. This parameter is mutually exclusive with `stop_token`. |
|
temperature (`float`, *optional*, defaults to `0.7`): |
|
Sampling temperature. |
|
missing_eos_penalty (`float` or `None`, *optional*, defaults to `None`): |
|
Penalty applied to the score when the model fails to generate an EOS token. This is useful to encourage |
|
to generate completions shorter than the maximum length (`max_new_tokens`). The penalty must be a positive |
|
value. |
|
sft_model_path (`str`, *optional*, defaults to `"EleutherAI/pythia-160m"`): |
|
Path to the SFT model. |
|
world_size (`int` or `None`, *optional*, defaults to `None`): |
|
Number of processes (GPUs) to use for the training. |
|
num_total_batches (`int` or `None`, *optional*, defaults to `None`): |
|
Number of total batches to train. |
|
micro_batch_size (`int` or `None`, *optional*, defaults to `None`): |
|
Micro batch size across devices (HF's `per_device_train_batch_size` * `world_size`). |
|
local_batch_size (`int` or `None`, *optional*, defaults to `None`): |
|
Batch size per GPU (HF's `per_device_train_batch_size` * `gradient_accumulation_steps`). |
|
batch_size (`int` or `None`, *optional*, defaults to `None`): |
|
Batch size across devices (HF's `per_device_train_batch_size` * `world_size` * `gradient_accumulation_steps`). |
|
local_mini_batch_size (`int` or `None`, *optional*, defaults to `None`): |
|
Mini batch size per GPU. |
|
mini_batch_size (`int` or `None`, *optional*, defaults to `None`): |
|
Mini batch size across GPUs. |
|
push_to_hub (`bool`, *optional*, defaults to `False`): |
|
Whether to push the model to the Hub after training. |
|
""" |
|
|
|
run_name: Optional[str] = field( |
|
default=None, |
|
metadata={"help": "Name of the run."}, |
|
) |
|
dataset_num_proc: Optional[int] = field( |
|
default=None, |
|
metadata={"help": "Number of processes to use for processing the dataset."}, |
|
) |
|
num_mini_batches: int = field( |
|
default=1, |
|
metadata={"help": "Number of minibatches to split a batch into."}, |
|
) |
|
total_episodes: Optional[int] = field( |
|
default=None, |
|
metadata={"help": "Total number of episodes in the dataset."}, |
|
) |
|
local_rollout_forward_batch_size: int = field( |
|
default=64, |
|
metadata={"help": "Per rank no grad forward pass in the rollout phase."}, |
|
) |
|
num_sample_generations: int = field( |
|
default=10, |
|
metadata={ |
|
"help": "Number of debugging samples generations (i.e., `generate_completions` calls) throughout training." |
|
}, |
|
) |
|
response_length: int = field( |
|
default=53, |
|
metadata={"help": "Length of the response."}, |
|
) |
|
stop_token: Optional[Literal["eos"]] = field( |
|
default=None, |
|
metadata={ |
|
"help": "Specifies the stop token to use for text generation. This parameter is mutually exclusive with " |
|
"`stop_token_id`." |
|
}, |
|
) |
|
stop_token_id: Optional[int] = field( |
|
default=None, |
|
metadata={ |
|
"help": "Specifies the ID of the stop token to use for text generation. If `None`, no stop token ID is " |
|
"applied, unless `stop_token` is specified. This parameter is mutually exclusive with `stop_token`." |
|
}, |
|
) |
|
temperature: float = field( |
|
default=0.7, |
|
metadata={"help": "Sampling temperature."}, |
|
) |
|
missing_eos_penalty: Optional[float] = field( |
|
default=None, |
|
metadata={ |
|
"help": "Penalty applied to the score when the model fails to generate an EOS token. This is useful to " |
|
"encourage to generate completions shorter than the maximum length (`max_new_tokens`). The penalty must be " |
|
"a positive value." |
|
}, |
|
) |
|
sft_model_path: str = field( |
|
default="EleutherAI/pythia-160m", |
|
metadata={"help": "Path to the SFT model."}, |
|
) |
|
world_size: Optional[int] = field( |
|
default=None, |
|
metadata={"help": "Number of processes (GPUs) to use for the training."}, |
|
) |
|
num_total_batches: Optional[int] = field( |
|
default=None, |
|
metadata={"help": "Number of total batches to train."}, |
|
) |
|
micro_batch_size: Optional[int] = field( |
|
default=None, |
|
metadata={"help": "Micro batch size across devices (HF's `per_device_train_batch_size` * `world_size`)."}, |
|
) |
|
local_batch_size: Optional[int] = field( |
|
default=None, |
|
metadata={"help": "Batch size per GPU (HF's `per_device_train_batch_size` * `gradient_accumulation_steps`)."}, |
|
) |
|
batch_size: Optional[int] = field( |
|
default=None, |
|
metadata={ |
|
"help": "Batch size across devices (HF's `per_device_train_batch_size` * `world_size` * " |
|
"`gradient_accumulation_steps`)." |
|
}, |
|
) |
|
local_mini_batch_size: Optional[int] = field( |
|
default=None, |
|
metadata={"help": "Mini batch size per GPU."}, |
|
) |
|
mini_batch_size: Optional[int] = field( |
|
default=None, |
|
metadata={"help": "Mini batch size across GPUs."}, |
|
) |
|
push_to_hub: bool = field( |
|
default=False, |
|
metadata={"help": "Whether to push the model to the Hub after training."}, |
|
) |
|
|
|
|
|
def first_true_indices(bools: torch.Tensor, dtype=torch.long): |
|
""" |
|
Takes an N-dimensional bool tensor and returns an (N-1)-dimensional tensor of integers giving |
|
the position of the first True in each "row". |
|
|
|
Returns the length of the rows (bools.size(-1)) if no element is True in a given row. |
|
|
|
Args: |
|
bools (`torch.Tensor`): |
|
An N-dimensional boolean tensor. |
|
dtype (`torch.dtype`, optional): |
|
The desired data type of the output tensor. Defaults to `torch.long`. |
|
|
|
Returns: |
|
`torch.Tensor`: |
|
An (N-1)-dimensional tensor of integers indicating the position of the first True |
|
in each row. If no True value is found in a row, returns the length of the row. |
|
""" |
|
row_len = bools.size(-1) |
|
zero_or_index = row_len * (~bools).type(dtype) + torch.arange(row_len, dtype=dtype, device=bools.device) |
|
return torch.min(zero_or_index, dim=-1).values |
|
|
|
|
|
def get_reward( |
|
model: torch.nn.Module, query_responses: torch.Tensor, pad_token_id: int, context_length: int |
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
|
""" |
|
Computes the reward logits and the rewards for a given model and query responses. |
|
|
|
Args: |
|
model (`torch.nn.Module`): |
|
The model used to compute the reward logits. |
|
query_responses (`torch.Tensor`): |
|
The tensor containing the query responses. |
|
pad_token_id (`int`): |
|
The token ID representing the pad token. |
|
context_length (`int`): |
|
The length of the context in the query responses. |
|
|
|
Returns: |
|
tuple: |
|
- `reward_logits` (`torch.Tensor`): |
|
The logits for the reward model. |
|
- `final_rewards` (`torch.Tensor`): |
|
The final rewards for each query response. |
|
- `sequence_lengths` (`torch.Tensor`): |
|
The lengths of the sequences in the query responses. |
|
""" |
|
attention_mask = query_responses != pad_token_id |
|
position_ids = attention_mask.cumsum(1) - attention_mask.long() |
|
lm_backbone = getattr(model, model.base_model_prefix) |
|
input_ids = torch.masked_fill(query_responses, ~attention_mask, 0) |
|
output = lm_backbone( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
position_ids=position_ids, |
|
return_dict=True, |
|
output_hidden_states=True, |
|
use_cache=False, |
|
) |
|
reward_logits = model.score(output.hidden_states[-1]) |
|
sequence_lengths = first_true_indices(query_responses[:, context_length:] == pad_token_id) - 1 + context_length |
|
|
|
return ( |
|
reward_logits, |
|
reward_logits[ |
|
torch.arange(reward_logits.size(0), device=reward_logits.device), |
|
sequence_lengths, |
|
].squeeze(-1), |
|
sequence_lengths, |
|
) |
|
|
|
|
|
def forward( |
|
model: torch.nn.Module, |
|
query_responses: torch.Tensor, |
|
pad_token_id: int, |
|
) -> torch.nn.Module: |
|
""" |
|
Performs a forward pass through the model with the given query responses and pad token ID. |
|
|
|
Args: |
|
model (`torch.nn.Module`): |
|
The model to perform the forward pass. |
|
query_responses (`torch.Tensor`): |
|
The tensor containing the query responses. |
|
pad_token_id (`int`): |
|
The token ID representing the pad token. |
|
|
|
Returns: |
|
`torch.nn.Module`: |
|
The output of the model, including hidden states. |
|
""" |
|
attention_mask = query_responses != pad_token_id |
|
position_ids = attention_mask.cumsum(1) - attention_mask.long() |
|
input_ids = torch.masked_fill(query_responses, ~attention_mask, 0) |
|
return model( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
position_ids=position_ids, |
|
return_dict=True, |
|
output_hidden_states=True, |
|
) |
|
|
|
|
|
def prepare_deepspeed( |
|
model: torch.nn.Module, per_device_train_batch_size: int, fp16: bool = False, bf16: bool = False |
|
): |
|
""" |
|
Prepares the model for training with DeepSpeed (both for stage 2 and 3), configuring the appropriate settings based on the model and |
|
batch size. |
|
|
|
Args: |
|
model (`torch.nn.Module`): |
|
The model to be prepared for DeepSpeed training. |
|
per_device_train_batch_size (`int`): |
|
The training batch size per device. |
|
|
|
Returns: |
|
`torch.nn.Module`: |
|
The model initialized and configured with DeepSpeed for training. |
|
""" |
|
import deepspeed |
|
|
|
deepspeed_plugin = AcceleratorState().deepspeed_plugin |
|
config_kwargs = deepspeed_plugin.deepspeed_config |
|
if config_kwargs["zero_optimization"]["stage"] != 3: |
|
config_kwargs["train_micro_batch_size_per_gpu"] = per_device_train_batch_size |
|
config_kwargs = { |
|
"train_micro_batch_size_per_gpu": config_kwargs["train_micro_batch_size_per_gpu"], |
|
"prescale_gradients": False, |
|
"wall_clock_breakdown": False, |
|
} |
|
if bf16: |
|
config_kwargs["bf16"] = {"enabled": True} |
|
elif fp16: |
|
config_kwargs["fp16"] = {"enabled": True} |
|
else: |
|
if hasattr(model, "config"): |
|
hidden_size = ( |
|
max(model.config.hidden_sizes) |
|
if getattr(model.config, "hidden_sizes", None) |
|
else getattr(model.config, "hidden_size", None) |
|
) |
|
if hidden_size is not None and config_kwargs["zero_optimization"]["stage"] == 3: |
|
|
|
|
|
config_kwargs.update( |
|
{ |
|
"zero_optimization.reduce_bucket_size": hidden_size * hidden_size, |
|
"zero_optimization.stage3_param_persistence_threshold": 10 * hidden_size, |
|
"zero_optimization.stage3_prefetch_bucket_size": 0, |
|
} |
|
) |
|
model, *_ = deepspeed.initialize(model=model, config=config_kwargs) |
|
model.eval() |
|
return model |
|
|
|
|
|
def truncate_response(stop_token_id: int, pad_token_id: int, responses: torch.Tensor): |
|
""" |
|
Truncates the responses at the first occurrence of the stop token, filling the rest with pad tokens. |
|
|
|
Args: |
|
stop_token_id (`int`): |
|
The token ID representing the stop token where truncation occurs. |
|
pad_token_id (`int`): |
|
The token ID representing the pad token used to fill the truncated responses. |
|
responses (`torch.Tensor`): |
|
The tensor containing the responses to be truncated. |
|
|
|
Returns: |
|
`torch.Tensor`: |
|
The truncated responses tensor with pad tokens filled after the stop token. |
|
""" |
|
trunc_idxs = first_true_indices(responses == stop_token_id).unsqueeze(-1) |
|
new_size = [1] * (len(responses.size()) - 1) + [responses.shape[1]] |
|
idxs = torch.arange(responses.shape[1], device=responses.device).view(*new_size) |
|
postprocessed_responses = torch.masked_fill(responses, idxs > trunc_idxs, pad_token_id) |
|
return postprocessed_responses |
|
|
|
|
|
def generate( |
|
lm_backbone: torch.nn.Module, queries: torch.Tensor, pad_token_id: int, generation_config: GenerationConfig |
|
) -> tuple[torch.Tensor, torch.Tensor]: |
|
""" |
|
Generates sequences from the language model backbone in a way that does not affect padding tokens. |
|
|
|
Args: |
|
lm_backbone (`torch.nn.Module`): |
|
The language model backbone used for generation. |
|
queries (`torch.Tensor`): |
|
The tensor containing the input queries. |
|
pad_token_id (`int`): |
|
The token ID representing the pad token. |
|
generation_config (`GenerationConfig`): |
|
The configuration for the generation process. |
|
|
|
Returns: |
|
tuple: |
|
- `generated_sequences` (`torch.Tensor`): |
|
The concatenated tensor of input queries and generated sequences. |
|
- `logits` (`torch.Tensor`): |
|
The logits output from the generation process. |
|
""" |
|
context_length = queries.shape[1] |
|
attention_mask = queries != pad_token_id |
|
input_ids = torch.masked_fill(queries, ~attention_mask, 0) |
|
output = lm_backbone.generate( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
|
|
|
|
generation_config=generation_config, |
|
return_dict_in_generate=True, |
|
output_scores=True, |
|
) |
|
logits = torch.stack(output.scores, 1) |
|
return torch.cat((queries, output.sequences[:, context_length:]), dim=1), logits |
|
|
|
|
|
@torch.no_grad() |
|
def batch_generation( |
|
model: torch.nn.Module, |
|
queries: torch.Tensor, |
|
local_rollout_forward_batch_size: int, |
|
pad_token_id: int, |
|
generation_config: GenerationConfig, |
|
): |
|
query_responses = [] |
|
logitss = [] |
|
batch_size = queries.shape[0] |
|
for i in range(0, batch_size, local_rollout_forward_batch_size): |
|
query = queries[i : i + local_rollout_forward_batch_size] |
|
query_response, logits = generate( |
|
model, |
|
query, |
|
pad_token_id, |
|
generation_config, |
|
) |
|
query_responses.append(query_response) |
|
logitss.append(logits) |
|
|
|
|
|
padded_query_responses = pad(query_responses, padding_value=pad_token_id, padding_side="right") |
|
padded_logitss = pad(logitss, padding_value=0, padding_side="right") |
|
|
|
|
|
padded_query_responses = padded_query_responses.view(-1, padded_query_responses.shape[-1])[:batch_size] |
|
padded_logitss = padded_logitss.view(-1, *padded_logitss.shape[2:])[:batch_size] |
|
|
|
return padded_query_responses, padded_logitss |
|
|
|
|
|
def add_bos_token_if_needed( |
|
bos_token_id: Optional[int], |
|
prompt_len_input_ids: int, |
|
prompt_tokens: dict[str, list[int]], |
|
chosen_prompt_len_input_ids: int, |
|
chosen_tokens: dict[str, list[int]], |
|
rejected_prompt_len_input_ids: int, |
|
rejected_tokens: dict[str, list[int]], |
|
): |
|
if bos_token_id is not None: |
|
if prompt_len_input_ids == 0 or bos_token_id != prompt_tokens["prompt_input_ids"][0]: |
|
prompt_tokens["prompt_input_ids"] = [bos_token_id] + prompt_tokens["prompt_input_ids"] |
|
prompt_tokens["prompt_attention_mask"] = [1] + prompt_tokens["prompt_attention_mask"] |
|
if chosen_prompt_len_input_ids == 0 or bos_token_id != chosen_tokens["prompt_input_ids"][0]: |
|
chosen_tokens["prompt_input_ids"] = [bos_token_id] + chosen_tokens["prompt_input_ids"] |
|
chosen_tokens["prompt_attention_mask"] = [1] + chosen_tokens["prompt_attention_mask"] |
|
if rejected_prompt_len_input_ids == 0 or bos_token_id != rejected_tokens["prompt_input_ids"][0]: |
|
rejected_tokens["prompt_input_ids"] = [bos_token_id] + rejected_tokens["prompt_input_ids"] |
|
rejected_tokens["prompt_attention_mask"] = [1] + rejected_tokens["prompt_attention_mask"] |
|
return prompt_tokens, chosen_tokens, rejected_tokens |
|
|
|
|
|
def add_eos_token_if_needed( |
|
eos_token_id: int, chosen_tokens: dict[str, list[int]], rejected_tokens: dict[str, list[int]] |
|
): |
|
if len(chosen_tokens["input_ids"]) == 0 or eos_token_id != chosen_tokens["input_ids"][-1]: |
|
chosen_tokens["input_ids"].append(eos_token_id) |
|
chosen_tokens["attention_mask"].append(1) |
|
if len(rejected_tokens["input_ids"]) == 0 or eos_token_id != rejected_tokens["input_ids"][-1]: |
|
rejected_tokens["input_ids"].append(eos_token_id) |
|
rejected_tokens["attention_mask"].append(1) |
|
return chosen_tokens, rejected_tokens |
|
|
|
|
|
def truncate_right( |
|
input_ids: torch.Tensor, stop_token_id: int, pad_token_id: int |
|
) -> tuple[torch.Tensor, torch.Tensor]: |
|
""" |
|
Truncates the input tensor from the right side after the first occurrence of the stop token. |
|
|
|
Args: |
|
input_ids (`torch.Tensor`): |
|
The tensor containing the responses to be truncated |
|
stop_token_id (`int`): |
|
The token ID representing the stop token where truncation occurs |
|
pad_token_id (`int`): |
|
The token ID representing the pad token used to fill the truncated responses |
|
|
|
Returns: |
|
tuple: |
|
- `output_ids` (`torch.Tensor`): |
|
The truncated responses tensor with pad tokens filled after the stop token |
|
- `mask` (`torch.Tensor`): |
|
The mask tensor to indicate the padding tokens |
|
""" |
|
trunc_idxs = first_true_indices(input_ids == stop_token_id).unsqueeze(-1) |
|
new_size = [1] * (len(input_ids.size()) - 1) + [input_ids.shape[1]] |
|
idxs = torch.arange(input_ids.shape[1], device=input_ids.device).view(*new_size) |
|
output_ids = torch.masked_fill(input_ids, idxs > trunc_idxs, pad_token_id) |
|
mask = torch.masked_fill(torch.ones_like(input_ids), idxs > trunc_idxs, 0) |
|
return output_ids, mask |
|
|
|
|
|
def empty_cache() -> None: |
|
"""Empties the cache of the available torch device. |
|
|
|
This function checks for the availability of different torch devices (XPU, MLU, NPU, CUDA) |
|
and empties the cache of the first available device it finds. |
|
|
|
If none of the specific devices are available, it defaults to emptying the CUDA cache. |
|
""" |
|
if is_torch_xpu_available(): |
|
torch.xpu.empty_cache() |
|
elif is_torch_mlu_available(): |
|
torch.mlu.empty_cache() |
|
elif is_torch_npu_available(): |
|
torch.npu.empty_cache() |
|
else: |
|
torch.cuda.empty_cache() |
|
|
|
|
|
def decode_and_strip_padding(inputs: torch.Tensor, tokenizer: PreTrainedTokenizerBase) -> list[str]: |
|
""" |
|
Decodes the input tensor and strips the padding tokens. |
|
|
|
Args: |
|
inputs (`torch.Tensor`): |
|
The input tensor to be decoded. |
|
tokenizer (`transformers.PreTrainedTokenizerBase`): |
|
The tokenizer used to decode the input tensor. |
|
|
|
Returns: |
|
`list[str]`: |
|
The list of decoded strings with padding tokens stripped. |
|
""" |
|
decoded = tokenizer.batch_decode(inputs, skip_special_tokens=False) |
|
return [d.replace(tokenizer.pad_token, "") for d in decoded] |
|
|
|
|
|
def generate_model_card( |
|
base_model: Optional[str], |
|
model_name: str, |
|
hub_model_id: str, |
|
dataset_name: Optional[str], |
|
tags: list[str], |
|
wandb_url: Optional[str], |
|
trainer_name: str, |
|
trainer_citation: Optional[str] = None, |
|
paper_title: Optional[str] = None, |
|
paper_id: Optional[str] = None, |
|
comet_url: Optional[str] = None, |
|
) -> ModelCard: |
|
""" |
|
Generate a `ModelCard` from a template. |
|
|
|
Args: |
|
base_model (`str` or `None`): |
|
Base model name. |
|
model_name (`str`): |
|
Model name. |
|
hub_model_id (`str`): |
|
Hub model ID as `username/model_id`. |
|
dataset_name (`str` or `None`): |
|
Dataset name. |
|
tags (`list[str]`): |
|
Tags. |
|
wandb_url (`str` or `None`): |
|
Weights & Biases run URL. |
|
comet_url (`str` or `None`): |
|
Comet experiment URL. |
|
trainer_name (`str`): |
|
Trainer name. |
|
trainer_citation (`str` or `None`, defaults to `None`): |
|
Trainer citation as a BibTeX entry. |
|
paper_title (`str` or `None`, defaults to `None`): |
|
Paper title. |
|
paper_id (`str` or `None`, defaults to `None`): |
|
ArXiv paper ID as `YYMM.NNNNN`. |
|
|
|
Returns: |
|
`ModelCard`: |
|
A ModelCard object. |
|
""" |
|
card_data = ModelCardData( |
|
base_model=base_model, |
|
datasets=dataset_name, |
|
library_name="transformers", |
|
licence="license", |
|
model_name=model_name, |
|
tags=["generated_from_trainer", *tags], |
|
) |
|
card = ModelCard.from_template( |
|
card_data, |
|
template_path=str(pkg_resources.files("trl").joinpath("templates/lm_model_card.md")), |
|
base_model=base_model, |
|
model_name=model_name, |
|
hub_model_id=hub_model_id, |
|
dataset_name=dataset_name, |
|
wandb_url=wandb_url, |
|
comet_url=comet_url, |
|
trainer_name=trainer_name, |
|
trainer_citation=trainer_citation, |
|
paper_title=paper_title, |
|
paper_id=paper_id, |
|
trl_version=version("trl"), |
|
transformers_version=version("transformers"), |
|
pytorch_version=version("torch"), |
|
datasets_version=version("datasets"), |
|
tokenizers_version=version("tokenizers"), |
|
) |
|
return card |
|
|
|
|
|
def get_comet_experiment_url() -> Optional[str]: |
|
""" |
|
If Comet integration is enabled, return the URL of the current Comet experiment; otherwise, return `None`. |
|
""" |
|
if not is_comet_available(): |
|
return None |
|
|
|
if comet_ml.get_running_experiment() is not None: |
|
return comet_ml.get_running_experiment().url |
|
|
|
return None |
|
|
|
|
|
def log_table_to_comet_experiment(name: str, table: pd.DataFrame) -> None: |
|
""" |
|
If Comet integration is enabled logs a table to the Comet experiment if it is currently running. |
|
|
|
Args: |
|
name (`str`): |
|
Table name. |
|
table (`pd.DataFrame`): |
|
The Pandas DataFrame containing the table to log. |
|
""" |
|
if not is_comet_available(): |
|
raise ModuleNotFoundError("The comet-ml is not installed. Please install it first: pip install comet-ml") |
|
|
|
experiment = comet_ml.get_running_experiment() |
|
if experiment is not None: |
|
experiment.log_table(tabular_data=table, filename=name) |
|
|
|
|
|
def flush_left(mask: torch.Tensor, *tensors: torch.Tensor) -> tuple[torch.Tensor, ...]: |
|
""" |
|
Shift non-zero elements in the mask and corresponding tensors to the left. |
|
|
|
This function operates on a binary mask and any number of additional tensors with the same dimensions as the mask. |
|
For each row, non-zero values are shifted to the leftmost positions. Then, columns that contain only zeros across |
|
all rows are truncated from the mask and tensors. Visually, this operation can be represented as follows: |
|
|
|
``` |
|
[[0, 0, x, x, x, x], -> [[x, x, x, x], |
|
[0, x, x, x, 0, 0]] [x, x, x, 0]] |
|
``` |
|
|
|
Args: |
|
|
|
mask (`torch.Tensor`): |
|
2D tensor (binary mask) with shape `(N, M)`. |
|
*tensors (`torch.Tensor`) |
|
One or more 2D tensors with the same shape as `mask`. These tensors will be processed alongside `mask`, |
|
with non-zero values shifted and excess zero columns truncated in the same manner. |
|
|
|
Returns: |
|
`torch.Tensor`: |
|
Updated binary mask with non-zero values flushed to the left and trailing zero columns removed. |
|
`*torch.Tensor` |
|
Updated tensors, processed in the same way as the mask. |
|
|
|
Example: |
|
```python |
|
>>> mask = torch.tensor([[0, 0, 1, 1, 1], |
|
... [0, 1, 1, 0, 0]]) |
|
>>> tensor = torch.tensor([[9, 9, 2, 3, 4], |
|
... [9, 5, 6, 9, 9]]) |
|
>>> new_mask, new_tensor = flush_left(mask, tensor) |
|
>>> print(new_mask) |
|
tensor([[1, 1, 1], |
|
[1, 1, 0]]) |
|
>>> print(new_tensor) |
|
tensor([[2, 3, 4], |
|
[5, 6, 0]]) |
|
``` |
|
""" |
|
|
|
mask = mask.clone() |
|
tensors = [t.clone() for t in tensors] |
|
|
|
|
|
for i in range(mask.size(0)): |
|
first_one_idx = torch.nonzero(mask[i])[0].item() |
|
mask[i] = torch.roll(mask[i], shifts=-first_one_idx) |
|
for tensor in tensors: |
|
tensor[i] = torch.roll(tensor[i], shifts=-first_one_idx) |
|
|
|
|
|
empty_cols = torch.sum(mask, dim=0) == 0 |
|
first_empty_col = torch.nonzero(empty_cols)[0].item() if empty_cols.any() else mask.size(1) |
|
mask = mask[:, :first_empty_col] |
|
for i, tensor in enumerate(tensors): |
|
tensors[i] = tensor[:, :first_empty_col] |
|
|
|
if not tensors: |
|
return mask |
|
else: |
|
return mask, *tensors |
|
|
|
|
|
def selective_log_softmax(logits, index): |
|
""" |
|
A memory-efficient implementation of the common `log_softmax -> gather` operation. |
|
|
|
This function is equivalent to the following naive implementation: |
|
```python |
|
logps = torch.gather(logits.log_softmax(-1), dim=-1, index=index.unsqueeze(-1)).squeeze(-1) |
|
``` |
|
|
|
Args: |
|
logits (`torch.Tensor`): |
|
Logits tensor of shape `(..., num_classes)`. |
|
index (`torch.Tensor`): |
|
Index tensor of shape `(...)`, specifying the positions to gather from the log-softmax output. |
|
|
|
Returns: |
|
`torch.Tensor`: |
|
Gathered log probabilities with the same shape as `index`. |
|
""" |
|
if logits.dtype in [torch.float32, torch.float64]: |
|
selected_logits = torch.gather(logits, dim=-1, index=index.unsqueeze(-1)).squeeze(-1) |
|
|
|
logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits]) |
|
per_token_logps = selected_logits - logsumexp_values |
|
else: |
|
|
|
per_token_logps = [] |
|
for row_logits, row_labels in zip(logits, index): |
|
row_logps = F.log_softmax(row_logits, dim=-1) |
|
row_per_token_logps = row_logps.gather(dim=-1, index=row_labels.unsqueeze(-1)).squeeze(-1) |
|
per_token_logps.append(row_per_token_logps) |
|
per_token_logps = torch.stack(per_token_logps) |
|
return per_token_logps |
|
|
|
|
|
def print_prompt_completions_sample( |
|
prompts: list[str], completions: list[str], rewards: dict[str, list[float]], step: int, num_samples: int = None |
|
) -> None: |
|
""" |
|
Print out a sample of model completions to the console with multiple reward metrics. |
|
|
|
This function creates a nicely formatted table showing prompt-completion pairs, useful for monitoring model outputs |
|
during training. It requires the `rich` library to be installed. |
|
|
|
Args: |
|
prompts (`list[str]`): |
|
List of prompts. |
|
completions (`list[str]`): |
|
List of completions corresponding to the prompts. |
|
rewards (`dict[str, list[float]]`): |
|
Dictionary where keys are reward names and values are lists of rewards. |
|
step (`int`): |
|
Current training step number, used in the output title. |
|
num_samples (`int` or `None`, *optional*, defaults to `None`): |
|
Number of random samples to display. If `None` (default), all items will be displayed. |
|
|
|
Example: |
|
```python |
|
>>> from trl.trainer.utils import print_prompt_completions_sample |
|
>>> prompts = ["The sky is", "The sun is"] |
|
>>> completions = [" blue.", " in the sky."] |
|
>>> rewards = {"Correctness": [0.123, 0.456], "Format": [0.789, 0.101]} |
|
>>> print_prompt_completions_sample(prompts, completions, rewards, 42) |
|
โญโโโโโโโโโโโโโโโโโโโโโโ Step 42 โโโโโโโโโโโโโโโโโโโโโโโโฎ |
|
โ โโโโโโโโโโโโโโณโโโโโโโโโโโโโโโณโโโโโโโโโโโโโโณโโโโโโโโโ โ |
|
โ โ Prompt โ Completion โ Correctness โ Format โ โ |
|
โ โกโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโฉ โ |
|
โ โ The sky is โ blue. โ 0.12 โ 0.79 โ โ |
|
โ โโโโโโโโโโโโโโผโโโโโโโโโโโโโโโผโโโโโโโโโโโโโโผโโโโโโโโโค โ |
|
โ โ The sun is โ in the sky. โ 0.46 โ 0.10 โ โ |
|
โ โโโโโโโโโโโโโโดโโโโโโโโโโโโโโโดโโโโโโโโโโโโโโดโโโโโโโโโ โ |
|
โฐโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโฏ |
|
``` |
|
""" |
|
console = Console() |
|
table = Table(show_header=True, header_style="bold white", expand=True) |
|
|
|
|
|
table.add_column("Prompt", style="bright_yellow") |
|
table.add_column("Completion", style="bright_green") |
|
for reward_name in rewards.keys(): |
|
table.add_column(reward_name, style="bold cyan", justify="right") |
|
|
|
|
|
if num_samples is not None: |
|
if num_samples >= len(prompts): |
|
num_samples = None |
|
elif num_samples <= 0: |
|
return |
|
|
|
|
|
if num_samples is not None: |
|
indices = random.sample(range(len(prompts)), num_samples) |
|
prompts = [prompts[i] for i in indices] |
|
completions = [completions[i] for i in indices] |
|
rewards = {key: [val[i] for i in indices] for key, val in rewards.items()} |
|
|
|
for i in range(len(prompts)): |
|
reward_values = [f"{rewards[key][i]:.2f}" for key in rewards.keys()] |
|
table.add_row(Text(prompts[i]), Text(completions[i]), *reward_values) |
|
table.add_section() |
|
|
|
panel = Panel(table, expand=False, title=f"Step {step}", border_style="bold white") |
|
console.print(panel) |
|
|