|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
import sys |
|
from dataclasses import dataclass, field |
|
from typing import Optional |
|
|
|
from transformers import is_bitsandbytes_available |
|
|
|
from ..core import flatten_dict |
|
|
|
|
|
@dataclass |
|
class DDPOConfig: |
|
r""" |
|
Configuration class for the [`DDPOTrainer`]. |
|
|
|
Using [`~transformers.HfArgumentParser`] we can turn this class into |
|
[argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the |
|
command line. |
|
|
|
Parameters: |
|
exp_name (`str`, *optional*, defaults to `os.path.basename(sys.argv[0])[: -len(".py")]`): |
|
Name of this experiment (by default is the file name without the extension name). |
|
run_name (`str`, *optional*, defaults to `""`): |
|
Name of this run. |
|
seed (`int`, *optional*, defaults to `0`): |
|
Random seed. |
|
log_with (`Literal["wandb", "tensorboard"]]` or `None`, *optional*, defaults to `None`): |
|
Log with either 'wandb' or 'tensorboard', check |
|
https://huggingface.co/docs/accelerate/usage_guides/tracking for more details. |
|
tracker_kwargs (`Dict`, *optional*, defaults to `{}`): |
|
Keyword arguments for the tracker (e.g. wandb_project). |
|
accelerator_kwargs (`Dict`, *optional*, defaults to `{}`): |
|
Keyword arguments for the accelerator. |
|
project_kwargs (`Dict`, *optional*, defaults to `{}`): |
|
Keyword arguments for the accelerator project config (e.g. `logging_dir`). |
|
tracker_project_name (`str`, *optional*, defaults to `"trl"`): |
|
Name of project to use for tracking. |
|
logdir (`str`, *optional*, defaults to `"logs"`): |
|
Top-level logging directory for checkpoint saving. |
|
num_epochs (`int`, *optional*, defaults to `100`): |
|
Number of epochs to train. |
|
save_freq (`int`, *optional*, defaults to `1`): |
|
Number of epochs between saving model checkpoints. |
|
num_checkpoint_limit (`int`, *optional*, defaults to `5`): |
|
Number of checkpoints to keep before overwriting old ones. |
|
mixed_precision (`str`, *optional*, defaults to `"fp16"`): |
|
Mixed precision training. |
|
allow_tf32 (`bool`, *optional*, defaults to `True`): |
|
Allow `tf32` on Ampere GPUs. |
|
resume_from (`str`, *optional*, defaults to `""`): |
|
Resume training from a checkpoint. |
|
sample_num_steps (`int`, *optional*, defaults to `50`): |
|
Number of sampler inference steps. |
|
sample_eta (`float`, *optional*, defaults to `1.0`): |
|
Eta parameter for the DDIM sampler. |
|
sample_guidance_scale (`float`, *optional*, defaults to `5.0`): |
|
Classifier-free guidance weight. |
|
sample_batch_size (`int`, *optional*, defaults to `1`): |
|
Batch size (per GPU) to use for sampling. |
|
sample_num_batches_per_epoch (`int`, *optional*, defaults to `2`): |
|
Number of batches to sample per epoch. |
|
train_batch_size (`int`, *optional*, defaults to `1`): |
|
Batch size (per GPU) to use for training. |
|
train_use_8bit_adam (`bool`, *optional*, defaults to `False`): |
|
Use 8bit Adam optimizer from bitsandbytes. |
|
train_learning_rate (`float`, *optional*, defaults to `3e-4`): |
|
Learning rate. |
|
train_adam_beta1 (`float`, *optional*, defaults to `0.9`): |
|
Adam beta1. |
|
train_adam_beta2 (`float`, *optional*, defaults to `0.999`): |
|
Adam beta2. |
|
train_adam_weight_decay (`float`, *optional*, defaults to `1e-4`): |
|
Adam weight decay. |
|
train_adam_epsilon (`float`, *optional*, defaults to `1e-8`): |
|
Adam epsilon. |
|
train_gradient_accumulation_steps (`int`, *optional*, defaults to `1`): |
|
Number of gradient accumulation steps. |
|
train_max_grad_norm (`float`, *optional*, defaults to `1.0`): |
|
Maximum gradient norm for gradient clipping. |
|
train_num_inner_epochs (`int`, *optional*, defaults to `1`): |
|
Number of inner epochs per outer epoch. |
|
train_cfg (`bool`, *optional*, defaults to `True`): |
|
Whether to use classifier-free guidance during training. |
|
train_adv_clip_max (`float`, *optional*, defaults to `5.0`): |
|
Clip advantages to the range. |
|
train_clip_range (`float`, *optional*, defaults to `1e-4`): |
|
PPO clip range. |
|
train_timestep_fraction (`float`, *optional*, defaults to `1.0`): |
|
Fraction of timesteps to train on. |
|
per_prompt_stat_tracking (`bool`, *optional*, defaults to `False`): |
|
Whether to track statistics for each prompt separately. |
|
per_prompt_stat_tracking_buffer_size (`int`, *optional*, defaults to `16`): |
|
Number of reward values to store in the buffer for each prompt. |
|
per_prompt_stat_tracking_min_count (`int`, *optional*, defaults to `16`): |
|
Minimum number of reward values to store in the buffer. |
|
async_reward_computation (`bool`, *optional*, defaults to `False`): |
|
Whether to compute rewards asynchronously. |
|
max_workers (`int`, *optional*, defaults to `2`): |
|
Maximum number of workers to use for async reward computation. |
|
negative_prompts (`str`, *optional*, defaults to `""`): |
|
Comma-separated list of prompts to use as negative examples. |
|
push_to_hub (`bool`, *optional*, defaults to `False`): |
|
Whether to push the final model checkpoint to the Hub. |
|
""" |
|
|
|
exp_name: str = field( |
|
default=os.path.basename(sys.argv[0])[: -len(".py")], |
|
metadata={"help": "Name of this experiment (by default is the file name without the extension name)."}, |
|
) |
|
run_name: str = field( |
|
default="", |
|
metadata={"help": "Name of this run."}, |
|
) |
|
seed: int = field( |
|
default=0, |
|
metadata={"help": "Random seed."}, |
|
) |
|
log_with: Optional[str] = field( |
|
default=None, |
|
metadata={ |
|
"help": "Log with either 'wandb' or 'tensorboard'.", |
|
"choices": ["wandb", "tensorboard"], |
|
}, |
|
) |
|
tracker_kwargs: dict = field( |
|
default_factory=dict, |
|
metadata={"help": "Keyword arguments for the tracker (e.g. wandb_project)."}, |
|
) |
|
accelerator_kwargs: dict = field( |
|
default_factory=dict, |
|
metadata={"help": "Keyword arguments for the accelerator."}, |
|
) |
|
project_kwargs: dict = field( |
|
default_factory=dict, |
|
metadata={"help": "Keyword arguments for the accelerator project config (e.g. `logging_dir`)."}, |
|
) |
|
tracker_project_name: str = field( |
|
default="trl", |
|
metadata={"help": "Name of project to use for tracking."}, |
|
) |
|
logdir: str = field( |
|
default="logs", |
|
metadata={"help": "Top-level logging directory for checkpoint saving."}, |
|
) |
|
num_epochs: int = field( |
|
default=100, |
|
metadata={"help": "Number of epochs to train."}, |
|
) |
|
save_freq: int = field( |
|
default=1, |
|
metadata={"help": "Number of epochs between saving model checkpoints."}, |
|
) |
|
num_checkpoint_limit: int = field( |
|
default=5, |
|
metadata={"help": "Number of checkpoints to keep before overwriting old ones."}, |
|
) |
|
mixed_precision: str = field( |
|
default="fp16", |
|
metadata={"help": "Mixed precision training."}, |
|
) |
|
allow_tf32: bool = field( |
|
default=True, |
|
metadata={"help": "Allow `tf32` on Ampere GPUs."}, |
|
) |
|
resume_from: str = field( |
|
default="", |
|
metadata={"help": "Resume training from a checkpoint."}, |
|
) |
|
sample_num_steps: int = field( |
|
default=50, |
|
metadata={"help": "Number of sampler inference steps."}, |
|
) |
|
sample_eta: float = field( |
|
default=1.0, |
|
metadata={"help": "Eta parameter for the DDIM sampler."}, |
|
) |
|
sample_guidance_scale: float = field( |
|
default=5.0, |
|
metadata={"help": "Classifier-free guidance weight."}, |
|
) |
|
sample_batch_size: int = field( |
|
default=1, |
|
metadata={"help": "Batch size (per GPU) to use for sampling."}, |
|
) |
|
sample_num_batches_per_epoch: int = field( |
|
default=2, |
|
metadata={"help": "Number of batches to sample per epoch."}, |
|
) |
|
train_batch_size: int = field( |
|
default=1, |
|
metadata={"help": "Batch size (per GPU) to use for training."}, |
|
) |
|
train_use_8bit_adam: bool = field( |
|
default=False, |
|
metadata={"help": "Use 8bit Adam optimizer from bitsandbytes."}, |
|
) |
|
train_learning_rate: float = field( |
|
default=3e-4, |
|
metadata={"help": "Learning rate."}, |
|
) |
|
train_adam_beta1: float = field( |
|
default=0.9, |
|
metadata={"help": "Adam beta1."}, |
|
) |
|
train_adam_beta2: float = field( |
|
default=0.999, |
|
metadata={"help": "Adam beta2."}, |
|
) |
|
train_adam_weight_decay: float = field( |
|
default=1e-4, |
|
metadata={"help": "Adam weight decay."}, |
|
) |
|
train_adam_epsilon: float = field( |
|
default=1e-8, |
|
metadata={"help": "Adam epsilon."}, |
|
) |
|
train_gradient_accumulation_steps: int = field( |
|
default=1, |
|
metadata={"help": "Number of gradient accumulation steps."}, |
|
) |
|
train_max_grad_norm: float = field( |
|
default=1.0, |
|
metadata={"help": "Maximum gradient norm for gradient clipping."}, |
|
) |
|
train_num_inner_epochs: int = field( |
|
default=1, |
|
metadata={"help": "Number of inner epochs per outer epoch."}, |
|
) |
|
train_cfg: bool = field( |
|
default=True, |
|
metadata={"help": "Whether to use classifier-free guidance during training."}, |
|
) |
|
train_adv_clip_max: float = field( |
|
default=5.0, |
|
metadata={"help": "Clip advantages to the range."}, |
|
) |
|
train_clip_range: float = field( |
|
default=1e-4, |
|
metadata={"help": "PPO clip range."}, |
|
) |
|
train_timestep_fraction: float = field( |
|
default=1.0, |
|
metadata={"help": "Fraction of timesteps to train on."}, |
|
) |
|
per_prompt_stat_tracking: bool = field( |
|
default=False, |
|
metadata={"help": "Whether to track statistics for each prompt separately."}, |
|
) |
|
per_prompt_stat_tracking_buffer_size: int = field( |
|
default=16, |
|
metadata={"help": "Number of reward values to store in the buffer for each prompt."}, |
|
) |
|
per_prompt_stat_tracking_min_count: int = field( |
|
default=16, |
|
metadata={"help": "Minimum number of reward values to store in the buffer."}, |
|
) |
|
async_reward_computation: bool = field( |
|
default=False, |
|
metadata={"help": "Whether to compute rewards asynchronously."}, |
|
) |
|
max_workers: int = field( |
|
default=2, |
|
metadata={"help": "Maximum number of workers to use for async reward computation."}, |
|
) |
|
negative_prompts: str = field( |
|
default="", |
|
metadata={"help": "Comma-separated list of prompts to use as negative examples."}, |
|
) |
|
push_to_hub: bool = field( |
|
default=False, |
|
metadata={"help": "Whether to push the final model checkpoint to the Hub."}, |
|
) |
|
|
|
def to_dict(self): |
|
output_dict = {} |
|
for key, value in self.__dict__.items(): |
|
output_dict[key] = value |
|
return flatten_dict(output_dict) |
|
|
|
def __post_init__(self): |
|
if self.train_use_8bit_adam and not is_bitsandbytes_available(): |
|
raise ImportError( |
|
"You need to install bitsandbytes to use 8bit Adam. " |
|
"You can install it with `pip install bitsandbytes`." |
|
) |
|
|