|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import importlib |
|
import inspect |
|
import logging |
|
import os |
|
import subprocess |
|
import sys |
|
from collections.abc import Iterable |
|
from dataclasses import dataclass, field |
|
from typing import Optional, Union |
|
|
|
import yaml |
|
from transformers import HfArgumentParser |
|
from transformers.hf_argparser import DataClass, DataClassType |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
@dataclass |
|
class ScriptArguments: |
|
""" |
|
Arguments common to all scripts. |
|
|
|
Args: |
|
dataset_name (`str`): |
|
Dataset name. |
|
dataset_config (`str` or `None`, *optional*, defaults to `None`): |
|
Dataset configuration name. Corresponds to the `name` argument of the [`~datasets.load_dataset`] function. |
|
dataset_train_split (`str`, *optional*, defaults to `"train"`): |
|
Dataset split to use for training. |
|
dataset_test_split (`str`, *optional*, defaults to `"test"`): |
|
Dataset split to use for evaluation. |
|
gradient_checkpointing_use_reentrant (`bool`, *optional*, defaults to `False`): |
|
Whether to apply `use_reentrant` for gradient checkpointing. |
|
ignore_bias_buffers (`bool`, *optional*, defaults to `False`): |
|
Debug argument for distributed training. Fix for DDP issues with LM bias/mask buffers - invalid scalar |
|
type, inplace operation. See https://github.com/huggingface/transformers/issues/22482#issuecomment-1595790992. |
|
""" |
|
|
|
dataset_name: str = field(metadata={"help": "Dataset name."}) |
|
dataset_config: Optional[str] = field( |
|
default=None, |
|
metadata={ |
|
"help": "Dataset configuration name. Corresponds to the `name` argument of the `datasets.load_dataset` " |
|
"function." |
|
}, |
|
) |
|
dataset_train_split: str = field(default="train", metadata={"help": "Dataset split to use for training."}) |
|
dataset_test_split: str = field(default="test", metadata={"help": "Dataset split to use for evaluation."}) |
|
gradient_checkpointing_use_reentrant: bool = field( |
|
default=False, |
|
metadata={"help": "Whether to apply `use_reentrant` for gradient checkpointing."}, |
|
) |
|
ignore_bias_buffers: bool = field( |
|
default=False, |
|
metadata={ |
|
"help": "Debug argument for distributed training. Fix for DDP issues with LM bias/mask buffers - invalid " |
|
"scalar type, inplace operation. See " |
|
"https://github.com/huggingface/transformers/issues/22482#issuecomment-1595790992." |
|
}, |
|
) |
|
|
|
|
|
def init_zero_verbose(): |
|
""" |
|
Perform zero verbose init - use this method on top of the CLI modules to make |
|
""" |
|
import logging |
|
import warnings |
|
|
|
from rich.logging import RichHandler |
|
|
|
FORMAT = "%(message)s" |
|
logging.basicConfig(format=FORMAT, datefmt="[%X]", handlers=[RichHandler()], level=logging.ERROR) |
|
|
|
|
|
def warning_handler(message, category, filename, lineno, file=None, line=None): |
|
logging.warning(f"{filename}:{lineno}: {category.__name__}: {message}") |
|
|
|
|
|
warnings.showwarning = warning_handler |
|
|
|
|
|
class TrlParser(HfArgumentParser): |
|
""" |
|
A subclass of [`transformers.HfArgumentParser`] designed for parsing command-line arguments with dataclass-backed |
|
configurations, while also supporting configuration file loading and environment variable management. |
|
|
|
Args: |
|
dataclass_types (`Union[DataClassType, Iterable[DataClassType]]` or `None`, *optional*, defaults to `None`): |
|
Dataclass types to use for argument parsing. |
|
**kwargs: |
|
Additional keyword arguments passed to the [`transformers.HfArgumentParser`] constructor. |
|
|
|
Examples: |
|
|
|
```yaml |
|
# config.yaml |
|
env: |
|
VAR1: value1 |
|
arg1: 23 |
|
``` |
|
|
|
```python |
|
# main.py |
|
import os |
|
from dataclasses import dataclass |
|
from trl import TrlParser |
|
|
|
@dataclass |
|
class MyArguments: |
|
arg1: int |
|
arg2: str = "alpha" |
|
|
|
parser = TrlParser(dataclass_types=[MyArguments]) |
|
training_args = parser.parse_args_and_config() |
|
|
|
print(training_args, os.environ.get("VAR1")) |
|
``` |
|
|
|
```bash |
|
$ python main.py --config config.yaml |
|
(MyArguments(arg1=23, arg2='alpha'),) value1 |
|
|
|
$ python main.py --arg1 5 --arg2 beta |
|
(MyArguments(arg1=5, arg2='beta'),) None |
|
``` |
|
""" |
|
|
|
def __init__( |
|
self, |
|
dataclass_types: Optional[Union[DataClassType, Iterable[DataClassType]]] = None, |
|
**kwargs, |
|
): |
|
|
|
if dataclass_types is None: |
|
dataclass_types = [] |
|
elif not isinstance(dataclass_types, Iterable): |
|
dataclass_types = [dataclass_types] |
|
|
|
|
|
for dataclass_type in dataclass_types: |
|
if "config" in dataclass_type.__dataclass_fields__: |
|
raise ValueError( |
|
f"Dataclass {dataclass_type.__name__} has a field named 'config'. This field is reserved for the " |
|
f"config file path and should not be used in the dataclass." |
|
) |
|
|
|
super().__init__(dataclass_types=dataclass_types, **kwargs) |
|
|
|
def parse_args_and_config( |
|
self, args: Optional[Iterable[str]] = None, return_remaining_strings: bool = False |
|
) -> tuple[DataClass, ...]: |
|
""" |
|
Parse command-line args and config file into instances of the specified dataclass types. |
|
|
|
This method wraps [`transformers.HfArgumentParser.parse_args_into_dataclasses`] and also parses the config file |
|
specified with the `--config` flag. The config file (in YAML format) provides argument values that replace the |
|
default values in the dataclasses. Command line arguments can override values set by the config file. The |
|
method also sets any environment variables specified in the `env` field of the config file. |
|
""" |
|
args = list(args) if args is not None else sys.argv[1:] |
|
if "--config" in args: |
|
|
|
config_index = args.index("--config") |
|
args.pop(config_index) |
|
config_path = args.pop(config_index) |
|
with open(config_path) as yaml_file: |
|
config = yaml.safe_load(yaml_file) |
|
|
|
|
|
if "env" in config: |
|
env_vars = config.pop("env", {}) |
|
if not isinstance(env_vars, dict): |
|
raise ValueError("`env` field should be a dict in the YAML file.") |
|
for key, value in env_vars.items(): |
|
os.environ[key] = str(value) |
|
|
|
|
|
config_remaining_strings = self.set_defaults_with_config(**config) |
|
else: |
|
config_remaining_strings = [] |
|
|
|
|
|
output = self.parse_args_into_dataclasses(args=args, return_remaining_strings=return_remaining_strings) |
|
|
|
|
|
if return_remaining_strings: |
|
args_remaining_strings = output[-1] |
|
return output[:-1] + (config_remaining_strings + args_remaining_strings,) |
|
else: |
|
return output |
|
|
|
def set_defaults_with_config(self, **kwargs) -> list[str]: |
|
""" |
|
Overrides the parser's default values with those provided via keyword arguments. |
|
|
|
Any argument with an updated default will also be marked as not required |
|
if it was previously required. |
|
|
|
Returns a list of strings that were not consumed by the parser. |
|
""" |
|
|
|
for action in self._actions: |
|
if action.dest in kwargs: |
|
action.default = kwargs.pop(action.dest) |
|
action.required = False |
|
remaining_strings = [item for key, value in kwargs.items() for item in [f"--{key}", str(value)]] |
|
return remaining_strings |
|
|
|
|
|
def get_git_commit_hash(package_name): |
|
try: |
|
|
|
package = importlib.import_module(package_name) |
|
|
|
package_path = os.path.dirname(inspect.getfile(package)) |
|
|
|
|
|
git_repo_path = os.path.abspath(os.path.join(package_path, "..")) |
|
git_dir = os.path.join(git_repo_path, ".git") |
|
|
|
if os.path.isdir(git_dir): |
|
|
|
commit_hash = ( |
|
subprocess.check_output(["git", "rev-parse", "HEAD"], cwd=git_repo_path).strip().decode("utf-8") |
|
) |
|
return commit_hash |
|
else: |
|
return None |
|
except Exception as e: |
|
return f"Error: {str(e)}" |
|
|