trl-sandbox / tests /test_sft_trainer.py
ivangabriele's picture
feat: initialize project
2f5127c verified
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import tempfile
import unittest
import numpy as np
import torch
from datasets import Dataset, Image, Sequence, load_dataset
from parameterized import parameterized
from transformers import (
AutoModelForCausalLM,
AutoProcessor,
AutoTokenizer,
LlavaForConditionalGeneration,
TrainingArguments,
is_vision_available,
)
from transformers.testing_utils import require_flash_attn, require_peft, require_vision
from transformers.utils import is_peft_available
from trl import SFTConfig, SFTTrainer
from trl.trainer import ConstantLengthDataset, DataCollatorForCompletionOnlyLM
from trl.trainer.sft_trainer import DataCollatorForLanguageModeling
def formatting_prompts_func(example):
text = f"### Question: {example['question']}\n ### Answer: {example['answer']}"
return text
def formatting_func_for_pretokenized(example):
return example["input_ids"]
if is_peft_available():
from peft import LoraConfig, PeftModel, get_peft_model
if is_vision_available():
from PIL import Image as PILImage
class TestDataCollatorForLanguageModeling(unittest.TestCase):
def test_basic_padding(self):
"""Test basic padding functionality without completion masks."""
self.collator = DataCollatorForLanguageModeling(pad_token_id=0)
examples = [{"input_ids": [1, 2, 3]}, {"input_ids": [4, 5]}]
result = self.collator(examples)
torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3], [4, 5, 0]]))
torch.testing.assert_close(result["attention_mask"], torch.tensor([[1, 1, 1], [1, 1, 0]]))
torch.testing.assert_close(result["position_ids"], torch.tensor([[0, 1, 2], [0, 1, 0]]))
torch.testing.assert_close(result["labels"], torch.tensor([[1, 2, 3], [4, 5, -100]]))
def test_completion_mask(self):
"""Test completion mask functionality."""
self.collator = DataCollatorForLanguageModeling(pad_token_id=0)
examples = [
{"input_ids": [1, 2, 3], "completion_mask": [0, 1, 1]},
{"input_ids": [4, 5], "completion_mask": [0, 1]},
]
result = self.collator(examples)
torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3], [4, 5, 0]]))
torch.testing.assert_close(result["attention_mask"], torch.tensor([[1, 1, 1], [1, 1, 0]]))
torch.testing.assert_close(result["position_ids"], torch.tensor([[0, 1, 2], [0, 1, 0]]))
torch.testing.assert_close(result["labels"], torch.tensor([[-100, 2, 3], [-100, 5, -100]]))
def test_completion_only_loss_disabled(self):
"""Test behavior when completion_only_loss is disabled."""
collator = DataCollatorForLanguageModeling(pad_token_id=0, completion_only_loss=False)
examples = [
{"input_ids": [1, 2, 3], "completion_mask": [0, 1, 1]},
{"input_ids": [4, 5], "completion_mask": [0, 1]},
]
result = collator(examples)
# Labels should not be masked when completion_only_loss=False
torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3], [4, 5, 0]]))
torch.testing.assert_close(result["attention_mask"], torch.tensor([[1, 1, 1], [1, 1, 0]]))
torch.testing.assert_close(result["position_ids"], torch.tensor([[0, 1, 2], [0, 1, 0]]))
torch.testing.assert_close(result["labels"], torch.tensor([[1, 2, 3], [4, 5, -100]]))
def test_padding_free_mode(self):
"""Test padding-free mode where sequences are concatenated."""
collator = DataCollatorForLanguageModeling(pad_token_id=0, padding_free=True)
examples = [{"input_ids": [1, 2, 3]}, {"input_ids": [4, 5]}]
result = collator(examples)
torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3, 4, 5]]))
torch.testing.assert_close(result["attention_mask"], torch.tensor([[1, 1, 1, 1, 1]]))
torch.testing.assert_close(result["position_ids"], torch.tensor([[0, 1, 2, 0, 1]]))
torch.testing.assert_close(result["labels"], torch.tensor([[1, 2, 3, 4, 5]]))
def test_padding_free_with_completion_mask(self):
"""Test padding-free mode with completion masks."""
collator = DataCollatorForLanguageModeling(pad_token_id=0, padding_free=True)
examples = [
{"input_ids": [1, 2, 3], "completion_mask": [0, 1, 1]},
{"input_ids": [4, 5], "completion_mask": [1, 1]},
]
result = collator(examples)
torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3, 4, 5]]))
torch.testing.assert_close(result["attention_mask"], torch.tensor([[1, 1, 1, 1, 1]]))
torch.testing.assert_close(result["position_ids"], torch.tensor([[0, 1, 2, 0, 1]]))
torch.testing.assert_close(result["labels"], torch.tensor([[-100, 2, 3, 4, 5]]))
def test_pad_to_multiple_of(self):
"""Test padding to multiple of specified value."""
collator = DataCollatorForLanguageModeling(pad_token_id=0, pad_to_multiple_of=4)
examples = [{"input_ids": [1, 2, 3]}, {"input_ids": [4, 5]}]
result = collator(examples)
torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3, 0], [4, 5, 0, 0]]))
torch.testing.assert_close(result["attention_mask"], torch.tensor([[1, 1, 1, 0], [1, 1, 0, 0]]))
torch.testing.assert_close(result["position_ids"], torch.tensor([[0, 1, 2, 0], [0, 1, 0, 0]]))
torch.testing.assert_close(result["labels"], torch.tensor([[1, 2, 3, -100], [4, 5, -100, -100]]))
def test_custom_position_ids(self):
"""Test handling of custom position IDs in examples."""
self.collator = DataCollatorForLanguageModeling(pad_token_id=0)
examples = [{"input_ids": [1, 2, 3], "position_ids": [0, 0, 1]}, {"input_ids": [4, 5], "position_ids": [0, 1]}]
result = self.collator(examples)
torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3], [4, 5, 0]]))
torch.testing.assert_close(result["attention_mask"], torch.tensor([[1, 1, 1], [1, 1, 0]]))
torch.testing.assert_close(result["position_ids"], torch.tensor([[0, 0, 1], [0, 1, 0]]))
torch.testing.assert_close(result["labels"], torch.tensor([[1, 2, 3], [4, 5, -100]]))
def test_single_example(self):
"""Test collator with a single example."""
self.collator = DataCollatorForLanguageModeling(pad_token_id=0)
examples = [{"input_ids": [1, 2, 3, 4]}]
result = self.collator(examples)
torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3, 4]]))
torch.testing.assert_close(result["attention_mask"], torch.tensor([[1, 1, 1, 1]]))
torch.testing.assert_close(result["position_ids"], torch.tensor([[0, 1, 2, 3]]))
torch.testing.assert_close(result["labels"], torch.tensor([[1, 2, 3, 4]]))
def test_different_pad_token_id(self):
"""Test with different pad token ID."""
collator = DataCollatorForLanguageModeling(pad_token_id=999)
examples = [{"input_ids": [1, 2, 3]}, {"input_ids": [4, 5]}]
result = collator(examples)
torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3], [4, 5, 999]]))
torch.testing.assert_close(result["attention_mask"], torch.tensor([[1, 1, 1], [1, 1, 0]]))
torch.testing.assert_close(result["position_ids"], torch.tensor([[0, 1, 2], [0, 1, 0]]))
torch.testing.assert_close(result["labels"], torch.tensor([[1, 2, 3], [4, 5, -100]]))
class SFTTrainerTester(unittest.TestCase):
r""" """
def setUp(self):
self.model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
self.model = AutoModelForCausalLM.from_pretrained(self.model_id)
self.tokenizer = AutoTokenizer.from_pretrained(self.model_id)
self.dummy_dataset = Dataset.from_dict(
{
"question": [
"Does llamas know how to code?",
"Does llamas know how to fly?",
"Does llamas know how to talk?",
"Does llamas know how to code?",
"Does llamas know how to fly?",
"Does llamas know how to talk?",
"Does llamas know how to swim?",
],
"answer": [
"Yes, llamas are very good at coding.",
"No, llamas can't fly.",
"Yes, llamas are very good at talking.",
"Yes, llamas are very good at coding.",
"No, llamas can't fly.",
"Yes, llamas are very good at talking.",
"No, llamas can't swim.",
],
"text": [
"### Question: Does llamas know how to code?\n ### Answer: Yes, llamas are very good at coding.",
"### Question: Does llamas know how to fly?\n ### Answer: No, llamas can't fly.",
"### Question: Does llamas know how to talk?\n ### Answer: Yes, llamas are very good at talking.",
"### Question: Does llamas know how to code?\n ### Answer: Yes, llamas are very good at coding.",
"### Question: Does llamas know how to fly?\n ### Answer: No, llamas can't fly.",
"### Question: Does llamas know how to talk?\n ### Answer: Yes, llamas are very good at talking.",
"### Question: Does llamas know how to swim?\n ### Answer: No, llamas can't swim.",
],
}
)
self.dummy_tokenized_dataset = Dataset.from_dict(
{
"input_ids": [
self.tokenizer.encode(
"TRL is a library to post-train LLMs and diffusion models with methods such as Supervised Fine-tuning (SFT), Proximal Policy Optimization (PPO), and Direct Preference Optimization (DPO)."
)
]
* 10
}
)
self.conversational_lm_dataset = load_dataset("trl-internal-testing/zen", "conversational_language_modeling")
self.standard_prompt_completion_dataset = load_dataset(
"trl-internal-testing/zen", "standard_prompt_completion"
)
if is_vision_available():
self.dummy_vsft_instruction_dataset = Dataset.from_dict(
{
"messages": [
[
{
"role": "user",
"content": [{"type": "text", "text": "What is in this image?"}, {"type": "image"}],
},
{
"role": "assistant",
"content": [{"type": "text", "text": "It is random noise."}],
},
{
"role": "user",
"content": [{"type": "text", "text": "Oh ye, you are right, what is 1+1"}],
},
{
"role": "assistant",
"content": [{"type": "text", "text": "2"}],
},
],
[
{
"role": "user",
"content": [{"type": "text", "text": "What is in this image?"}, {"type": "image"}],
},
{
"role": "assistant",
"content": [{"type": "text", "text": "It is random noise."}],
},
],
],
"images": [
[PILImage.fromarray((np.random.rand(40, 50, 3) * 255).astype("uint8")).convert("RGBA")],
[PILImage.fromarray((np.random.rand(50, 60, 3) * 255).astype("uint8")).convert("RGBA")],
],
}
)
self.dummy_vsft_instruction_dataset.cast_column("images", Sequence(Image()))
self.dummy_vsft_instruction_dataset = self.dummy_vsft_instruction_dataset.cast_column(
"images", Sequence(Image())
)
self.train_dataset = ConstantLengthDataset(
self.tokenizer,
self.dummy_dataset,
formatting_func=formatting_prompts_func,
seq_length=16,
num_of_sequences=16,
)
self.eval_dataset = ConstantLengthDataset(
self.tokenizer,
self.dummy_dataset,
formatting_func=formatting_prompts_func,
seq_length=16,
num_of_sequences=16,
)
self.train_dataset_from_pretokenized = ConstantLengthDataset(
self.tokenizer,
self.dummy_tokenized_dataset,
seq_length=16,
num_of_sequences=16,
formatting_func=formatting_func_for_pretokenized,
)
self.eval_dataset_from_pretokenized = ConstantLengthDataset(
self.tokenizer,
self.dummy_tokenized_dataset,
seq_length=16,
num_of_sequences=16,
formatting_func=formatting_func_for_pretokenized,
)
def test_constant_length_dataset_with_pretokenized_data(self):
constant_len_dataset = ConstantLengthDataset(
self.tokenizer,
self.dummy_tokenized_dataset,
formatting_func=formatting_func_for_pretokenized,
)
assert len(constant_len_dataset) == len(self.dummy_tokenized_dataset)
assert len(constant_len_dataset) > 0
for example in constant_len_dataset:
assert "input_ids" in example
assert "labels" in example
assert len(example["input_ids"]) == constant_len_dataset.seq_length
assert len(example["labels"]) == constant_len_dataset.seq_length
decoded_text = self.tokenizer.decode(example["input_ids"])
assert ("TRL" in decoded_text) and ("(DPO)" in decoded_text)
def test_constant_length_dataset(self):
formatted_dataset = ConstantLengthDataset(
self.tokenizer,
self.dummy_dataset,
formatting_func=formatting_prompts_func,
)
self.assertEqual(len(formatted_dataset), len(self.dummy_dataset))
self.assertGreater(len(formatted_dataset), 0)
for example in formatted_dataset:
self.assertIn("input_ids", example)
self.assertIn("labels", example)
self.assertEqual(len(example["input_ids"]), formatted_dataset.seq_length)
self.assertEqual(len(example["labels"]), formatted_dataset.seq_length)
decoded_text = self.tokenizer.decode(example["input_ids"])
self.assertTrue(("Question" in decoded_text) and ("Answer" in decoded_text))
def test_backward_compatibility(self):
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = TrainingArguments(
output_dir=tmp_dir,
per_device_train_batch_size=2,
hub_token="not_a_real_token",
report_to="none",
)
trainer = SFTTrainer(
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
args=training_args,
train_dataset=self.train_dataset,
formatting_func=formatting_prompts_func,
)
self.assertEqual(trainer.args.hub_token, training_args.hub_token)
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
trainer.train()
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
# Check that the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")
def test_with_pretokenized_data_packing(self):
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = SFTConfig(
output_dir=tmp_dir,
packing=True,
report_to="none",
)
trainer = SFTTrainer(
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
args=training_args,
train_dataset=self.train_dataset_from_pretokenized,
)
trainer.train()
assert trainer.state.log_history[-1]["train_loss"] is not None
def test_uncorrect_data(self):
with tempfile.TemporaryDirectory() as tmp_dir:
# Shoud work as SFTTrainer natively supports conversational lm dataset
training_args = SFTConfig(
output_dir=tmp_dir,
per_device_train_batch_size=2,
max_length=32, # make sure there is at least 1 packed sequence
packing=True,
report_to="none",
)
_ = SFTTrainer(
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
args=training_args,
train_dataset=self.conversational_lm_dataset["train"],
)
# Same, but without packing
training_args = SFTConfig(
output_dir=tmp_dir,
per_device_train_batch_size=2,
packing=False,
report_to="none",
)
_ = SFTTrainer(
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
args=training_args,
train_dataset=self.conversational_lm_dataset["train"],
)
# Same, but with packing with `max_length`
training_args = SFTConfig(
output_dir=tmp_dir,
per_device_train_batch_size=2,
max_length=16, # make sure there is at least 1 packed sequence
packing=True,
report_to="none",
)
_ = SFTTrainer(
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
args=training_args,
train_dataset=self.standard_prompt_completion_dataset["train"],
)
# Same but with prompt completion dataset
training_args = SFTConfig(
output_dir=tmp_dir,
per_device_train_batch_size=2,
packing=False,
report_to="none",
)
_ = SFTTrainer(
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
args=training_args,
train_dataset=self.standard_prompt_completion_dataset["train"],
)
# Should work as dummy dataset are supported with a formatting function
training_args = SFTConfig(
output_dir=tmp_dir,
per_device_train_batch_size=2,
max_length=32, # make sure there is at least 1 packed sequence
packing=True,
report_to="none",
)
_ = SFTTrainer(
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
args=training_args,
train_dataset=self.dummy_dataset,
formatting_func=formatting_prompts_func,
)
def test_sft_trainer_with_model_num_train_epochs(self):
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = SFTConfig(
output_dir=tmp_dir,
num_train_epochs=2,
per_device_train_batch_size=2,
packing=True,
report_to="none",
)
trainer = SFTTrainer(
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
args=training_args,
train_dataset=self.train_dataset,
)
trainer.train()
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = SFTConfig(
output_dir=tmp_dir,
num_train_epochs=2,
max_length=16,
packing=True,
report_to="none",
)
trainer = SFTTrainer(
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
args=training_args,
train_dataset=self.dummy_dataset,
)
trainer.train()
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = SFTConfig(
output_dir=tmp_dir,
num_train_epochs=2,
per_device_train_batch_size=2,
max_length=16,
report_to="none",
)
trainer = SFTTrainer(
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
args=training_args,
train_dataset=self.dummy_dataset,
)
trainer.train()
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
def test_with_model_(self):
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = SFTConfig(
output_dir=tmp_dir,
per_device_train_batch_size=2,
max_length=16,
packing=True,
report_to="none",
)
trainer = SFTTrainer(
model=self.model,
args=training_args,
train_dataset=self.dummy_dataset,
)
trainer.train()
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
# with formatting_func + packed
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = SFTConfig(
output_dir=tmp_dir,
per_device_train_batch_size=2,
max_length=16,
packing=True,
report_to="none",
)
trainer = SFTTrainer(
model=self.model,
args=training_args,
train_dataset=self.dummy_dataset,
formatting_func=formatting_prompts_func,
)
trainer.train()
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = SFTConfig(
output_dir=tmp_dir,
per_device_train_batch_size=2,
max_length=16,
report_to="none",
)
trainer = SFTTrainer(
model=self.model,
args=training_args,
train_dataset=self.dummy_dataset,
)
trainer.train()
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
def test_with_multiple_eval_datasets(self):
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = SFTConfig(
output_dir=tmp_dir,
per_device_train_batch_size=2,
eval_strategy="steps",
eval_steps=3,
report_to="none",
)
trainer = SFTTrainer(
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
args=training_args,
train_dataset=self.train_dataset,
eval_dataset={
"data1": self.eval_dataset,
"data2": self.eval_dataset,
},
)
trainer.train()
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
self.assertIsNotNone(trainer.state.log_history[0]["eval_data1_loss"])
self.assertIsNotNone(trainer.state.log_history[1]["eval_data2_loss"])
def test_data_collator_completion_lm(self):
response_template = "### Response:\n"
data_collator = DataCollatorForCompletionOnlyLM(response_template, tokenizer=self.tokenizer, mlm=False)
text = """\n\n### Instructions:\nHello all this should be masked\n\n### Response:\nI have not been masked correctly."""
encoded_text = self.tokenizer(text)
examples = [encoded_text]
batch = data_collator(examples)
labels = batch["labels"]
last_pad_idx = np.where(labels == -100)[1][-1]
result_text = self.tokenizer.decode(batch["input_ids"][0, last_pad_idx + 1 :])
self.assertEqual(result_text, "I have not been masked correctly.")
def test_data_collator_completion_lm_with_multiple_text(self):
tokenizer = copy.deepcopy(self.tokenizer)
tokenizer.padding_side = "left"
response_template = "### Response:\n"
data_collator = DataCollatorForCompletionOnlyLM(response_template, tokenizer=tokenizer, mlm=False)
text1 = """\n\n### Instructions:\nHello all this should be masked\n\n### Response:\nI have not been masked correctly."""
text2 = """\n\n### Instructions:\nThis is another longer text that should also be masked. This text is significantly longer than the previous one.\n\n### Response:\nI have not been masked correctly."""
encoded_text1 = tokenizer(text1)
encoded_text2 = tokenizer(text2)
examples = [encoded_text1, encoded_text2]
batch = data_collator(examples)
for i in range(2):
labels = batch["labels"][i]
last_pad_idx = np.where(labels == -100)[0][-1]
result_text = tokenizer.decode(batch["input_ids"][i, last_pad_idx + 1 :])
self.assertEqual(result_text, "I have not been masked correctly.")
def test_data_collator_chat_completion_lm(self):
instruction_template = "### Human:"
assistant_template = "### Assistant:"
data_collator = DataCollatorForCompletionOnlyLM(
response_template=assistant_template,
instruction_template=instruction_template,
tokenizer=self.tokenizer,
mlm=False,
)
text = """### Human: Hello all this should be masked.### Assistant: I should not be masked.### Human: All this should be masked too.### Assistant: I should not be masked too."""
encoded_text = self.tokenizer(text)
examples = [encoded_text]
batch = data_collator(examples)
labels = batch["labels"]
non_masked_tokens = batch["input_ids"][labels != -100]
result_text = self.tokenizer.decode(non_masked_tokens)
self.assertEqual(result_text, " I should not be masked. I should not be masked too.")
def test_data_collator_chat_completion_lm_with_multiple_text(self):
tokenizer = copy.deepcopy(self.tokenizer)
tokenizer.padding_side = "left"
instruction_template = "### Human:"
assistant_template = "### Assistant:"
data_collator = DataCollatorForCompletionOnlyLM(
response_template=assistant_template,
instruction_template=instruction_template,
tokenizer=tokenizer,
mlm=False,
)
text1 = """### Human: Hello all this should be masked.### Assistant: I should not be masked."""
text2 = """### Human: Hello all this should be masked.### Assistant: I should not be masked.### Human: All this should be masked too.### Assistant: I should not be masked too."""
encoded_text1 = tokenizer(text1)
encoded_text2 = tokenizer(text2)
examples = [encoded_text1, encoded_text2]
batch = data_collator(examples)
labels = batch["labels"]
input_ids = batch["input_ids"]
non_masked_tokens1 = input_ids[0][labels[0] != -100]
result_text1 = tokenizer.decode(non_masked_tokens1)
self.assertEqual(result_text1, " I should not be masked.")
non_masked_tokens2 = input_ids[1][labels[1] != -100]
result_text2 = tokenizer.decode(non_masked_tokens2)
self.assertEqual(result_text2, " I should not be masked. I should not be masked too.")
def test_with_model_neftune(self):
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = SFTConfig(
output_dir=tmp_dir,
per_device_train_batch_size=2,
neftune_noise_alpha=5,
packing=True,
report_to="none",
)
trainer = SFTTrainer(
model=self.model,
args=training_args,
train_dataset=self.train_dataset,
)
trainer.model = trainer._activate_neftune(trainer.model)
device = trainer.model.get_input_embeddings().weight.device
trainer.model.train()
torch.random.manual_seed(42)
embeds_neftune = trainer.model.get_input_embeddings()(torch.LongTensor([[1, 0, 1]]).to(device))
torch.random.manual_seed(24)
embeds_neftune_2 = trainer.model.get_input_embeddings()(torch.LongTensor([[1, 0, 1]]).to(device))
self.assertFalse(torch.allclose(embeds_neftune, embeds_neftune_2))
self.assertGreater(len(trainer.model.get_input_embeddings()._forward_hooks), 0)
trainer.neftune_hook_handle.remove()
trainer.train()
# Make sure forward pass works fine
_ = trainer.model(torch.LongTensor([[1, 0, 1]]).to(device))
self.assertEqual(len(trainer.model.get_input_embeddings()._forward_hooks), 0)
@require_peft
def test_peft_str(self):
with tempfile.TemporaryDirectory() as tmp_dir:
peft_config = LoraConfig(
r=16,
lora_alpha=32,
lora_dropout=0.05,
task_type="CAUSAL_LM",
)
training_args = SFTConfig(
packing=True,
output_dir=tmp_dir,
report_to="none",
)
_ = SFTTrainer(
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
args=training_args,
train_dataset=self.train_dataset,
peft_config=peft_config,
)
@require_peft
def test_peft_sft_trainer(self):
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = SFTConfig(
output_dir=tmp_dir,
per_device_train_batch_size=2,
packing=True,
report_to="none",
)
peft_config = LoraConfig(
r=16,
lora_alpha=32,
lora_dropout=0.05,
task_type="CAUSAL_LM",
)
trainer = SFTTrainer(
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
args=training_args,
train_dataset=self.train_dataset,
peft_config=peft_config,
)
self.assertTrue(isinstance(trainer.model, PeftModel))
trainer.train()
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
@require_peft
def test_peft_and_gradient_checkpointing(self):
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = SFTConfig(
output_dir=tmp_dir,
gradient_checkpointing=True,
report_to="none",
)
peft_config = LoraConfig(r=16, lora_alpha=32, lora_dropout=0.05, task_type="CAUSAL_LM")
trainer = SFTTrainer(
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
args=training_args,
train_dataset=self.train_dataset,
peft_config=peft_config,
)
self.assertIsInstance(trainer.model, PeftModel)
trainer.train()
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
@require_peft
def test_peft_neftune(self):
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = SFTConfig(
output_dir=tmp_dir,
per_device_train_batch_size=2,
neftune_noise_alpha=5,
packing=True,
report_to="none",
)
peft_config = LoraConfig(
r=16,
lora_alpha=32,
lora_dropout=0.05,
task_type="CAUSAL_LM",
)
trainer = SFTTrainer(
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
args=training_args,
train_dataset=self.train_dataset,
peft_config=peft_config,
)
trainer.model = trainer._activate_neftune(trainer.model)
self.assertIsInstance(trainer.model, PeftModel)
device = trainer.model.get_input_embeddings().weight.device
trainer.model.train()
torch.random.manual_seed(42)
embeds_neftune = trainer.model.get_input_embeddings()(torch.LongTensor([[1, 0, 1]]).to(device))
torch.random.manual_seed(24)
embeds_neftune_2 = trainer.model.get_input_embeddings()(torch.LongTensor([[1, 0, 1]]).to(device))
self.assertFalse(torch.allclose(embeds_neftune, embeds_neftune_2))
self.assertGreater(len(trainer.model.get_input_embeddings()._forward_hooks), 0)
trainer.neftune_hook_handle.remove()
trainer.train()
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
# Make sure forward pass works fine to check if embeddings forward is not broken.
trainer.model(torch.LongTensor([[1, 0, 1]]).to(device))
self.assertEqual(len(trainer.model.get_input_embeddings()._forward_hooks), 0)
@require_peft
def test_peft_tag(self):
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = SFTConfig(
output_dir=tmp_dir,
per_device_train_batch_size=2,
gradient_checkpointing=True,
packing=True,
report_to="none",
)
peft_config = LoraConfig(
r=16,
lora_alpha=32,
lora_dropout=0.05,
task_type="CAUSAL_LM",
)
trainer = SFTTrainer(
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
args=training_args,
train_dataset=self.train_dataset,
peft_config=peft_config,
)
for tag in ["sft", "trl"]:
self.assertIn(tag, trainer.model.model_tags)
@require_peft
def test_tag(self):
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = SFTConfig(
output_dir=tmp_dir,
per_device_train_batch_size=2,
gradient_checkpointing=True,
packing=True,
report_to="none",
)
trainer = SFTTrainer(
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
args=training_args,
train_dataset=self.train_dataset,
)
for tag in ["sft", "trl"]:
self.assertIn(tag, trainer.model.model_tags)
def test_only_train_packing(self):
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = SFTConfig(
output_dir=tmp_dir,
per_device_train_batch_size=2,
gradient_checkpointing=True,
packing=True,
max_length=128, # make sure there is at least 1 packed sequence
eval_packing=False,
report_to="none",
)
trainer = SFTTrainer(
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
args=training_args,
train_dataset=self.conversational_lm_dataset["train"],
eval_dataset=self.conversational_lm_dataset["test"],
)
self.assertEqual(len(trainer.train_dataset["input_ids"]), 7) # w/ this dataset, we end up with 46 seqs
self.assertEqual(len(trainer.eval_dataset["input_ids"]), len(self.conversational_lm_dataset["test"]))
def test_eval_packing(self):
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = SFTConfig(
output_dir=tmp_dir,
per_device_train_batch_size=2,
max_length=128, # make sure there is at least 1 packed sequence
packing=True,
report_to="none",
)
trainer = SFTTrainer(
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
args=training_args,
train_dataset=self.conversational_lm_dataset["train"],
eval_dataset=self.conversational_lm_dataset["test"],
)
self.assertEqual(len(trainer.train_dataset["input_ids"]), 7) # w/ this dataset, we end up with 46 seqs
self.assertEqual(len(trainer.eval_dataset["input_ids"]), 1) # w/ this dataset, we end up with 6 seqs
def test_no_packing(self):
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = SFTConfig(
output_dir=tmp_dir,
per_device_train_batch_size=2,
max_length=128, # make sure there is at least 1 packed sequence
packing=False,
report_to="none",
)
trainer = SFTTrainer(
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
args=training_args,
train_dataset=self.conversational_lm_dataset["train"],
eval_dataset=self.conversational_lm_dataset["test"],
)
self.assertEqual(len(trainer.train_dataset["input_ids"]), len(self.conversational_lm_dataset["train"]))
self.assertEqual(len(trainer.eval_dataset["input_ids"]), len(self.conversational_lm_dataset["test"]))
@require_vision
def test_skip_prepare_dataset(self):
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = SFTConfig(
output_dir=tmp_dir,
per_device_train_batch_size=2,
remove_unused_columns=False,
dataset_kwargs={"skip_prepare_dataset": True},
report_to="none",
)
trainer = SFTTrainer(
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
args=training_args,
train_dataset=self.dummy_vsft_instruction_dataset,
)
self.assertEqual(trainer.train_dataset.features, self.dummy_vsft_instruction_dataset.features)
def test_skip_prepare_dataset_with_no_packing(self):
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = SFTConfig(
output_dir=tmp_dir,
per_device_train_batch_size=2,
remove_unused_columns=False,
packing=False,
dataset_kwargs={"skip_prepare_dataset": True},
report_to="none",
)
trainer = SFTTrainer(
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
args=training_args,
train_dataset=self.dummy_dataset,
)
self.assertEqual(trainer.train_dataset.features, self.dummy_dataset.features)
@require_vision
def test_llava(self):
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = SFTConfig(
output_dir=tmp_dir,
remove_unused_columns=False,
dataset_kwargs={"skip_prepare_dataset": True},
report_to="none",
)
tiny_llava = LlavaForConditionalGeneration.from_pretrained(
"trl-internal-testing/tiny-LlavaForConditionalGeneration"
)
processor = AutoProcessor.from_pretrained("trl-internal-testing/tiny-LlavaForConditionalGeneration")
processor.chat_template = """{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. {% for message in messages %}{% if message['role'] == 'user' %}USER: {% else %}ASSISTANT: {% endif %}{% for item in message['content'] %}{% if item['type'] == 'text' %}{{ item['text'] }}{% elif item['type'] == 'image' %}<image>{% endif %}{% endfor %}{% if message['role'] == 'user' %} {% else %}{{eos_token}}{% endif %}{% endfor %}{% if add_generation_prompt %}ASSISTANT: {% endif %}"""
def collate_fn(examples):
# Get the texts and images, and apply the chat template
texts = [processor.apply_chat_template(example["messages"], tokenize=False) for example in examples]
images = [example["images"][0] for example in examples]
# Tokenize the texts and process the images
batch = processor(texts, images, return_tensors="pt", padding=True)
# The labels are the input_ids, and we mask the padding tokens in the loss computation
labels = batch["input_ids"].clone()
labels[labels == processor.tokenizer.pad_token_id] = -100
batch["labels"] = labels
return batch
trainer = SFTTrainer(
model=tiny_llava,
args=training_args,
data_collator=collate_fn,
train_dataset=self.dummy_vsft_instruction_dataset,
)
trainer.train()
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
def test_torch_dtype(self):
# See https://github.com/huggingface/trl/issues/1751
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = SFTConfig(
output_dir=tmp_dir,
per_device_train_batch_size=2,
model_init_kwargs={"torch_dtype": torch.float16},
report_to="none",
)
trainer = SFTTrainer(
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
args=training_args,
train_dataset=self.train_dataset,
formatting_func=formatting_prompts_func,
)
self.assertEqual(trainer.model.config.torch_dtype, torch.float16)
# This new tester aims to replace the first one at some point
class SFTTrainerTester2(unittest.TestCase):
def test_train(self):
# Get the dataset
dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train")
with tempfile.TemporaryDirectory() as tmp_dir:
# Initialize the trainer
training_args = SFTConfig(output_dir=tmp_dir, report_to="none")
trainer = SFTTrainer(
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", args=training_args, train_dataset=dataset
)
# Save the initial parameters to compare them later
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
# Train the model
trainer.train()
# Check that the training loss is not None
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
# Check the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")
def test_train_model(self):
# Instantiate the model
model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5")
# Get the dataset
dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train")
with tempfile.TemporaryDirectory() as tmp_dir:
# Initialize the trainer
training_args = SFTConfig(output_dir=tmp_dir, report_to="none")
trainer = SFTTrainer(model=model, args=training_args, train_dataset=dataset)
# Save the initial parameters to compare them later
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
# Train the model
trainer.train()
# Check that the training loss is not None
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
# Check the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")
def test_train_model_torch_dtype(self):
# Get the dataset
dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train")
with tempfile.TemporaryDirectory() as tmp_dir:
# Initialize the trainer
training_args = SFTConfig(
output_dir=tmp_dir, model_init_kwargs={"torch_dtype": torch.float16}, report_to="none"
)
trainer = SFTTrainer(
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", args=training_args, train_dataset=dataset
)
# Save the initial parameters to compare them later
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
# Train the model
trainer.train()
# Check that the training loss is not None
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
# Check the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
# Check the torch dtype
self.assertEqual(new_param.dtype, torch.float16)
self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")
@require_peft
def test_train_peft_model(self):
# Get the base model
model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
model = AutoModelForCausalLM.from_pretrained(model_id)
# Get the base model parameter names
base_param_names = [f"base_model.model.{n}" for n, _ in model.named_parameters()]
# Turn the model into a peft model
lora_config = LoraConfig()
model = get_peft_model(model, lora_config)
# Get the dataset
dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train")
with tempfile.TemporaryDirectory() as tmp_dir:
# Initialize the trainer
training_args = SFTConfig(output_dir=tmp_dir, report_to="none")
trainer = SFTTrainer(model=model, args=training_args, train_dataset=dataset)
# Save the initial parameters to compare them later
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
# Train the model
trainer.train()
# Check that the training loss is not None
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
# Check the peft params have changed and the base model params have not changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
if n in base_param_names: # We expect the base model parameters to be the same
self.assertTrue(torch.allclose(param, new_param), f"Parameter {n} has changed")
elif (
"base_layer" not in n
): # We expect the peft parameters to be different (except for the base layer)
self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")
def test_train_with_non_chatml_conversational_data(self):
# Get the dataset
dataset = load_dataset("trl-internal-testing/zen", "conversational_language_modeling", split="train")
# Rename role/content to from/value to ensure SFT works with non-chatML conversational data
def rename_fields(example: list[dict]):
return {"conversations": [{"from": m["role"], "value": m["content"]} for m in example["messages"]]}
dataset = dataset.map(rename_fields, remove_columns="messages")
with tempfile.TemporaryDirectory() as tmp_dir:
# Initialize the trainer
training_args = SFTConfig(output_dir=tmp_dir, report_to="none")
trainer = SFTTrainer(
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", args=training_args, train_dataset=dataset
)
# Save the initial parameters to compare them later
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
# Train the model
trainer.train()
# Check that the training loss is not None
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
# Check the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")
def test_train_with_pretokenized_data(self):
# Get the dataset
model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
tokenizer = AutoTokenizer.from_pretrained(model_id)
dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train")
def tokenize_example(example):
return tokenizer(example["text"])
# Apply tokenization
tokenized_dataset = dataset.map(tokenize_example, remove_columns=["text"])
with tempfile.TemporaryDirectory() as tmp_dir:
# Initialize the trainer
training_args = SFTConfig(output_dir=tmp_dir, report_to="none")
trainer = SFTTrainer(model=model_id, args=training_args, train_dataset=tokenized_dataset)
# Save the initial parameters to compare them later
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
# Train the model
trainer.train()
# Check that the training loss is not None
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
# Check the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")
def test_train_with_iterable_dataset(self):
# Get the dataset
dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train", streaming=True)
with tempfile.TemporaryDirectory() as tmp_dir:
# Initialize the trainer
training_args = SFTConfig(output_dir=tmp_dir, max_steps=3, report_to="none")
trainer = SFTTrainer(
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", args=training_args, train_dataset=dataset
)
# Save the initial parameters to compare them later
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
# Train the model
trainer.train()
# Check that the training loss is not None
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
# Check the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")
def test_train_with_data_collator_for_completion_only_and_padding_free(self):
# Get the dataset
model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
dataset = load_dataset("trl-internal-testing/zen", "conversational_language_modeling", split="train")
tokenizer = AutoTokenizer.from_pretrained(model_id)
response_template = "<|im_start|>assistant\n"
collator = DataCollatorForCompletionOnlyLM(response_template, tokenizer=tokenizer, padding_free=True)
with tempfile.TemporaryDirectory() as tmp_dir:
# Initialize the trainer
training_args = SFTConfig(output_dir=tmp_dir, report_to="none")
trainer = SFTTrainer(model=model_id, args=training_args, train_dataset=dataset, data_collator=collator)
# Save the initial parameters to compare them later
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
# Train the model
trainer.train()
# Check that the training loss is not None
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
# Check the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")
@require_flash_attn
def test_train_padding_free(self):
# Get the dataset
dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train")
with tempfile.TemporaryDirectory() as tmp_dir:
# Initialize the trainer
training_args = SFTConfig(
output_dir=tmp_dir,
padding_free=True,
model_init_kwargs={"attn_implementation": "flash_attention_2"},
bf16=True, # flash_attention_2 only supports bf16 and fp16
report_to="none",
)
trainer = SFTTrainer(
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", args=training_args, train_dataset=dataset
)
# Save the initial parameters to compare them later
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
# Train the model
trainer.train()
# Check that the training loss is not None
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
# Check the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")
@parameterized.expand([("ffd",), ("wrapped",)])
def test_train_packing(self, packing_strategy):
# Get the dataset
dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train")
with tempfile.TemporaryDirectory() as tmp_dir:
# Initialize the trainer
training_args = SFTConfig(
output_dir=tmp_dir, packing=True, packing_strategy=packing_strategy, max_length=10, report_to="none"
)
trainer = SFTTrainer(
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", args=training_args, train_dataset=dataset
)
# Save the initial parameters to compare them later
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
# Train the model
trainer.train()
# Check that the training loss is not None
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
# Check the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")