TRL documentation

Quickstart

Hugging Face's logo
Join the Hugging Face community

and get access to the augmented documentation experience

to get started

Quickstart

TRL is a comprehensive library for post-training foundation models using techniques like Supervised Fine-Tuning (SFT), Group Relative Policy Optimization (GRPO), Direct Preference Optimization (DPO).

Quick Examples

Get started instantly with TRLโ€™s most popular trainers. Each example uses compact models for quick experimentation.

Supervised Fine-Tuning

from trl import SFTTrainer
from datasets import load_dataset

trainer = SFTTrainer(
    model="Qwen/Qwen2.5-0.5B",
    train_dataset=load_dataset("trl-lib/Capybara", split="train"),
)
trainer.train()

Group Relative Policy Optimization

from trl import GRPOTrainer
from datasets import load_dataset

# Define a simple reward function (count unique chars as example)
def reward_function(completions, **kwargs):
    return [len(set(completion.lower())) for completion in completions]

trainer = GRPOTrainer(
    model="Qwen/Qwen2.5-0.5B-Instruct",  # Start from SFT model
    train_dataset=load_dataset("trl-lib/tldr", split="train"),
    reward_function=reward_function,
)
trainer.train()

Direct Preference Optimization

from trl import DPOTrainer
from datasets import load_dataset

trainer = DPOTrainer(
    model="Qwen/Qwen2.5-0.5B-Instruct",  # Use your SFT model
    ref_model="Qwen/Qwen2.5-0.5B-Instruct",  # Original base model
    train_dataset=load_dataset("trl-lib/ultrafeedback_binarized", split="train"),
)
trainer.train()

Command Line Interface

Skip the code entirely - train directly from your terminal:

# SFT: Fine-tune on instructions
trl sft --model_name_or_path Qwen/Qwen2.5-0.5B \
    --dataset_name trl-lib/Capybara

# DPO: Align with preferences  
trl dpo --model_name_or_path Qwen/Qwen2.5-0.5B-Instruct \
    --dataset_name trl-lib/ultrafeedback_binarized

Whatโ€™s Next?

๐Ÿ“š Learn More

๐Ÿš€ Scale Up

๐Ÿ’ก Examples

Troubleshooting

Out of Memory?

Reduce batch size and enable optimizations:

SFT
DPO
training_args = SFTConfig(
    per_device_train_batch_size=1,  # Start small
    gradient_accumulation_steps=8,  # Maintain effective batch size
)

Loss not decreasing?

Try adjusting the learning rate:

training_args = SFTConfig(learning_rate=2e-5)  # Good starting point

For more help, see our Training FAQ or open an issue on GitHub.

< > Update on GitHub