|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from dataclasses import dataclass, field |
|
from typing import Any, Optional |
|
|
|
from .sft_config import SFTConfig |
|
|
|
|
|
@dataclass |
|
class GKDConfig(SFTConfig): |
|
""" |
|
Configuration class for [`GKDTrainer`]. |
|
|
|
Args: |
|
temperature (`float`, *optional*, defaults to `0.9`): |
|
Temperature for sampling. The higher the temperature, the more random the completions. |
|
lmbda (`float`, *optional*, defaults to `0.5`): |
|
Lambda parameter that controls the student data fraction (i.e., the proportion of on-policy |
|
student-generated outputs). |
|
beta (`float`, *optional*, defaults to `0.5`): |
|
Interpolation coefficient between `0.0` and `1.0` of the Generalized Jensen-Shannon Divergence loss. When |
|
beta is `0.0`, the loss is the KL divergence. When beta is `1.0`, the loss is the Inverse KL Divergence. |
|
max_new_tokens (`int`, *optional*, defaults to `128`): |
|
Maximum number of tokens to generate per completion. |
|
teacher_model_name_or_path (`str` or `None`, *optional*, defaults to `None`): |
|
Model name or path of the teacher model. If `None`, the teacher model will be the same as the model |
|
being trained. |
|
teacher_model_init_kwargs (`dict[str, Any]]` or `None`, *optional*, defaults to `None`): |
|
Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the teacher model |
|
from a string. |
|
disable_dropout (`bool`, *optional*, defaults to `True`): |
|
Whether to disable dropout in the model. |
|
seq_kd (`bool`, *optional*, defaults to `False`): |
|
Seq_kd parameter that controls whether to perform Sequence-Level KD (can be viewed as supervised FT |
|
on teacher-generated output). |
|
""" |
|
|
|
temperature: float = field( |
|
default=0.9, |
|
metadata={"help": "Temperature for sampling. The higher the temperature, the more random the completions."}, |
|
) |
|
lmbda: float = field( |
|
default=0.5, |
|
metadata={ |
|
"help": "Lambda parameter that controls the student data fraction (i.e., the proportion of on-policy " |
|
"student-generated outputs)." |
|
}, |
|
) |
|
beta: float = field( |
|
default=0.5, |
|
metadata={ |
|
"help": "Interpolation coefficient between `0.0` and `1.0` of the Generalized Jensen-Shannon Divergence " |
|
"loss. When beta is `0.0`, the loss is the KL divergence. When beta is `1.0`, the loss is the Inverse KL " |
|
"Divergence." |
|
}, |
|
) |
|
max_new_tokens: int = field( |
|
default=128, |
|
metadata={"help": "Maximum number of tokens to generate per completion."}, |
|
) |
|
teacher_model_name_or_path: Optional[str] = field( |
|
default=None, |
|
metadata={ |
|
"help": "Model name or path of the teacher model. If `None`, the teacher model will be the same as the " |
|
"model being trained." |
|
}, |
|
) |
|
teacher_model_init_kwargs: Optional[dict[str, Any]] = field( |
|
default=None, |
|
metadata={ |
|
"help": "Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the " |
|
"teacher model from a string." |
|
}, |
|
) |
|
disable_dropout: bool = field( |
|
default=True, |
|
metadata={"help": "Whether to disable dropouts in `model`."}, |
|
) |
|
seq_kd: bool = field( |
|
default=False, |
|
metadata={ |
|
"help": "Seq_kd parameter that controls whether to perform Sequence-Level KD (can be viewed as supervised " |
|
"FT on teacher-generated output)." |
|
}, |
|
) |
|
|
|
def __post_init__(self): |
|
super().__post_init__() |
|
|
|
if self.lmbda < 0.0 or self.lmbda > 1.0: |
|
raise ValueError("lmbda must be in the range [0.0, 1.0].") |
|
if self.beta < 0.0 or self.beta > 1.0: |
|
raise ValueError("beta must be in the range [0.0, 1.0].") |
|
|