File size: 6,580 Bytes
2f5127c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import importlib.resources as resources
import os
import sys
import warnings

from accelerate.commands.launch import launch_command, launch_command_parser

from .scripts.dpo import make_parser as make_dpo_parser
from .scripts.env import print_env
from .scripts.grpo import make_parser as make_grpo_parser
from .scripts.kto import make_parser as make_kto_parser
from .scripts.sft import make_parser as make_sft_parser
from .scripts.utils import TrlParser
from .scripts.vllm_serve import main as vllm_serve_main
from .scripts.vllm_serve import make_parser as make_vllm_serve_parser


def main():
    parser = TrlParser(prog="TRL CLI", usage="trl", allow_abbrev=False)

    # Add the subparsers
    subparsers = parser.add_subparsers(help="available commands", dest="command", parser_class=TrlParser)

    # Add the subparsers for every script
    make_dpo_parser(subparsers)
    subparsers.add_parser("env", help="Print the environment information")
    make_grpo_parser(subparsers)
    make_kto_parser(subparsers)
    make_sft_parser(subparsers)
    make_vllm_serve_parser(subparsers)

    # Parse the arguments; the remaining ones (`launch_args`) are passed to the 'accelerate launch' subparser.
    # Duplicates may occur if the same argument is provided in both the config file and CLI.
    # For example: launch_args = `["--num_processes", "4", "--num_processes", "8"]`.
    # Deduplication and precedence (CLI over config) are handled later by launch_command_parser.
    args, launch_args = parser.parse_args_and_config(return_remaining_strings=True)

    # Replace `--accelerate_config foo` with `--config_file trl/accelerate_configs/foo.yaml` if it is present in the
    # launch_args. It allows the user to use predefined accelerate configs from the `trl` package.
    if "--accelerate_config" in launch_args:
        # Get the index of the '--accelerate_config' argument and the corresponding config name
        config_index = launch_args.index("--accelerate_config")
        config_name = launch_args[config_index + 1]

        # If the config_name correspond to a path in the filesystem, we don't want to override it
        if os.path.isfile(config_name):
            accelerate_config_path = config_name
        elif resources.files("trl.accelerate_configs").joinpath(f"{config_name}.yaml").exists():
            # Get the predefined accelerate config path from the package resources
            accelerate_config_path = resources.files("trl.accelerate_configs").joinpath(f"{config_name}.yaml")
        else:
            raise ValueError(
                f"Accelerate config {config_name} is neither a file nor a valid config in the `trl` package. "
                "Please provide a valid config name or a path to a config file."
            )

        # Remove '--accelerate_config' and its corresponding config name
        launch_args.pop(config_index)
        launch_args.pop(config_index)

        # Insert '--config_file' and the absolute path to the front of the list
        launch_args = ["--config_file", str(accelerate_config_path)] + launch_args

    if args.command == "dpo":
        # Get the default args for the launch command
        dpo_training_script = resources.files("trl.scripts").joinpath("dpo.py")
        args = launch_command_parser().parse_args([str(dpo_training_script)])

        # Feed the args to the launch command
        args.training_script_args = sys.argv[2:]  # remove "trl" and "dpo"
        launch_command(args)  # launch training

    elif args.command == "env":
        print_env()

    elif args.command == "grpo":
        # Get the default args for the launch command
        grpo_training_script = resources.files("trl.scripts").joinpath("grpo.py")
        args = launch_command_parser().parse_args([str(grpo_training_script)])

        # Feed the args to the launch command
        args.training_script_args = sys.argv[2:]  # remove "trl" and "grpo"
        launch_command(args)  # launch training

    elif args.command == "kto":
        # Get the default args for the launch command
        kto_training_script = resources.files("trl.scripts").joinpath("kto.py")
        args = launch_command_parser().parse_args([str(kto_training_script)])

        # Feed the args to the launch command
        args.training_script_args = sys.argv[2:]  # remove "trl" and "kto"
        launch_command(args)  # launch training

    elif args.command == "sft":
        # Get the path to the training script
        sft_training_script = resources.files("trl.scripts").joinpath("sft.py")

        # This simulates running: `accelerate launch <launch args> sft.py <training script args>`.
        # Note that the training script args may include launch-related arguments (e.g., `--num_processes`),
        # but we rely on the script to ignore any that don't apply to it.
        training_script_args = sys.argv[2:]  # Remove "trl" and "sft"
        args = launch_command_parser().parse_args(launch_args + [str(sft_training_script)] + training_script_args)
        launch_command(args)  # launch training

    elif args.command == "vllm-serve":
        (script_args,) = parser.parse_args_and_config()

        # Known issue: Using DeepSpeed with tensor_parallel_size=1 and data_parallel_size>1 may cause a crash when
        # launched via the CLI. Suggest running the module directly.
        # More information: https://github.com/vllm-project/vllm/issues/17079
        if script_args.tensor_parallel_size == 1 and script_args.data_parallel_size > 1:
            warnings.warn(
                "Detected configuration: tensor_parallel_size=1 and data_parallel_size>1. This setup is known to "
                "cause a crash when using the `trl vllm-serve` CLI entry point. As a workaround, please run the "
                "server using the module path instead: `python -m trl.scripts.vllm_serve`",
                RuntimeWarning,
            )

        vllm_serve_main(script_args)


if __name__ == "__main__":
    main()