|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import warnings |
|
from dataclasses import dataclass, field |
|
from typing import Optional, Union |
|
|
|
import transformers |
|
from packaging import version |
|
from transformers import TrainingArguments |
|
|
|
|
|
@dataclass |
|
class GRPOConfig(TrainingArguments): |
|
r""" |
|
Configuration class for the [`GRPOTrainer`]. |
|
|
|
Only the parameters specific to GRPO training are listed here. For details on other parameters, refer to the |
|
[`~transformers.TrainingArguments`] documentation. |
|
|
|
Using [`~transformers.HfArgumentParser`] we can turn this class into |
|
[argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the |
|
command line. |
|
|
|
Parameters: |
|
> Parameters that control the model and reference model |
|
|
|
model_init_kwargs (`str`, `dict[str, Any]` or `None`, *optional*, defaults to `None`): |
|
Keyword arguments for [`~transformers.AutoModelForCausalLM.from_pretrained`], used when the `model` |
|
argument of the [`GRPOTrainer`] is provided as a string. |
|
disable_dropout (`bool`, *optional*, defaults to `False`): |
|
Whether to disable dropout in the model. This is useful for training with a reference model, as it |
|
prevents the model from generating different logprobs for the same input. |
|
|
|
> Parameters that control the data preprocessing |
|
|
|
remove_unused_columns (`bool`, *optional*, defaults to `False`): |
|
Whether to only keep the column `"prompt"` in the dataset. If you use a custom reward function that |
|
requires any column other than `"prompts"` and `"completions"`, you should keep this to `False`. |
|
max_prompt_length (`int` or `None`, *optional*, defaults to `512`): |
|
Maximum length of the prompt. If the prompt is longer than this value, it will be truncated left. |
|
num_generations (`int` or `None`, *optional*, defaults to `8`): |
|
Number of generations per prompt to sample. The effective batch size (num_processes * |
|
per_device_batch_size * gradient_accumulation_steps) must be evenly divisible by this value. |
|
max_completion_length (`int` or `None`, *optional*, defaults to `256`): |
|
Maximum length of the generated completion. |
|
ds3_gather_for_generation (`bool`, *optional*, defaults to `True`): |
|
This setting applies to DeepSpeed ZeRO-3. If enabled, the policy model weights are gathered for generation, |
|
improving generation speed. However, disabling this option allows training models that exceed the VRAM |
|
capacity of a single GPU, albeit at the cost of slower generation. Disabling this option is not compatible |
|
with vLLM generation. |
|
shuffle_dataset (`bool`, *optional*, defaults to `True`): |
|
Whether to shuffle the training dataset. |
|
|
|
> Parameters that control generation |
|
|
|
temperature (`float`, defaults to `0.9`): |
|
Temperature for sampling. The higher the temperature, the more random the completions. |
|
top_p (`float`, *optional*, defaults to `1.0`): |
|
Float that controls the cumulative probability of the top tokens to consider. Must be in (0, 1]. Set to |
|
`1.0` to consider all tokens. |
|
top_k (`int` or `None`, *optional*, defaults to `50`): |
|
Number of highest probability vocabulary tokens to keep for top-k-filtering. If `None`, top-k-filtering is |
|
disabled. |
|
min_p (`float` or `None`, *optional*, defaults to `None`): |
|
Minimum token probability, which will be scaled by the probability of the most likely token. It must be a |
|
value between `0.0` and `1.0`. Typical values are in the `0.01-0.2` range. |
|
repetition_penalty (`float`, *optional*, defaults to `1.0`): |
|
Float that penalizes new tokens based on whether they appear in the prompt and the generated text so far. |
|
Values > `1.0` encourage the model to use new tokens, while values < `1.0` encourage the model to repeat |
|
tokens. |
|
cache_implementation (`str` or `None`, *optional*, defaults to `None`): |
|
Implementation of the cache method for faster generation when use_vllm is set to False. |
|
|
|
> Parameters that control generation acceleration powered by vLLM |
|
|
|
use_vllm (`bool`, *optional*, defaults to `False`): |
|
Whether to use vLLM for generating completions. If set to `True`, ensure that a GPU is kept unused for |
|
training, as vLLM will require one for generation. vLLM must be installed (`pip install vllm`). |
|
vllm_server_host (`str`, *optional*, defaults to `"0.0.0.0"`): |
|
Host of the vLLM server to connect to. |
|
vllm_server_port (`int`, *optional*, defaults to `8000`): |
|
Port of the vLLM server to connect to. |
|
vllm_server_timeout (`float`, *optional*, defaults to `120.0`): |
|
Total timeout duration in seconds to wait for the vLLM server to be up. If the server is not up after the |
|
timeout, a `ConnectionError` is raised. |
|
vllm_guided_decoding_regex (`str` or `None`, *optional*, defaults to `None`): |
|
Regex for vLLM guided decoding. If `None` (default), guided decoding is disabled. |
|
|
|
> Parameters that control the training |
|
|
|
learning_rate (`float`, *optional*, defaults to `1e-6`): |
|
Initial learning rate for [`AdamW`] optimizer. The default value replaces that of |
|
[`~transformers.TrainingArguments`]. |
|
beta (`float`, *optional*, defaults to `0.04`): |
|
KL coefficient. If `0.0`, the reference model is not loaded, reducing memory usage and improving training |
|
speed, but may be numerically unstable for long training runs. |
|
num_iterations (`int`, *optional*, defaults to `1`): |
|
Number of iterations per batch (denoted as ΞΌ in the algorithm). |
|
epsilon (`float`, *optional*, defaults to `0.2`): |
|
Epsilon value for clipping. |
|
epsilon_high (`float` or `None`, *optional*, defaults to `None`): |
|
Upper-bound epsilon value for clipping. If not specified, it defaults to the same value as the lower-bound |
|
specified in argument `epsilon`. Paper [DAPO](https://huggingface.co/papers/2503.14476) recommends `0.28`. |
|
reward_weights (`list[float]` or `None`, *optional*, defaults to `None`): |
|
Weights for each reward function. Must match the number of reward functions. If `None`, all rewards are |
|
weighted equally with weight `1.0`. |
|
scale_rewards (`bool`, *optional*, defaults to `True`): |
|
Whether to scale the rewards by dividing them by their standard deviation. If `True` (default), the rewards |
|
are normalized by the standard deviation, ensuring they have unit variance. If `False`, no scaling is |
|
applied. The [Dr. GRPO paper](https://huggingface.co/papers/2503.20783) recommends not scaling the rewards, |
|
as scaling by the standard deviation introduces a question-level difficulty bias. |
|
loss_type (`str`, *optional*, defaults to `"bnpo"`): |
|
Specifies the loss formulation to use. Supported values are: |
|
|
|
- `"grpo"`: Aggregates token-level losses by normalizing over sequence length. Not recommended due to |
|
length biasβthis approach tends to prefer shorter completions with positive advantages and longer ones |
|
with negative advantages. |
|
- `"bnpo"`: Aggregates token-level losses by normalizing number of active token in the local batch. |
|
Note that normalization is performed over the local batch only, so results may slightly vary depending |
|
on the local batch size, despite a constant effective batch size. When using |
|
`per_device_train_batch_size==1`, the loss is equivalent to the GRPO loss. |
|
- `"dr_grpo"`: Aggregates token-level losses by normalizing with a global constant. This method was |
|
introduced in the [Dr. GRPO paper](https://huggingface.co/papers/2503.20783) to eliminate length bias. |
|
The value of the constant corresponds to `max_completion_length`. |
|
mask_truncated_completions (`bool`, *optional*, defaults to `False`): |
|
When enabled, truncated completions are excluded from the loss calculation, preventing them from being |
|
incorrectly penalized and introducing noise during training. According to the |
|
[DAPO](https://huggingface.co/papers/2503.14476) paper, this is a good practice for training stability. |
|
sync_ref_model (`bool`, *optional*, defaults to `False`): |
|
Whether to synchronize the reference model with the active model every `ref_model_sync_steps` steps, using |
|
the `ref_model_mixup_alpha` parameter. This synchronization originites from the |
|
[TR-DPO](https://huggingface.co/papers/2404.09656) paper. |
|
ref_model_mixup_alpha (`float`, *optional*, defaults to `0.6`): |
|
Ξ± parameter from the [TR-DPO](https://huggingface.co/papers/2404.09656) paper, which controls the mix |
|
between the current policy and the previous reference policy during updates. The reference policy is |
|
updated according to the equation: `Ο_ref = Ξ± * Ο_ΞΈ + (1 - Ξ±) * Ο_ref_prev`. To use this parameter, you |
|
must set `sync_ref_model=True`. |
|
ref_model_sync_steps (`int`, *optional*, defaults to `512`): |
|
Ο parameter from the [TR-DPO](https://huggingface.co/papers/2404.09656) paper, which determines how |
|
frequently the current policy is synchronized with the reference policy. To use this parameter, you must |
|
set `sync_ref_model=True`. |
|
use_liger_loss (`bool`, *optional*, defaults to `False`): |
|
Whether to use the Liger GRPO loss. |
|
|
|
> Parameters that control the logging |
|
|
|
log_completions (`bool`, *optional*, defaults to `False`): |
|
Whether to log a sample of (prompt, completion) pairs every `logging_steps` steps. If `rich` is |
|
installed, it prints the sample. If `wandb` logging is enabled, it logs it to `wandb`. |
|
num_completions_to_print (`int` or `None`, *optional*, defaults to `None`): |
|
Number of completions to print with `rich`. If `None`, all completions are logged. |
|
wandb_log_unique_prompts (`bool`, *optional*, defaults to `False`): |
|
Whether to log unique prompts in wandb. If `True`, only unique prompts are logged. If `False`, all |
|
prompts are logged. |
|
""" |
|
|
|
if version.parse(transformers.__version__) <= version.parse("4.50.3"): |
|
from transformers.training_args import _VALID_DICT_FIELDS |
|
|
|
_VALID_DICT_FIELDS.append("model_init_kwargs") |
|
else: |
|
_VALID_DICT_FIELDS = TrainingArguments._VALID_DICT_FIELDS + ["model_init_kwargs"] |
|
|
|
|
|
model_init_kwargs: Optional[Union[dict, str]] = field( |
|
default=None, |
|
metadata={ |
|
"help": "Keyword arguments for `transformers.AutoModelForCausalLM.from_pretrained`, used when the `model` " |
|
"argument of the `GRPOTrainer` is provided as a string." |
|
}, |
|
) |
|
disable_dropout: bool = field( |
|
default=False, |
|
metadata={ |
|
"help": "Whether to disable dropout in the model. This is useful for training with a reference model, as " |
|
"it prevents the model from generating different logprobs for the same input." |
|
}, |
|
) |
|
|
|
|
|
|
|
|
|
remove_unused_columns: Optional[bool] = field( |
|
default=False, |
|
metadata={ |
|
"help": "Whether to only keep the column 'prompt' in the dataset. If you use a custom reward function " |
|
"that requires any column other than 'prompts' and 'completions', you should keep this to `False`." |
|
}, |
|
) |
|
max_prompt_length: Optional[int] = field( |
|
default=512, |
|
metadata={ |
|
"help": "Maximum length of the prompt. If the prompt is longer than this value, it will be truncated left." |
|
}, |
|
) |
|
num_generations: Optional[int] = field( |
|
default=8, |
|
metadata={ |
|
"help": "Number of generations to sample. The effective batch size (num_processes * per_device_batch_size " |
|
"* gradient_accumulation_steps) must be evenly divisible by this value." |
|
}, |
|
) |
|
max_completion_length: Optional[int] = field( |
|
default=256, |
|
metadata={"help": "Maximum length of the generated completion."}, |
|
) |
|
ds3_gather_for_generation: bool = field( |
|
default=True, |
|
metadata={ |
|
"help": "This setting applies to DeepSpeed ZeRO-3. If enabled, the policy model weights are gathered for " |
|
"generation, improving generation speed. However, disabling this option allows training models that " |
|
"exceed the VRAM capacity of a single GPU, albeit at the cost of slower generation. Disabling this option " |
|
"is not compatible with vLLM generation." |
|
}, |
|
) |
|
shuffle_dataset: Optional[bool] = field( |
|
default=True, |
|
metadata={"help": "Whether to shuffle the training dataset."}, |
|
) |
|
|
|
|
|
temperature: float = field( |
|
default=0.9, |
|
metadata={"help": "Temperature for sampling. The higher the temperature, the more random the completions."}, |
|
) |
|
top_p: float = field( |
|
default=1.0, |
|
metadata={ |
|
"help": "Float that controls the cumulative probability of the top tokens to consider. Must be in (0, 1]. " |
|
"Set to 1.0 to consider all tokens." |
|
}, |
|
) |
|
top_k: Optional[int] = field( |
|
default=50, |
|
metadata={ |
|
"help": "Number of highest probability vocabulary tokens to keep for top-k-filtering. If `None`, " |
|
"top-k-filtering is disabled." |
|
}, |
|
) |
|
min_p: Optional[float] = field( |
|
default=None, |
|
metadata={ |
|
"help": "Minimum token probability, which will be scaled by the probability of the most likely token. It " |
|
"must be a value between 0.0 and 1.0. Typical values are in the 0.01-0.2 range." |
|
}, |
|
) |
|
repetition_penalty: float = field( |
|
default=1.0, |
|
metadata={ |
|
"help": "Float that penalizes new tokens based on whether they appear in the prompt and the generated " |
|
"text so far. Values > 1.0 encourage the model to use new tokens, while values < 1.0 encourage the model " |
|
"to repeat tokens." |
|
}, |
|
) |
|
cache_implementation: Optional[str] = field( |
|
default=None, |
|
metadata={"help": "Implementation of the cache method for faster generation when use_vllm is set to False."}, |
|
) |
|
|
|
|
|
use_vllm: bool = field( |
|
default=False, |
|
metadata={ |
|
"help": "Whether to use vLLM for generating completions. If set to `True`, ensure that a vLLM server is " |
|
"running. To run the server, install vLLM (`pip install vllm`) and run `trl vllm-serve`." |
|
}, |
|
) |
|
vllm_server_host: str = field( |
|
default="0.0.0.0", |
|
metadata={"help": "Host of the vLLM server to connect to."}, |
|
) |
|
vllm_server_port: int = field( |
|
default=8000, |
|
metadata={"help": "Port of the vLLM server to connect to."}, |
|
) |
|
vllm_server_timeout: float = field( |
|
default=240.0, |
|
metadata={ |
|
"help": "Total timeout duration in seconds to wait for the vLLM server to be up. If the server is not up " |
|
"after the timeout, a `ConnectionError` is raised." |
|
}, |
|
) |
|
vllm_guided_decoding_regex: Optional[str] = field( |
|
default=None, |
|
metadata={"help": "Regex for vLLM guided decoding. If `None` (default), guided decoding is disabled."}, |
|
) |
|
|
|
|
|
learning_rate: float = field( |
|
default=1e-6, |
|
metadata={ |
|
"help": "Initial learning rate for `AdamW` optimizer. The default value replaces that of " |
|
"`transformers.TrainingArguments`." |
|
}, |
|
) |
|
beta: float = field( |
|
default=0.04, |
|
metadata={ |
|
"help": "KL coefficient. If `0.0`, the reference model is not loaded, reducing memory usage and improving " |
|
"training speed, but may be numerically unstable for long training runs." |
|
}, |
|
) |
|
num_iterations: int = field( |
|
default=1, |
|
metadata={"help": "Number of iterations per batch (denoted as ΞΌ in the algorithm)."}, |
|
) |
|
epsilon: float = field( |
|
default=0.2, |
|
metadata={"help": "Epsilon value for clipping."}, |
|
) |
|
epsilon_high: Optional[float] = field( |
|
default=None, |
|
metadata={ |
|
"help": "Upper-bound epsilon value for clipping. If not specified, it defaults to the same value as the " |
|
"lower-bound specified in argument `epsilon`. Paper DAPO recommends `0.28`." |
|
}, |
|
) |
|
reward_weights: Optional[list[float]] = field( |
|
default=None, |
|
metadata={ |
|
"help": "Weights for each reward function. Must match the number of reward functions. If `None`, all " |
|
"rewards are weighted equally with weight `1.0`." |
|
}, |
|
) |
|
scale_rewards: bool = field( |
|
default=True, |
|
metadata={ |
|
"help": "Whether to scale the rewards by dividing them by their standard deviation. If `True` (default), " |
|
"the rewards are normalized by the standard deviation, ensuring they have unit variance. If `False`, no " |
|
"scaling is applied. The Dr. GRPO paper recommends not scaling the rewards, as scaling by the standard " |
|
"deviation introduces a question-level difficulty bias." |
|
}, |
|
) |
|
loss_type: str = field( |
|
default="bnpo", |
|
metadata={ |
|
"help": "Specifies the loss formulation to use. Supported values are `grpo`, `bnpo`, and `dr_grpo`. " |
|
"`'grpo'`: Aggregates token-level losses by normalizing over sequence length. Not recommended due to " |
|
"length biasβthis approach tends to prefer shorter completions with positive advantages and longer ones " |
|
"with negative advantages. " |
|
"`'bnpo'`: Aggregates token-level losses by normalizing number of active token in the local batch. " |
|
"Note that normalization is performed over the local batch only, so results may slightly vary depending " |
|
"on the local batch size, despite a constant effective batch size. When using " |
|
"`per_device_train_batch_size==1`, the loss is equivalent to the GRPO loss. " |
|
"`'dr_grpo'`: Aggregates token-level losses by normalizing with a global constant. This method was " |
|
"introduced in the Dr. GRPO paper to eliminate length bias. The value of the constant corresponds to " |
|
"`max_completion_length`." |
|
}, |
|
) |
|
mask_truncated_completions: bool = field( |
|
default=False, |
|
metadata={ |
|
"help": "When enabled, truncated completions are excluded from the loss calculation, preventing them from " |
|
"being incorrectly penalized and introducing noise during training. According to the DAPO paper, this is " |
|
"a good practice for training stability." |
|
}, |
|
) |
|
sync_ref_model: bool = field( |
|
default=False, |
|
metadata={ |
|
"help": "Whether to synchronize the reference model with the active model every `ref_model_sync_steps` " |
|
"steps, using the `ref_model_mixup_alpha` parameter." |
|
}, |
|
) |
|
ref_model_mixup_alpha: float = field( |
|
default=0.6, |
|
metadata={ |
|
"help": "Ξ± parameter from the TR-DPO paper, which controls the mix between the current policy and the " |
|
"previous reference policy during updates. The reference policy is updated according to the equation: " |
|
"`Ο_ref = Ξ± * Ο_ΞΈ + (1 - Ξ±) * Ο_ref_prev`. To use this parameter, you must set `sync_ref_model=True`." |
|
}, |
|
) |
|
ref_model_sync_steps: int = field( |
|
default=512, |
|
metadata={ |
|
"help": "Ο parameter from the TR-DPO paper, which determines how frequently the current policy is " |
|
"synchronized with the reference policy. To use this parameter, you must set `sync_ref_model=True`." |
|
}, |
|
) |
|
use_liger_loss: bool = field( |
|
default=False, |
|
metadata={"help": "Whether to use the Liger GRPO loss."}, |
|
) |
|
|
|
|
|
log_completions: bool = field( |
|
default=False, |
|
metadata={ |
|
"help": "Whether to log a sample of (prompt, completion) pairs every `logging_steps` steps. If `rich` is " |
|
"installed, it prints the sample. If `wandb` logging is enabled, it logs it to `wandb`." |
|
}, |
|
) |
|
num_completions_to_print: Optional[int] = field( |
|
default=None, |
|
metadata={"help": "Number of completions to print with `rich`. If `None`, all completions are logged."}, |
|
) |
|
wandb_log_unique_prompts: Optional[bool] = field( |
|
default=False, |
|
metadata={ |
|
"help": "Whether to log unique prompts in wandb. If `True`, only unique prompts are logged. If `False`, " |
|
"all prompts are logged." |
|
}, |
|
) |
|
|
|
|
|
vllm_device: Optional[str] = field( |
|
default=None, |
|
metadata={ |
|
"help": "This parameter is deprecated and will be removed in version 0.18.0. To use vLLM, start a vLLM " |
|
"server with the `trl vllm-serve` command." |
|
}, |
|
) |
|
vllm_gpu_memory_utilization: Optional[float] = field( |
|
default=None, |
|
metadata={ |
|
"help": "This parameter is deprecated and will be removed in version 0.18.0. To control the GPU memory " |
|
"utilization for vLLM, you should now use the `gpu_memory_utilization` parameter in the vLLM server " |
|
"configuration." |
|
}, |
|
) |
|
vllm_dtype: Optional[str] = field( |
|
default=None, |
|
metadata={ |
|
"help": "This parameter is deprecated and will be removed in version 0.18.0. To control the data type for " |
|
"vLLM generation, you should now use the `dtype` parameter in the vLLM server configuration." |
|
}, |
|
) |
|
vllm_max_model_len: Optional[int] = field( |
|
default=None, |
|
metadata={ |
|
"help": "This parameter is deprecated and will be removed in version 0.18.0. To control the " |
|
"`max_model_len` for vLLM, you should now use the `max_model_len` parameter in the vLLM server " |
|
"configuration." |
|
}, |
|
) |
|
vllm_enable_prefix_caching: Optional[bool] = field( |
|
default=None, |
|
metadata={ |
|
"help": "This parameter is deprecated and will be removed in version 0.18.0. To control prefix caching in " |
|
"vLLM, you should now use the `enable_prefix_caching` parameter in the vLLM server configuration." |
|
}, |
|
) |
|
|
|
def __post_init__(self): |
|
super().__post_init__() |
|
|
|
if self.vllm_device is not None: |
|
warnings.warn( |
|
"`vllm_device` is deprecated and will be removed in version 0.18.0. To use vLLM, start a vLLM server " |
|
"with the `trl vllm-serve` command.", |
|
DeprecationWarning, |
|
) |
|
|
|
if self.vllm_gpu_memory_utilization is not None: |
|
warnings.warn( |
|
"`vllm_gpu_memory_utilization` is deprecated and will be removed in v0.18. To control the GPU memory " |
|
"utilization for vLLM, you should now use the `gpu_memory_utilization` parameter in the vLLM server " |
|
"configuration.", |
|
DeprecationWarning, |
|
) |
|
|
|
if self.vllm_dtype is not None: |
|
warnings.warn( |
|
"`vllm_dtype` is deprecated and will be removed in version 0.18.0. To control the data type for vLLM " |
|
"generation, you should now use the `dtype` parameter in the vLLM server configuration.", |
|
DeprecationWarning, |
|
) |
|
|
|
if self.vllm_max_model_len is not None: |
|
warnings.warn( |
|
"`vllm_max_model_len` is deprecated and will be removed in version 0.18.0. To control the " |
|
"`max_model_len` for vLLM, you should now use the `max_model_len` parameter in the vLLM server " |
|
"configuration.", |
|
DeprecationWarning, |
|
) |
|
|
|
if self.vllm_enable_prefix_caching is not None: |
|
warnings.warn( |
|
"`vllm_enable_prefix_caching` is deprecated and will be removed in version 0.18.0. To control prefix " |
|
"caching in vLLM, you should now use the `enable_prefix_caching` parameter in the vLLM server " |
|
"configuration.", |
|
DeprecationWarning, |
|
) |
|
|