|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import logging |
|
import os |
|
import re |
|
from argparse import Namespace |
|
from typing import Any, Optional |
|
|
|
import torch |
|
from lightning_utilities.core.imports import RequirementCache |
|
from typing_extensions import get_args |
|
|
|
from lightning_fabric.accelerators import CPUAccelerator, CUDAAccelerator, MPSAccelerator |
|
from lightning_fabric.plugins.precision.precision import _PRECISION_INPUT_STR, _PRECISION_INPUT_STR_ALIAS |
|
from lightning_fabric.strategies import STRATEGY_REGISTRY |
|
from lightning_fabric.utilities.consolidate_checkpoint import _process_cli_args |
|
from lightning_fabric.utilities.device_parser import _parse_gpu_ids |
|
from lightning_fabric.utilities.distributed import _suggested_max_num_threads |
|
from lightning_fabric.utilities.load import _load_distributed_checkpoint |
|
|
|
_log = logging.getLogger(__name__) |
|
|
|
_CLICK_AVAILABLE = RequirementCache("click") |
|
_LIGHTNING_SDK_AVAILABLE = RequirementCache("lightning_sdk") |
|
|
|
_SUPPORTED_ACCELERATORS = ("cpu", "gpu", "cuda", "mps", "tpu") |
|
|
|
|
|
def _get_supported_strategies() -> list[str]: |
|
"""Returns strategy choices from the registry, with the ones removed that are incompatible to be launched from the |
|
CLI or ones that require further configuration by the user.""" |
|
available_strategies = STRATEGY_REGISTRY.available_strategies() |
|
excluded = r".*(spawn|fork|notebook|xla|tpu|offload).*" |
|
return [strategy for strategy in available_strategies if not re.match(excluded, strategy)] |
|
|
|
|
|
if _CLICK_AVAILABLE: |
|
import click |
|
|
|
@click.group() |
|
def _main() -> None: |
|
pass |
|
|
|
@_main.command( |
|
"run", |
|
context_settings={ |
|
"ignore_unknown_options": True, |
|
}, |
|
) |
|
@click.argument( |
|
"script", |
|
type=click.Path(exists=True), |
|
) |
|
@click.option( |
|
"--accelerator", |
|
type=click.Choice(_SUPPORTED_ACCELERATORS), |
|
default=None, |
|
help="The hardware accelerator to run on.", |
|
) |
|
@click.option( |
|
"--strategy", |
|
type=click.Choice(_get_supported_strategies()), |
|
default=None, |
|
help="Strategy for how to run across multiple devices.", |
|
) |
|
@click.option( |
|
"--devices", |
|
type=str, |
|
default="1", |
|
help=( |
|
"Number of devices to run on (``int``), which devices to run on (``list`` or ``str``), or ``'auto'``." |
|
" The value applies per node." |
|
), |
|
) |
|
@click.option( |
|
"--num-nodes", |
|
"--num_nodes", |
|
type=int, |
|
default=1, |
|
help="Number of machines (nodes) for distributed execution.", |
|
) |
|
@click.option( |
|
"--node-rank", |
|
"--node_rank", |
|
type=int, |
|
default=0, |
|
help=( |
|
"The index of the machine (node) this command gets started on. Must be a number in the range" |
|
" 0, ..., num_nodes - 1." |
|
), |
|
) |
|
@click.option( |
|
"--main-address", |
|
"--main_address", |
|
type=str, |
|
default="127.0.0.1", |
|
help="The hostname or IP address of the main machine (usually the one with node_rank = 0).", |
|
) |
|
@click.option( |
|
"--main-port", |
|
"--main_port", |
|
type=int, |
|
default=29400, |
|
help="The main port to connect to the main machine.", |
|
) |
|
@click.option( |
|
"--precision", |
|
type=click.Choice(get_args(_PRECISION_INPUT_STR) + get_args(_PRECISION_INPUT_STR_ALIAS)), |
|
default=None, |
|
help=( |
|
"Double precision (``64-true`` or ``64``), full precision (``32-true`` or ``32``), " |
|
"half precision (``16-mixed`` or ``16``) or bfloat16 precision (``bf16-mixed`` or ``bf16``)" |
|
), |
|
) |
|
@click.argument("script_args", nargs=-1, type=click.UNPROCESSED) |
|
def _run(**kwargs: Any) -> None: |
|
"""Run a Lightning Fabric script. |
|
|
|
SCRIPT is the path to the Python script with the code to run. The script must contain a Fabric object. |
|
|
|
SCRIPT_ARGS are the remaining arguments that you can pass to the script itself and are expected to be parsed |
|
there. |
|
|
|
""" |
|
script_args = list(kwargs.pop("script_args", [])) |
|
main(args=Namespace(**kwargs), script_args=script_args) |
|
|
|
@_main.command( |
|
"consolidate", |
|
context_settings={ |
|
"ignore_unknown_options": True, |
|
}, |
|
) |
|
@click.argument( |
|
"checkpoint_folder", |
|
type=click.Path(exists=True), |
|
) |
|
@click.option( |
|
"--output_file", |
|
type=click.Path(exists=True), |
|
default=None, |
|
help=( |
|
"Path to the file where the converted checkpoint should be saved. The file should not already exist." |
|
" If no path is provided, the file will be saved next to the input checkpoint folder with the same name" |
|
" and a '.consolidated' suffix." |
|
), |
|
) |
|
def _consolidate(checkpoint_folder: str, output_file: Optional[str]) -> None: |
|
"""Convert a distributed/sharded checkpoint into a single file that can be loaded with `torch.load()`. |
|
|
|
Only supports FSDP sharded checkpoints at the moment. |
|
|
|
""" |
|
args = Namespace(checkpoint_folder=checkpoint_folder, output_file=output_file) |
|
config = _process_cli_args(args) |
|
checkpoint = _load_distributed_checkpoint(config.checkpoint_folder) |
|
torch.save(checkpoint, config.output_file) |
|
|
|
|
|
def _set_env_variables(args: Namespace) -> None: |
|
"""Set the environment variables for the new processes. |
|
|
|
The Fabric connector will parse the arguments set here. |
|
|
|
""" |
|
os.environ["LT_CLI_USED"] = "1" |
|
if args.accelerator is not None: |
|
os.environ["LT_ACCELERATOR"] = str(args.accelerator) |
|
if args.strategy is not None: |
|
os.environ["LT_STRATEGY"] = str(args.strategy) |
|
os.environ["LT_DEVICES"] = str(args.devices) |
|
os.environ["LT_NUM_NODES"] = str(args.num_nodes) |
|
if args.precision is not None: |
|
os.environ["LT_PRECISION"] = str(args.precision) |
|
|
|
|
|
def _get_num_processes(accelerator: str, devices: str) -> int: |
|
"""Parse the `devices` argument to determine how many processes need to be launched on the current machine.""" |
|
if accelerator == "gpu": |
|
parsed_devices = _parse_gpu_ids(devices, include_cuda=True, include_mps=True) |
|
elif accelerator == "cuda": |
|
parsed_devices = CUDAAccelerator.parse_devices(devices) |
|
elif accelerator == "mps": |
|
parsed_devices = MPSAccelerator.parse_devices(devices) |
|
elif accelerator == "tpu": |
|
raise ValueError("Launching processes for TPU through the CLI is not supported.") |
|
else: |
|
return CPUAccelerator.parse_devices(devices) |
|
return len(parsed_devices) if parsed_devices is not None else 0 |
|
|
|
|
|
def _torchrun_launch(args: Namespace, script_args: list[str]) -> None: |
|
"""This will invoke `torchrun` programmatically to launch the given script in new processes.""" |
|
import torch.distributed.run as torchrun |
|
|
|
num_processes = 1 if args.strategy == "dp" else _get_num_processes(args.accelerator, args.devices) |
|
|
|
torchrun_args = [ |
|
f"--nproc_per_node={num_processes}", |
|
f"--nnodes={args.num_nodes}", |
|
f"--node_rank={args.node_rank}", |
|
f"--master_addr={args.main_address}", |
|
f"--master_port={args.main_port}", |
|
args.script, |
|
] |
|
torchrun_args.extend(script_args) |
|
|
|
|
|
os.environ.setdefault("OMP_NUM_THREADS", str(_suggested_max_num_threads())) |
|
torchrun.main(torchrun_args) |
|
|
|
|
|
def main(args: Namespace, script_args: Optional[list[str]] = None) -> None: |
|
_set_env_variables(args) |
|
_torchrun_launch(args, script_args or []) |
|
|
|
|
|
if __name__ == "__main__": |
|
if not _CLICK_AVAILABLE: |
|
_log.error( |
|
"To use the Lightning Fabric CLI, you must have `click` installed." |
|
" Install it by running `pip install -U click`." |
|
) |
|
raise SystemExit(1) |
|
|
|
_run() |
|
|