|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
import sys |
|
import warnings |
|
|
|
from accelerate.commands.launch import launch_command, launch_command_parser |
|
|
|
from .scripts.chat import main as chat_main |
|
from .scripts.chat import make_parser as make_chat_parser |
|
from .scripts.dpo import make_parser as make_dpo_parser |
|
from .scripts.env import print_env |
|
from .scripts.grpo import make_parser as make_grpo_parser |
|
from .scripts.kto import make_parser as make_kto_parser |
|
from .scripts.sft import make_parser as make_sft_parser |
|
from .scripts.utils import TrlParser |
|
from .scripts.vllm_serve import main as vllm_serve_main |
|
from .scripts.vllm_serve import make_parser as make_vllm_serve_parser |
|
|
|
|
|
def main(): |
|
parser = TrlParser(prog="TRL CLI", usage="trl", allow_abbrev=False) |
|
|
|
|
|
subparsers = parser.add_subparsers(help="available commands", dest="command", parser_class=TrlParser) |
|
|
|
|
|
make_chat_parser(subparsers) |
|
make_dpo_parser(subparsers) |
|
subparsers.add_parser("env", help="Print the environment information") |
|
make_grpo_parser(subparsers) |
|
make_kto_parser(subparsers) |
|
make_sft_parser(subparsers) |
|
make_vllm_serve_parser(subparsers) |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
if args.command == "chat": |
|
(chat_args,) = parser.parse_args_and_config() |
|
chat_main(chat_args) |
|
|
|
if args.command == "dpo": |
|
|
|
dpo_training_script = os.path.join(os.path.dirname(os.path.abspath(__file__)), "scripts", "dpo.py") |
|
args = launch_command_parser().parse_args([dpo_training_script]) |
|
|
|
|
|
args.training_script_args = sys.argv[2:] |
|
launch_command(args) |
|
|
|
elif args.command == "env": |
|
print_env() |
|
|
|
elif args.command == "grpo": |
|
|
|
grpo_training_script = os.path.join(os.path.dirname(os.path.abspath(__file__)), "scripts", "grpo.py") |
|
args = launch_command_parser().parse_args([grpo_training_script]) |
|
|
|
|
|
args.training_script_args = sys.argv[2:] |
|
launch_command(args) |
|
|
|
elif args.command == "kto": |
|
|
|
kto_training_script = os.path.join(os.path.dirname(os.path.abspath(__file__)), "scripts", "kto.py") |
|
args = launch_command_parser().parse_args([kto_training_script]) |
|
|
|
|
|
args.training_script_args = sys.argv[2:] |
|
launch_command(args) |
|
|
|
elif args.command == "sft": |
|
|
|
sft_training_script = os.path.join(os.path.dirname(os.path.abspath(__file__)), "scripts", "sft.py") |
|
args = launch_command_parser().parse_args([sft_training_script]) |
|
|
|
|
|
args.training_script_args = sys.argv[2:] |
|
launch_command(args) |
|
|
|
elif args.command == "vllm-serve": |
|
(script_args,) = parser.parse_args_and_config() |
|
|
|
|
|
|
|
|
|
if script_args.tensor_parallel_size == 1 and script_args.data_parallel_size > 1: |
|
warnings.warn( |
|
"Detected configuration: tensor_parallel_size=1 and data_parallel_size>1. This setup is known to " |
|
"cause a crash when using the `trl vllm-serve` CLI entry point. As a workaround, please run the " |
|
"server using the module path instead: `python -m trl.scripts.vllm_serve`", |
|
RuntimeWarning, |
|
) |
|
|
|
vllm_serve_main(script_args) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|