Spaces:
Paused
Paused
# Copyright 2020-2025 The HuggingFace Team. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import copy | |
import tempfile | |
import unittest | |
import numpy as np | |
import torch | |
from datasets import Dataset, Image, Sequence, load_dataset | |
from parameterized import parameterized | |
from transformers import ( | |
AutoModelForCausalLM, | |
AutoProcessor, | |
AutoTokenizer, | |
LlavaForConditionalGeneration, | |
TrainingArguments, | |
is_vision_available, | |
) | |
from transformers.testing_utils import require_flash_attn, require_peft, require_vision | |
from transformers.utils import is_peft_available | |
from trl import SFTConfig, SFTTrainer | |
from trl.trainer import ConstantLengthDataset, DataCollatorForCompletionOnlyLM | |
from trl.trainer.sft_trainer import DataCollatorForLanguageModeling | |
def formatting_prompts_func(example): | |
text = f"### Question: {example['question']}\n ### Answer: {example['answer']}" | |
return text | |
def formatting_func_for_pretokenized(example): | |
return example["input_ids"] | |
if is_peft_available(): | |
from peft import LoraConfig, PeftModel, get_peft_model | |
if is_vision_available(): | |
from PIL import Image as PILImage | |
class TestDataCollatorForLanguageModeling(unittest.TestCase): | |
def test_basic_padding(self): | |
"""Test basic padding functionality without completion masks.""" | |
self.collator = DataCollatorForLanguageModeling(pad_token_id=0) | |
examples = [{"input_ids": [1, 2, 3]}, {"input_ids": [4, 5]}] | |
result = self.collator(examples) | |
torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3], [4, 5, 0]])) | |
torch.testing.assert_close(result["attention_mask"], torch.tensor([[1, 1, 1], [1, 1, 0]])) | |
torch.testing.assert_close(result["position_ids"], torch.tensor([[0, 1, 2], [0, 1, 0]])) | |
torch.testing.assert_close(result["labels"], torch.tensor([[1, 2, 3], [4, 5, -100]])) | |
def test_completion_mask(self): | |
"""Test completion mask functionality.""" | |
self.collator = DataCollatorForLanguageModeling(pad_token_id=0) | |
examples = [ | |
{"input_ids": [1, 2, 3], "completion_mask": [0, 1, 1]}, | |
{"input_ids": [4, 5], "completion_mask": [0, 1]}, | |
] | |
result = self.collator(examples) | |
torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3], [4, 5, 0]])) | |
torch.testing.assert_close(result["attention_mask"], torch.tensor([[1, 1, 1], [1, 1, 0]])) | |
torch.testing.assert_close(result["position_ids"], torch.tensor([[0, 1, 2], [0, 1, 0]])) | |
torch.testing.assert_close(result["labels"], torch.tensor([[-100, 2, 3], [-100, 5, -100]])) | |
def test_completion_only_loss_disabled(self): | |
"""Test behavior when completion_only_loss is disabled.""" | |
collator = DataCollatorForLanguageModeling(pad_token_id=0, completion_only_loss=False) | |
examples = [ | |
{"input_ids": [1, 2, 3], "completion_mask": [0, 1, 1]}, | |
{"input_ids": [4, 5], "completion_mask": [0, 1]}, | |
] | |
result = collator(examples) | |
# Labels should not be masked when completion_only_loss=False | |
torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3], [4, 5, 0]])) | |
torch.testing.assert_close(result["attention_mask"], torch.tensor([[1, 1, 1], [1, 1, 0]])) | |
torch.testing.assert_close(result["position_ids"], torch.tensor([[0, 1, 2], [0, 1, 0]])) | |
torch.testing.assert_close(result["labels"], torch.tensor([[1, 2, 3], [4, 5, -100]])) | |
def test_padding_free_mode(self): | |
"""Test padding-free mode where sequences are concatenated.""" | |
collator = DataCollatorForLanguageModeling(pad_token_id=0, padding_free=True) | |
examples = [{"input_ids": [1, 2, 3]}, {"input_ids": [4, 5]}] | |
result = collator(examples) | |
torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3, 4, 5]])) | |
torch.testing.assert_close(result["attention_mask"], torch.tensor([[1, 1, 1, 1, 1]])) | |
torch.testing.assert_close(result["position_ids"], torch.tensor([[0, 1, 2, 0, 1]])) | |
torch.testing.assert_close(result["labels"], torch.tensor([[1, 2, 3, 4, 5]])) | |
def test_padding_free_with_completion_mask(self): | |
"""Test padding-free mode with completion masks.""" | |
collator = DataCollatorForLanguageModeling(pad_token_id=0, padding_free=True) | |
examples = [ | |
{"input_ids": [1, 2, 3], "completion_mask": [0, 1, 1]}, | |
{"input_ids": [4, 5], "completion_mask": [1, 1]}, | |
] | |
result = collator(examples) | |
torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3, 4, 5]])) | |
torch.testing.assert_close(result["attention_mask"], torch.tensor([[1, 1, 1, 1, 1]])) | |
torch.testing.assert_close(result["position_ids"], torch.tensor([[0, 1, 2, 0, 1]])) | |
torch.testing.assert_close(result["labels"], torch.tensor([[-100, 2, 3, 4, 5]])) | |
def test_pad_to_multiple_of(self): | |
"""Test padding to multiple of specified value.""" | |
collator = DataCollatorForLanguageModeling(pad_token_id=0, pad_to_multiple_of=4) | |
examples = [{"input_ids": [1, 2, 3]}, {"input_ids": [4, 5]}] | |
result = collator(examples) | |
torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3, 0], [4, 5, 0, 0]])) | |
torch.testing.assert_close(result["attention_mask"], torch.tensor([[1, 1, 1, 0], [1, 1, 0, 0]])) | |
torch.testing.assert_close(result["position_ids"], torch.tensor([[0, 1, 2, 0], [0, 1, 0, 0]])) | |
torch.testing.assert_close(result["labels"], torch.tensor([[1, 2, 3, -100], [4, 5, -100, -100]])) | |
def test_custom_position_ids(self): | |
"""Test handling of custom position IDs in examples.""" | |
self.collator = DataCollatorForLanguageModeling(pad_token_id=0) | |
examples = [{"input_ids": [1, 2, 3], "position_ids": [0, 0, 1]}, {"input_ids": [4, 5], "position_ids": [0, 1]}] | |
result = self.collator(examples) | |
torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3], [4, 5, 0]])) | |
torch.testing.assert_close(result["attention_mask"], torch.tensor([[1, 1, 1], [1, 1, 0]])) | |
torch.testing.assert_close(result["position_ids"], torch.tensor([[0, 0, 1], [0, 1, 0]])) | |
torch.testing.assert_close(result["labels"], torch.tensor([[1, 2, 3], [4, 5, -100]])) | |
def test_single_example(self): | |
"""Test collator with a single example.""" | |
self.collator = DataCollatorForLanguageModeling(pad_token_id=0) | |
examples = [{"input_ids": [1, 2, 3, 4]}] | |
result = self.collator(examples) | |
torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3, 4]])) | |
torch.testing.assert_close(result["attention_mask"], torch.tensor([[1, 1, 1, 1]])) | |
torch.testing.assert_close(result["position_ids"], torch.tensor([[0, 1, 2, 3]])) | |
torch.testing.assert_close(result["labels"], torch.tensor([[1, 2, 3, 4]])) | |
def test_different_pad_token_id(self): | |
"""Test with different pad token ID.""" | |
collator = DataCollatorForLanguageModeling(pad_token_id=999) | |
examples = [{"input_ids": [1, 2, 3]}, {"input_ids": [4, 5]}] | |
result = collator(examples) | |
torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3], [4, 5, 999]])) | |
torch.testing.assert_close(result["attention_mask"], torch.tensor([[1, 1, 1], [1, 1, 0]])) | |
torch.testing.assert_close(result["position_ids"], torch.tensor([[0, 1, 2], [0, 1, 0]])) | |
torch.testing.assert_close(result["labels"], torch.tensor([[1, 2, 3], [4, 5, -100]])) | |
class SFTTrainerTester(unittest.TestCase): | |
r""" """ | |
def setUp(self): | |
self.model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" | |
self.model = AutoModelForCausalLM.from_pretrained(self.model_id) | |
self.tokenizer = AutoTokenizer.from_pretrained(self.model_id) | |
self.dummy_dataset = Dataset.from_dict( | |
{ | |
"question": [ | |
"Does llamas know how to code?", | |
"Does llamas know how to fly?", | |
"Does llamas know how to talk?", | |
"Does llamas know how to code?", | |
"Does llamas know how to fly?", | |
"Does llamas know how to talk?", | |
"Does llamas know how to swim?", | |
], | |
"answer": [ | |
"Yes, llamas are very good at coding.", | |
"No, llamas can't fly.", | |
"Yes, llamas are very good at talking.", | |
"Yes, llamas are very good at coding.", | |
"No, llamas can't fly.", | |
"Yes, llamas are very good at talking.", | |
"No, llamas can't swim.", | |
], | |
"text": [ | |
"### Question: Does llamas know how to code?\n ### Answer: Yes, llamas are very good at coding.", | |
"### Question: Does llamas know how to fly?\n ### Answer: No, llamas can't fly.", | |
"### Question: Does llamas know how to talk?\n ### Answer: Yes, llamas are very good at talking.", | |
"### Question: Does llamas know how to code?\n ### Answer: Yes, llamas are very good at coding.", | |
"### Question: Does llamas know how to fly?\n ### Answer: No, llamas can't fly.", | |
"### Question: Does llamas know how to talk?\n ### Answer: Yes, llamas are very good at talking.", | |
"### Question: Does llamas know how to swim?\n ### Answer: No, llamas can't swim.", | |
], | |
} | |
) | |
self.dummy_tokenized_dataset = Dataset.from_dict( | |
{ | |
"input_ids": [ | |
self.tokenizer.encode( | |
"TRL is a library to post-train LLMs and diffusion models with methods such as Supervised Fine-tuning (SFT), Proximal Policy Optimization (PPO), and Direct Preference Optimization (DPO)." | |
) | |
] | |
* 10 | |
} | |
) | |
self.conversational_lm_dataset = load_dataset("trl-internal-testing/zen", "conversational_language_modeling") | |
self.standard_prompt_completion_dataset = load_dataset( | |
"trl-internal-testing/zen", "standard_prompt_completion" | |
) | |
if is_vision_available(): | |
self.dummy_vsft_instruction_dataset = Dataset.from_dict( | |
{ | |
"messages": [ | |
[ | |
{ | |
"role": "user", | |
"content": [{"type": "text", "text": "What is in this image?"}, {"type": "image"}], | |
}, | |
{ | |
"role": "assistant", | |
"content": [{"type": "text", "text": "It is random noise."}], | |
}, | |
{ | |
"role": "user", | |
"content": [{"type": "text", "text": "Oh ye, you are right, what is 1+1"}], | |
}, | |
{ | |
"role": "assistant", | |
"content": [{"type": "text", "text": "2"}], | |
}, | |
], | |
[ | |
{ | |
"role": "user", | |
"content": [{"type": "text", "text": "What is in this image?"}, {"type": "image"}], | |
}, | |
{ | |
"role": "assistant", | |
"content": [{"type": "text", "text": "It is random noise."}], | |
}, | |
], | |
], | |
"images": [ | |
[PILImage.fromarray((np.random.rand(40, 50, 3) * 255).astype("uint8")).convert("RGBA")], | |
[PILImage.fromarray((np.random.rand(50, 60, 3) * 255).astype("uint8")).convert("RGBA")], | |
], | |
} | |
) | |
self.dummy_vsft_instruction_dataset.cast_column("images", Sequence(Image())) | |
self.dummy_vsft_instruction_dataset = self.dummy_vsft_instruction_dataset.cast_column( | |
"images", Sequence(Image()) | |
) | |
self.train_dataset = ConstantLengthDataset( | |
self.tokenizer, | |
self.dummy_dataset, | |
formatting_func=formatting_prompts_func, | |
seq_length=16, | |
num_of_sequences=16, | |
) | |
self.eval_dataset = ConstantLengthDataset( | |
self.tokenizer, | |
self.dummy_dataset, | |
formatting_func=formatting_prompts_func, | |
seq_length=16, | |
num_of_sequences=16, | |
) | |
self.train_dataset_from_pretokenized = ConstantLengthDataset( | |
self.tokenizer, | |
self.dummy_tokenized_dataset, | |
seq_length=16, | |
num_of_sequences=16, | |
formatting_func=formatting_func_for_pretokenized, | |
) | |
self.eval_dataset_from_pretokenized = ConstantLengthDataset( | |
self.tokenizer, | |
self.dummy_tokenized_dataset, | |
seq_length=16, | |
num_of_sequences=16, | |
formatting_func=formatting_func_for_pretokenized, | |
) | |
def test_constant_length_dataset_with_pretokenized_data(self): | |
constant_len_dataset = ConstantLengthDataset( | |
self.tokenizer, | |
self.dummy_tokenized_dataset, | |
formatting_func=formatting_func_for_pretokenized, | |
) | |
assert len(constant_len_dataset) == len(self.dummy_tokenized_dataset) | |
assert len(constant_len_dataset) > 0 | |
for example in constant_len_dataset: | |
assert "input_ids" in example | |
assert "labels" in example | |
assert len(example["input_ids"]) == constant_len_dataset.seq_length | |
assert len(example["labels"]) == constant_len_dataset.seq_length | |
decoded_text = self.tokenizer.decode(example["input_ids"]) | |
assert ("TRL" in decoded_text) and ("(DPO)" in decoded_text) | |
def test_constant_length_dataset(self): | |
formatted_dataset = ConstantLengthDataset( | |
self.tokenizer, | |
self.dummy_dataset, | |
formatting_func=formatting_prompts_func, | |
) | |
self.assertEqual(len(formatted_dataset), len(self.dummy_dataset)) | |
self.assertGreater(len(formatted_dataset), 0) | |
for example in formatted_dataset: | |
self.assertIn("input_ids", example) | |
self.assertIn("labels", example) | |
self.assertEqual(len(example["input_ids"]), formatted_dataset.seq_length) | |
self.assertEqual(len(example["labels"]), formatted_dataset.seq_length) | |
decoded_text = self.tokenizer.decode(example["input_ids"]) | |
self.assertTrue(("Question" in decoded_text) and ("Answer" in decoded_text)) | |
def test_backward_compatibility(self): | |
with tempfile.TemporaryDirectory() as tmp_dir: | |
training_args = TrainingArguments( | |
output_dir=tmp_dir, | |
per_device_train_batch_size=2, | |
hub_token="not_a_real_token", | |
report_to="none", | |
) | |
trainer = SFTTrainer( | |
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", | |
args=training_args, | |
train_dataset=self.train_dataset, | |
formatting_func=formatting_prompts_func, | |
) | |
self.assertEqual(trainer.args.hub_token, training_args.hub_token) | |
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} | |
trainer.train() | |
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) | |
# Check that the params have changed | |
for n, param in previous_trainable_params.items(): | |
new_param = trainer.model.get_parameter(n) | |
self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") | |
def test_with_pretokenized_data_packing(self): | |
with tempfile.TemporaryDirectory() as tmp_dir: | |
training_args = SFTConfig( | |
output_dir=tmp_dir, | |
packing=True, | |
report_to="none", | |
) | |
trainer = SFTTrainer( | |
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", | |
args=training_args, | |
train_dataset=self.train_dataset_from_pretokenized, | |
) | |
trainer.train() | |
assert trainer.state.log_history[-1]["train_loss"] is not None | |
def test_uncorrect_data(self): | |
with tempfile.TemporaryDirectory() as tmp_dir: | |
# Shoud work as SFTTrainer natively supports conversational lm dataset | |
training_args = SFTConfig( | |
output_dir=tmp_dir, | |
per_device_train_batch_size=2, | |
max_length=32, # make sure there is at least 1 packed sequence | |
packing=True, | |
report_to="none", | |
) | |
_ = SFTTrainer( | |
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", | |
args=training_args, | |
train_dataset=self.conversational_lm_dataset["train"], | |
) | |
# Same, but without packing | |
training_args = SFTConfig( | |
output_dir=tmp_dir, | |
per_device_train_batch_size=2, | |
packing=False, | |
report_to="none", | |
) | |
_ = SFTTrainer( | |
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", | |
args=training_args, | |
train_dataset=self.conversational_lm_dataset["train"], | |
) | |
# Same, but with packing with `max_length` | |
training_args = SFTConfig( | |
output_dir=tmp_dir, | |
per_device_train_batch_size=2, | |
max_length=16, # make sure there is at least 1 packed sequence | |
packing=True, | |
report_to="none", | |
) | |
_ = SFTTrainer( | |
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", | |
args=training_args, | |
train_dataset=self.standard_prompt_completion_dataset["train"], | |
) | |
# Same but with prompt completion dataset | |
training_args = SFTConfig( | |
output_dir=tmp_dir, | |
per_device_train_batch_size=2, | |
packing=False, | |
report_to="none", | |
) | |
_ = SFTTrainer( | |
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", | |
args=training_args, | |
train_dataset=self.standard_prompt_completion_dataset["train"], | |
) | |
# Should work as dummy dataset are supported with a formatting function | |
training_args = SFTConfig( | |
output_dir=tmp_dir, | |
per_device_train_batch_size=2, | |
max_length=32, # make sure there is at least 1 packed sequence | |
packing=True, | |
report_to="none", | |
) | |
_ = SFTTrainer( | |
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", | |
args=training_args, | |
train_dataset=self.dummy_dataset, | |
formatting_func=formatting_prompts_func, | |
) | |
def test_sft_trainer_with_model_num_train_epochs(self): | |
with tempfile.TemporaryDirectory() as tmp_dir: | |
training_args = SFTConfig( | |
output_dir=tmp_dir, | |
num_train_epochs=2, | |
per_device_train_batch_size=2, | |
packing=True, | |
report_to="none", | |
) | |
trainer = SFTTrainer( | |
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", | |
args=training_args, | |
train_dataset=self.train_dataset, | |
) | |
trainer.train() | |
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) | |
with tempfile.TemporaryDirectory() as tmp_dir: | |
training_args = SFTConfig( | |
output_dir=tmp_dir, | |
num_train_epochs=2, | |
max_length=16, | |
packing=True, | |
report_to="none", | |
) | |
trainer = SFTTrainer( | |
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", | |
args=training_args, | |
train_dataset=self.dummy_dataset, | |
) | |
trainer.train() | |
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) | |
with tempfile.TemporaryDirectory() as tmp_dir: | |
training_args = SFTConfig( | |
output_dir=tmp_dir, | |
num_train_epochs=2, | |
per_device_train_batch_size=2, | |
max_length=16, | |
report_to="none", | |
) | |
trainer = SFTTrainer( | |
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", | |
args=training_args, | |
train_dataset=self.dummy_dataset, | |
) | |
trainer.train() | |
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) | |
def test_with_model_(self): | |
with tempfile.TemporaryDirectory() as tmp_dir: | |
training_args = SFTConfig( | |
output_dir=tmp_dir, | |
per_device_train_batch_size=2, | |
max_length=16, | |
packing=True, | |
report_to="none", | |
) | |
trainer = SFTTrainer( | |
model=self.model, | |
args=training_args, | |
train_dataset=self.dummy_dataset, | |
) | |
trainer.train() | |
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) | |
# with formatting_func + packed | |
with tempfile.TemporaryDirectory() as tmp_dir: | |
training_args = SFTConfig( | |
output_dir=tmp_dir, | |
per_device_train_batch_size=2, | |
max_length=16, | |
packing=True, | |
report_to="none", | |
) | |
trainer = SFTTrainer( | |
model=self.model, | |
args=training_args, | |
train_dataset=self.dummy_dataset, | |
formatting_func=formatting_prompts_func, | |
) | |
trainer.train() | |
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) | |
with tempfile.TemporaryDirectory() as tmp_dir: | |
training_args = SFTConfig( | |
output_dir=tmp_dir, | |
per_device_train_batch_size=2, | |
max_length=16, | |
report_to="none", | |
) | |
trainer = SFTTrainer( | |
model=self.model, | |
args=training_args, | |
train_dataset=self.dummy_dataset, | |
) | |
trainer.train() | |
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) | |
def test_with_multiple_eval_datasets(self): | |
with tempfile.TemporaryDirectory() as tmp_dir: | |
training_args = SFTConfig( | |
output_dir=tmp_dir, | |
per_device_train_batch_size=2, | |
eval_strategy="steps", | |
eval_steps=3, | |
report_to="none", | |
) | |
trainer = SFTTrainer( | |
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", | |
args=training_args, | |
train_dataset=self.train_dataset, | |
eval_dataset={ | |
"data1": self.eval_dataset, | |
"data2": self.eval_dataset, | |
}, | |
) | |
trainer.train() | |
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) | |
self.assertIsNotNone(trainer.state.log_history[0]["eval_data1_loss"]) | |
self.assertIsNotNone(trainer.state.log_history[1]["eval_data2_loss"]) | |
def test_data_collator_completion_lm(self): | |
response_template = "### Response:\n" | |
data_collator = DataCollatorForCompletionOnlyLM(response_template, tokenizer=self.tokenizer, mlm=False) | |
text = """\n\n### Instructions:\nHello all this should be masked\n\n### Response:\nI have not been masked correctly.""" | |
encoded_text = self.tokenizer(text) | |
examples = [encoded_text] | |
batch = data_collator(examples) | |
labels = batch["labels"] | |
last_pad_idx = np.where(labels == -100)[1][-1] | |
result_text = self.tokenizer.decode(batch["input_ids"][0, last_pad_idx + 1 :]) | |
self.assertEqual(result_text, "I have not been masked correctly.") | |
def test_data_collator_completion_lm_with_multiple_text(self): | |
tokenizer = copy.deepcopy(self.tokenizer) | |
tokenizer.padding_side = "left" | |
response_template = "### Response:\n" | |
data_collator = DataCollatorForCompletionOnlyLM(response_template, tokenizer=tokenizer, mlm=False) | |
text1 = """\n\n### Instructions:\nHello all this should be masked\n\n### Response:\nI have not been masked correctly.""" | |
text2 = """\n\n### Instructions:\nThis is another longer text that should also be masked. This text is significantly longer than the previous one.\n\n### Response:\nI have not been masked correctly.""" | |
encoded_text1 = tokenizer(text1) | |
encoded_text2 = tokenizer(text2) | |
examples = [encoded_text1, encoded_text2] | |
batch = data_collator(examples) | |
for i in range(2): | |
labels = batch["labels"][i] | |
last_pad_idx = np.where(labels == -100)[0][-1] | |
result_text = tokenizer.decode(batch["input_ids"][i, last_pad_idx + 1 :]) | |
self.assertEqual(result_text, "I have not been masked correctly.") | |
def test_data_collator_chat_completion_lm(self): | |
instruction_template = "### Human:" | |
assistant_template = "### Assistant:" | |
data_collator = DataCollatorForCompletionOnlyLM( | |
response_template=assistant_template, | |
instruction_template=instruction_template, | |
tokenizer=self.tokenizer, | |
mlm=False, | |
) | |
text = """### Human: Hello all this should be masked.### Assistant: I should not be masked.### Human: All this should be masked too.### Assistant: I should not be masked too.""" | |
encoded_text = self.tokenizer(text) | |
examples = [encoded_text] | |
batch = data_collator(examples) | |
labels = batch["labels"] | |
non_masked_tokens = batch["input_ids"][labels != -100] | |
result_text = self.tokenizer.decode(non_masked_tokens) | |
self.assertEqual(result_text, " I should not be masked. I should not be masked too.") | |
def test_data_collator_chat_completion_lm_with_multiple_text(self): | |
tokenizer = copy.deepcopy(self.tokenizer) | |
tokenizer.padding_side = "left" | |
instruction_template = "### Human:" | |
assistant_template = "### Assistant:" | |
data_collator = DataCollatorForCompletionOnlyLM( | |
response_template=assistant_template, | |
instruction_template=instruction_template, | |
tokenizer=tokenizer, | |
mlm=False, | |
) | |
text1 = """### Human: Hello all this should be masked.### Assistant: I should not be masked.""" | |
text2 = """### Human: Hello all this should be masked.### Assistant: I should not be masked.### Human: All this should be masked too.### Assistant: I should not be masked too.""" | |
encoded_text1 = tokenizer(text1) | |
encoded_text2 = tokenizer(text2) | |
examples = [encoded_text1, encoded_text2] | |
batch = data_collator(examples) | |
labels = batch["labels"] | |
input_ids = batch["input_ids"] | |
non_masked_tokens1 = input_ids[0][labels[0] != -100] | |
result_text1 = tokenizer.decode(non_masked_tokens1) | |
self.assertEqual(result_text1, " I should not be masked.") | |
non_masked_tokens2 = input_ids[1][labels[1] != -100] | |
result_text2 = tokenizer.decode(non_masked_tokens2) | |
self.assertEqual(result_text2, " I should not be masked. I should not be masked too.") | |
def test_with_model_neftune(self): | |
with tempfile.TemporaryDirectory() as tmp_dir: | |
training_args = SFTConfig( | |
output_dir=tmp_dir, | |
per_device_train_batch_size=2, | |
neftune_noise_alpha=5, | |
packing=True, | |
report_to="none", | |
) | |
trainer = SFTTrainer( | |
model=self.model, | |
args=training_args, | |
train_dataset=self.train_dataset, | |
) | |
trainer.model = trainer._activate_neftune(trainer.model) | |
device = trainer.model.get_input_embeddings().weight.device | |
trainer.model.train() | |
torch.random.manual_seed(42) | |
embeds_neftune = trainer.model.get_input_embeddings()(torch.LongTensor([[1, 0, 1]]).to(device)) | |
torch.random.manual_seed(24) | |
embeds_neftune_2 = trainer.model.get_input_embeddings()(torch.LongTensor([[1, 0, 1]]).to(device)) | |
self.assertFalse(torch.allclose(embeds_neftune, embeds_neftune_2)) | |
self.assertGreater(len(trainer.model.get_input_embeddings()._forward_hooks), 0) | |
trainer.neftune_hook_handle.remove() | |
trainer.train() | |
# Make sure forward pass works fine | |
_ = trainer.model(torch.LongTensor([[1, 0, 1]]).to(device)) | |
self.assertEqual(len(trainer.model.get_input_embeddings()._forward_hooks), 0) | |
def test_peft_str(self): | |
with tempfile.TemporaryDirectory() as tmp_dir: | |
peft_config = LoraConfig( | |
r=16, | |
lora_alpha=32, | |
lora_dropout=0.05, | |
task_type="CAUSAL_LM", | |
) | |
training_args = SFTConfig( | |
packing=True, | |
output_dir=tmp_dir, | |
report_to="none", | |
) | |
_ = SFTTrainer( | |
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", | |
args=training_args, | |
train_dataset=self.train_dataset, | |
peft_config=peft_config, | |
) | |
def test_peft_sft_trainer(self): | |
with tempfile.TemporaryDirectory() as tmp_dir: | |
training_args = SFTConfig( | |
output_dir=tmp_dir, | |
per_device_train_batch_size=2, | |
packing=True, | |
report_to="none", | |
) | |
peft_config = LoraConfig( | |
r=16, | |
lora_alpha=32, | |
lora_dropout=0.05, | |
task_type="CAUSAL_LM", | |
) | |
trainer = SFTTrainer( | |
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", | |
args=training_args, | |
train_dataset=self.train_dataset, | |
peft_config=peft_config, | |
) | |
self.assertTrue(isinstance(trainer.model, PeftModel)) | |
trainer.train() | |
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) | |
def test_peft_and_gradient_checkpointing(self): | |
with tempfile.TemporaryDirectory() as tmp_dir: | |
training_args = SFTConfig( | |
output_dir=tmp_dir, | |
gradient_checkpointing=True, | |
report_to="none", | |
) | |
peft_config = LoraConfig(r=16, lora_alpha=32, lora_dropout=0.05, task_type="CAUSAL_LM") | |
trainer = SFTTrainer( | |
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", | |
args=training_args, | |
train_dataset=self.train_dataset, | |
peft_config=peft_config, | |
) | |
self.assertIsInstance(trainer.model, PeftModel) | |
trainer.train() | |
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) | |
def test_peft_neftune(self): | |
with tempfile.TemporaryDirectory() as tmp_dir: | |
training_args = SFTConfig( | |
output_dir=tmp_dir, | |
per_device_train_batch_size=2, | |
neftune_noise_alpha=5, | |
packing=True, | |
report_to="none", | |
) | |
peft_config = LoraConfig( | |
r=16, | |
lora_alpha=32, | |
lora_dropout=0.05, | |
task_type="CAUSAL_LM", | |
) | |
trainer = SFTTrainer( | |
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", | |
args=training_args, | |
train_dataset=self.train_dataset, | |
peft_config=peft_config, | |
) | |
trainer.model = trainer._activate_neftune(trainer.model) | |
self.assertIsInstance(trainer.model, PeftModel) | |
device = trainer.model.get_input_embeddings().weight.device | |
trainer.model.train() | |
torch.random.manual_seed(42) | |
embeds_neftune = trainer.model.get_input_embeddings()(torch.LongTensor([[1, 0, 1]]).to(device)) | |
torch.random.manual_seed(24) | |
embeds_neftune_2 = trainer.model.get_input_embeddings()(torch.LongTensor([[1, 0, 1]]).to(device)) | |
self.assertFalse(torch.allclose(embeds_neftune, embeds_neftune_2)) | |
self.assertGreater(len(trainer.model.get_input_embeddings()._forward_hooks), 0) | |
trainer.neftune_hook_handle.remove() | |
trainer.train() | |
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) | |
# Make sure forward pass works fine to check if embeddings forward is not broken. | |
trainer.model(torch.LongTensor([[1, 0, 1]]).to(device)) | |
self.assertEqual(len(trainer.model.get_input_embeddings()._forward_hooks), 0) | |
def test_peft_tag(self): | |
with tempfile.TemporaryDirectory() as tmp_dir: | |
training_args = SFTConfig( | |
output_dir=tmp_dir, | |
per_device_train_batch_size=2, | |
gradient_checkpointing=True, | |
packing=True, | |
report_to="none", | |
) | |
peft_config = LoraConfig( | |
r=16, | |
lora_alpha=32, | |
lora_dropout=0.05, | |
task_type="CAUSAL_LM", | |
) | |
trainer = SFTTrainer( | |
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", | |
args=training_args, | |
train_dataset=self.train_dataset, | |
peft_config=peft_config, | |
) | |
for tag in ["sft", "trl"]: | |
self.assertIn(tag, trainer.model.model_tags) | |
def test_tag(self): | |
with tempfile.TemporaryDirectory() as tmp_dir: | |
training_args = SFTConfig( | |
output_dir=tmp_dir, | |
per_device_train_batch_size=2, | |
gradient_checkpointing=True, | |
packing=True, | |
report_to="none", | |
) | |
trainer = SFTTrainer( | |
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", | |
args=training_args, | |
train_dataset=self.train_dataset, | |
) | |
for tag in ["sft", "trl"]: | |
self.assertIn(tag, trainer.model.model_tags) | |
def test_only_train_packing(self): | |
with tempfile.TemporaryDirectory() as tmp_dir: | |
training_args = SFTConfig( | |
output_dir=tmp_dir, | |
per_device_train_batch_size=2, | |
gradient_checkpointing=True, | |
packing=True, | |
max_length=128, # make sure there is at least 1 packed sequence | |
eval_packing=False, | |
report_to="none", | |
) | |
trainer = SFTTrainer( | |
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", | |
args=training_args, | |
train_dataset=self.conversational_lm_dataset["train"], | |
eval_dataset=self.conversational_lm_dataset["test"], | |
) | |
self.assertEqual(len(trainer.train_dataset["input_ids"]), 7) # w/ this dataset, we end up with 46 seqs | |
self.assertEqual(len(trainer.eval_dataset["input_ids"]), len(self.conversational_lm_dataset["test"])) | |
def test_eval_packing(self): | |
with tempfile.TemporaryDirectory() as tmp_dir: | |
training_args = SFTConfig( | |
output_dir=tmp_dir, | |
per_device_train_batch_size=2, | |
max_length=128, # make sure there is at least 1 packed sequence | |
packing=True, | |
report_to="none", | |
) | |
trainer = SFTTrainer( | |
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", | |
args=training_args, | |
train_dataset=self.conversational_lm_dataset["train"], | |
eval_dataset=self.conversational_lm_dataset["test"], | |
) | |
self.assertEqual(len(trainer.train_dataset["input_ids"]), 7) # w/ this dataset, we end up with 46 seqs | |
self.assertEqual(len(trainer.eval_dataset["input_ids"]), 1) # w/ this dataset, we end up with 6 seqs | |
def test_no_packing(self): | |
with tempfile.TemporaryDirectory() as tmp_dir: | |
training_args = SFTConfig( | |
output_dir=tmp_dir, | |
per_device_train_batch_size=2, | |
max_length=128, # make sure there is at least 1 packed sequence | |
packing=False, | |
report_to="none", | |
) | |
trainer = SFTTrainer( | |
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", | |
args=training_args, | |
train_dataset=self.conversational_lm_dataset["train"], | |
eval_dataset=self.conversational_lm_dataset["test"], | |
) | |
self.assertEqual(len(trainer.train_dataset["input_ids"]), len(self.conversational_lm_dataset["train"])) | |
self.assertEqual(len(trainer.eval_dataset["input_ids"]), len(self.conversational_lm_dataset["test"])) | |
def test_skip_prepare_dataset(self): | |
with tempfile.TemporaryDirectory() as tmp_dir: | |
training_args = SFTConfig( | |
output_dir=tmp_dir, | |
per_device_train_batch_size=2, | |
remove_unused_columns=False, | |
dataset_kwargs={"skip_prepare_dataset": True}, | |
report_to="none", | |
) | |
trainer = SFTTrainer( | |
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", | |
args=training_args, | |
train_dataset=self.dummy_vsft_instruction_dataset, | |
) | |
self.assertEqual(trainer.train_dataset.features, self.dummy_vsft_instruction_dataset.features) | |
def test_skip_prepare_dataset_with_no_packing(self): | |
with tempfile.TemporaryDirectory() as tmp_dir: | |
training_args = SFTConfig( | |
output_dir=tmp_dir, | |
per_device_train_batch_size=2, | |
remove_unused_columns=False, | |
packing=False, | |
dataset_kwargs={"skip_prepare_dataset": True}, | |
report_to="none", | |
) | |
trainer = SFTTrainer( | |
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", | |
args=training_args, | |
train_dataset=self.dummy_dataset, | |
) | |
self.assertEqual(trainer.train_dataset.features, self.dummy_dataset.features) | |
def test_llava(self): | |
with tempfile.TemporaryDirectory() as tmp_dir: | |
training_args = SFTConfig( | |
output_dir=tmp_dir, | |
remove_unused_columns=False, | |
dataset_kwargs={"skip_prepare_dataset": True}, | |
report_to="none", | |
) | |
tiny_llava = LlavaForConditionalGeneration.from_pretrained( | |
"trl-internal-testing/tiny-LlavaForConditionalGeneration" | |
) | |
processor = AutoProcessor.from_pretrained("trl-internal-testing/tiny-LlavaForConditionalGeneration") | |
processor.chat_template = """{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. {% for message in messages %}{% if message['role'] == 'user' %}USER: {% else %}ASSISTANT: {% endif %}{% for item in message['content'] %}{% if item['type'] == 'text' %}{{ item['text'] }}{% elif item['type'] == 'image' %}<image>{% endif %}{% endfor %}{% if message['role'] == 'user' %} {% else %}{{eos_token}}{% endif %}{% endfor %}{% if add_generation_prompt %}ASSISTANT: {% endif %}""" | |
def collate_fn(examples): | |
# Get the texts and images, and apply the chat template | |
texts = [processor.apply_chat_template(example["messages"], tokenize=False) for example in examples] | |
images = [example["images"][0] for example in examples] | |
# Tokenize the texts and process the images | |
batch = processor(texts, images, return_tensors="pt", padding=True) | |
# The labels are the input_ids, and we mask the padding tokens in the loss computation | |
labels = batch["input_ids"].clone() | |
labels[labels == processor.tokenizer.pad_token_id] = -100 | |
batch["labels"] = labels | |
return batch | |
trainer = SFTTrainer( | |
model=tiny_llava, | |
args=training_args, | |
data_collator=collate_fn, | |
train_dataset=self.dummy_vsft_instruction_dataset, | |
) | |
trainer.train() | |
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) | |
def test_torch_dtype(self): | |
# See https://github.com/huggingface/trl/issues/1751 | |
with tempfile.TemporaryDirectory() as tmp_dir: | |
training_args = SFTConfig( | |
output_dir=tmp_dir, | |
per_device_train_batch_size=2, | |
model_init_kwargs={"torch_dtype": torch.float16}, | |
report_to="none", | |
) | |
trainer = SFTTrainer( | |
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", | |
args=training_args, | |
train_dataset=self.train_dataset, | |
formatting_func=formatting_prompts_func, | |
) | |
self.assertEqual(trainer.model.config.torch_dtype, torch.float16) | |
# This new tester aims to replace the first one at some point | |
class SFTTrainerTester2(unittest.TestCase): | |
def test_train(self): | |
# Get the dataset | |
dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train") | |
with tempfile.TemporaryDirectory() as tmp_dir: | |
# Initialize the trainer | |
training_args = SFTConfig(output_dir=tmp_dir, report_to="none") | |
trainer = SFTTrainer( | |
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", args=training_args, train_dataset=dataset | |
) | |
# Save the initial parameters to compare them later | |
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} | |
# Train the model | |
trainer.train() | |
# Check that the training loss is not None | |
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) | |
# Check the params have changed | |
for n, param in previous_trainable_params.items(): | |
new_param = trainer.model.get_parameter(n) | |
self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") | |
def test_train_model(self): | |
# Instantiate the model | |
model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5") | |
# Get the dataset | |
dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train") | |
with tempfile.TemporaryDirectory() as tmp_dir: | |
# Initialize the trainer | |
training_args = SFTConfig(output_dir=tmp_dir, report_to="none") | |
trainer = SFTTrainer(model=model, args=training_args, train_dataset=dataset) | |
# Save the initial parameters to compare them later | |
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} | |
# Train the model | |
trainer.train() | |
# Check that the training loss is not None | |
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) | |
# Check the params have changed | |
for n, param in previous_trainable_params.items(): | |
new_param = trainer.model.get_parameter(n) | |
self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") | |
def test_train_model_torch_dtype(self): | |
# Get the dataset | |
dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train") | |
with tempfile.TemporaryDirectory() as tmp_dir: | |
# Initialize the trainer | |
training_args = SFTConfig( | |
output_dir=tmp_dir, model_init_kwargs={"torch_dtype": torch.float16}, report_to="none" | |
) | |
trainer = SFTTrainer( | |
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", args=training_args, train_dataset=dataset | |
) | |
# Save the initial parameters to compare them later | |
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} | |
# Train the model | |
trainer.train() | |
# Check that the training loss is not None | |
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) | |
# Check the params have changed | |
for n, param in previous_trainable_params.items(): | |
new_param = trainer.model.get_parameter(n) | |
# Check the torch dtype | |
self.assertEqual(new_param.dtype, torch.float16) | |
self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") | |
def test_train_peft_model(self): | |
# Get the base model | |
model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" | |
model = AutoModelForCausalLM.from_pretrained(model_id) | |
# Get the base model parameter names | |
base_param_names = [f"base_model.model.{n}" for n, _ in model.named_parameters()] | |
# Turn the model into a peft model | |
lora_config = LoraConfig() | |
model = get_peft_model(model, lora_config) | |
# Get the dataset | |
dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train") | |
with tempfile.TemporaryDirectory() as tmp_dir: | |
# Initialize the trainer | |
training_args = SFTConfig(output_dir=tmp_dir, report_to="none") | |
trainer = SFTTrainer(model=model, args=training_args, train_dataset=dataset) | |
# Save the initial parameters to compare them later | |
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} | |
# Train the model | |
trainer.train() | |
# Check that the training loss is not None | |
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) | |
# Check the peft params have changed and the base model params have not changed | |
for n, param in previous_trainable_params.items(): | |
new_param = trainer.model.get_parameter(n) | |
if n in base_param_names: # We expect the base model parameters to be the same | |
self.assertTrue(torch.allclose(param, new_param), f"Parameter {n} has changed") | |
elif ( | |
"base_layer" not in n | |
): # We expect the peft parameters to be different (except for the base layer) | |
self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") | |
def test_train_with_non_chatml_conversational_data(self): | |
# Get the dataset | |
dataset = load_dataset("trl-internal-testing/zen", "conversational_language_modeling", split="train") | |
# Rename role/content to from/value to ensure SFT works with non-chatML conversational data | |
def rename_fields(example: list[dict]): | |
return {"conversations": [{"from": m["role"], "value": m["content"]} for m in example["messages"]]} | |
dataset = dataset.map(rename_fields, remove_columns="messages") | |
with tempfile.TemporaryDirectory() as tmp_dir: | |
# Initialize the trainer | |
training_args = SFTConfig(output_dir=tmp_dir, report_to="none") | |
trainer = SFTTrainer( | |
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", args=training_args, train_dataset=dataset | |
) | |
# Save the initial parameters to compare them later | |
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} | |
# Train the model | |
trainer.train() | |
# Check that the training loss is not None | |
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) | |
# Check the params have changed | |
for n, param in previous_trainable_params.items(): | |
new_param = trainer.model.get_parameter(n) | |
self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") | |
def test_train_with_pretokenized_data(self): | |
# Get the dataset | |
model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" | |
tokenizer = AutoTokenizer.from_pretrained(model_id) | |
dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train") | |
def tokenize_example(example): | |
return tokenizer(example["text"]) | |
# Apply tokenization | |
tokenized_dataset = dataset.map(tokenize_example, remove_columns=["text"]) | |
with tempfile.TemporaryDirectory() as tmp_dir: | |
# Initialize the trainer | |
training_args = SFTConfig(output_dir=tmp_dir, report_to="none") | |
trainer = SFTTrainer(model=model_id, args=training_args, train_dataset=tokenized_dataset) | |
# Save the initial parameters to compare them later | |
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} | |
# Train the model | |
trainer.train() | |
# Check that the training loss is not None | |
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) | |
# Check the params have changed | |
for n, param in previous_trainable_params.items(): | |
new_param = trainer.model.get_parameter(n) | |
self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") | |
def test_train_with_iterable_dataset(self): | |
# Get the dataset | |
dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train", streaming=True) | |
with tempfile.TemporaryDirectory() as tmp_dir: | |
# Initialize the trainer | |
training_args = SFTConfig(output_dir=tmp_dir, max_steps=3, report_to="none") | |
trainer = SFTTrainer( | |
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", args=training_args, train_dataset=dataset | |
) | |
# Save the initial parameters to compare them later | |
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} | |
# Train the model | |
trainer.train() | |
# Check that the training loss is not None | |
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) | |
# Check the params have changed | |
for n, param in previous_trainable_params.items(): | |
new_param = trainer.model.get_parameter(n) | |
self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") | |
def test_train_with_data_collator_for_completion_only_and_padding_free(self): | |
# Get the dataset | |
model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" | |
dataset = load_dataset("trl-internal-testing/zen", "conversational_language_modeling", split="train") | |
tokenizer = AutoTokenizer.from_pretrained(model_id) | |
response_template = "<|im_start|>assistant\n" | |
collator = DataCollatorForCompletionOnlyLM(response_template, tokenizer=tokenizer, padding_free=True) | |
with tempfile.TemporaryDirectory() as tmp_dir: | |
# Initialize the trainer | |
training_args = SFTConfig(output_dir=tmp_dir, report_to="none") | |
trainer = SFTTrainer(model=model_id, args=training_args, train_dataset=dataset, data_collator=collator) | |
# Save the initial parameters to compare them later | |
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} | |
# Train the model | |
trainer.train() | |
# Check that the training loss is not None | |
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) | |
# Check the params have changed | |
for n, param in previous_trainable_params.items(): | |
new_param = trainer.model.get_parameter(n) | |
self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") | |
def test_train_padding_free(self): | |
# Get the dataset | |
dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train") | |
with tempfile.TemporaryDirectory() as tmp_dir: | |
# Initialize the trainer | |
training_args = SFTConfig( | |
output_dir=tmp_dir, | |
padding_free=True, | |
model_init_kwargs={"attn_implementation": "flash_attention_2"}, | |
bf16=True, # flash_attention_2 only supports bf16 and fp16 | |
report_to="none", | |
) | |
trainer = SFTTrainer( | |
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", args=training_args, train_dataset=dataset | |
) | |
# Save the initial parameters to compare them later | |
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} | |
# Train the model | |
trainer.train() | |
# Check that the training loss is not None | |
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) | |
# Check the params have changed | |
for n, param in previous_trainable_params.items(): | |
new_param = trainer.model.get_parameter(n) | |
self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") | |
def test_train_packing(self, packing_strategy): | |
# Get the dataset | |
dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train") | |
with tempfile.TemporaryDirectory() as tmp_dir: | |
# Initialize the trainer | |
training_args = SFTConfig( | |
output_dir=tmp_dir, packing=True, packing_strategy=packing_strategy, max_length=10, report_to="none" | |
) | |
trainer = SFTTrainer( | |
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", args=training_args, train_dataset=dataset | |
) | |
# Save the initial parameters to compare them later | |
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} | |
# Train the model | |
trainer.train() | |
# Check that the training loss is not None | |
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) | |
# Check the params have changed | |
for n, param in previous_trainable_params.items(): | |
new_param = trainer.model.get_parameter(n) | |
self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") | |