trl-sandbox / docs /source /iterative_sft_trainer.md
ivangabriele's picture
feat: initialize project
2f5127c verified

Iterative Trainer

Iterative fine-tuning is a training method that enables to perform custom actions (generation and filtering for example) between optimization steps. In TRL we provide an easy-to-use API to fine-tune your models in an iterative way in just a few lines of code.

Quickstart

To get started quickly, you can either pass a model identifier or a pre-instantiated model to the trainer:

from trl import IterativeSFTConfig, IterativeSFTTrainer

# Using a model identifier
trainer = IterativeSFTTrainer(
    "facebook/opt-350m",
    args=IterativeSFTConfig(
        max_length=512,
        output_dir="./output",
    ),
)

# Or using a pre-instantiated model
from transformers import AutoModelForCausalLM, AutoTokenizer

model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m")
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")

trainer = IterativeSFTTrainer(
    model,
    args=IterativeSFTConfig(
        max_length=512,
        output_dir="./output",
    ),
    processing_class=tokenizer,
)

Usage

The [IterativeSFTTrainer] supports two ways of providing input data to the step function:

Using a list of tensors as input:

inputs = {
    "input_ids": input_ids,
    "attention_mask": attention_mask,
}

trainer.step(**inputs)

Using a list of strings as input:

inputs = {
    "texts": texts,
    "texts_labels": texts_labels,  # Optional, defaults to texts
}

trainer.step(**inputs)

For causal language models, labels will automatically be created from input_ids or from texts. When using sequence to sequence models you will have to provide your own labels or text_labels.

Configuration

The [IterativeSFTConfig] class provides several parameters to customize the training:

from trl import IterativeSFTConfig

config = IterativeSFTConfig(
    # Model initialization parameters
    model_init_kwargs={"torch_dtype": "bfloat16"},

    # Data preprocessing parameters
    max_length=512,
    truncation_mode="keep_end",

    # Training parameters
    output_dir="./output",
    learning_rate=2e-5,
    per_device_train_batch_size=4,
    gradient_accumulation_steps=4,
    max_steps=1000,
    logging_steps=10,
    save_steps=100,
    optim="adamw_torch",
    report_to="wandb",
)

Model Initialization

You can control how the model is initialized by passing keyword arguments to model_init_kwargs:

config = IterativeSFTConfig(
    model_init_kwargs={
        "torch_dtype": "bfloat16",
        "device_map": "auto",
        "trust_remote_code": True,
    }
)

Data Preprocessing

The trainer supports two truncation modes:

  • keep_end: Truncates from the start of the sequence
  • keep_start: Truncates from the end of the sequence
config = IterativeSFTConfig(
    max_length=512,
    truncation_mode="keep_end",  # or "keep_start"
)

Training Optimization

You can optimize CUDA cache usage for more memory-efficient training:

config = IterativeSFTConfig(
    optimize_device_cache=True,
)

IterativeSFTTrainer

[[autodoc]] IterativeSFTTrainer

IterativeSFTConfig

[[autodoc]] IterativeSFTConfig