jamtur01's picture
Upload folder using huggingface_hub
9c6594c verified
# Copyright The Lightning AI team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import re
from argparse import Namespace
from typing import Any, Optional
import torch
from lightning_utilities.core.imports import RequirementCache
from typing_extensions import get_args
from lightning_fabric.accelerators import CPUAccelerator, CUDAAccelerator, MPSAccelerator
from lightning_fabric.plugins.precision.precision import _PRECISION_INPUT_STR, _PRECISION_INPUT_STR_ALIAS
from lightning_fabric.strategies import STRATEGY_REGISTRY
from lightning_fabric.utilities.consolidate_checkpoint import _process_cli_args
from lightning_fabric.utilities.device_parser import _parse_gpu_ids
from lightning_fabric.utilities.distributed import _suggested_max_num_threads
from lightning_fabric.utilities.load import _load_distributed_checkpoint
_log = logging.getLogger(__name__)
_CLICK_AVAILABLE = RequirementCache("click")
_LIGHTNING_SDK_AVAILABLE = RequirementCache("lightning_sdk")
_SUPPORTED_ACCELERATORS = ("cpu", "gpu", "cuda", "mps", "tpu")
def _get_supported_strategies() -> list[str]:
"""Returns strategy choices from the registry, with the ones removed that are incompatible to be launched from the
CLI or ones that require further configuration by the user."""
available_strategies = STRATEGY_REGISTRY.available_strategies()
excluded = r".*(spawn|fork|notebook|xla|tpu|offload).*"
return [strategy for strategy in available_strategies if not re.match(excluded, strategy)]
if _CLICK_AVAILABLE:
import click
@click.group()
def _main() -> None:
pass
@_main.command(
"run",
context_settings={
"ignore_unknown_options": True,
},
)
@click.argument(
"script",
type=click.Path(exists=True),
)
@click.option(
"--accelerator",
type=click.Choice(_SUPPORTED_ACCELERATORS),
default=None,
help="The hardware accelerator to run on.",
)
@click.option(
"--strategy",
type=click.Choice(_get_supported_strategies()),
default=None,
help="Strategy for how to run across multiple devices.",
)
@click.option(
"--devices",
type=str,
default="1",
help=(
"Number of devices to run on (``int``), which devices to run on (``list`` or ``str``), or ``'auto'``."
" The value applies per node."
),
)
@click.option(
"--num-nodes",
"--num_nodes",
type=int,
default=1,
help="Number of machines (nodes) for distributed execution.",
)
@click.option(
"--node-rank",
"--node_rank",
type=int,
default=0,
help=(
"The index of the machine (node) this command gets started on. Must be a number in the range"
" 0, ..., num_nodes - 1."
),
)
@click.option(
"--main-address",
"--main_address",
type=str,
default="127.0.0.1",
help="The hostname or IP address of the main machine (usually the one with node_rank = 0).",
)
@click.option(
"--main-port",
"--main_port",
type=int,
default=29400,
help="The main port to connect to the main machine.",
)
@click.option(
"--precision",
type=click.Choice(get_args(_PRECISION_INPUT_STR) + get_args(_PRECISION_INPUT_STR_ALIAS)),
default=None,
help=(
"Double precision (``64-true`` or ``64``), full precision (``32-true`` or ``32``), "
"half precision (``16-mixed`` or ``16``) or bfloat16 precision (``bf16-mixed`` or ``bf16``)"
),
)
@click.argument("script_args", nargs=-1, type=click.UNPROCESSED)
def _run(**kwargs: Any) -> None:
"""Run a Lightning Fabric script.
SCRIPT is the path to the Python script with the code to run. The script must contain a Fabric object.
SCRIPT_ARGS are the remaining arguments that you can pass to the script itself and are expected to be parsed
there.
"""
script_args = list(kwargs.pop("script_args", []))
main(args=Namespace(**kwargs), script_args=script_args)
@_main.command(
"consolidate",
context_settings={
"ignore_unknown_options": True,
},
)
@click.argument(
"checkpoint_folder",
type=click.Path(exists=True),
)
@click.option(
"--output_file",
type=click.Path(exists=True),
default=None,
help=(
"Path to the file where the converted checkpoint should be saved. The file should not already exist."
" If no path is provided, the file will be saved next to the input checkpoint folder with the same name"
" and a '.consolidated' suffix."
),
)
def _consolidate(checkpoint_folder: str, output_file: Optional[str]) -> None:
"""Convert a distributed/sharded checkpoint into a single file that can be loaded with `torch.load()`.
Only supports FSDP sharded checkpoints at the moment.
"""
args = Namespace(checkpoint_folder=checkpoint_folder, output_file=output_file)
config = _process_cli_args(args)
checkpoint = _load_distributed_checkpoint(config.checkpoint_folder)
torch.save(checkpoint, config.output_file)
def _set_env_variables(args: Namespace) -> None:
"""Set the environment variables for the new processes.
The Fabric connector will parse the arguments set here.
"""
os.environ["LT_CLI_USED"] = "1"
if args.accelerator is not None:
os.environ["LT_ACCELERATOR"] = str(args.accelerator)
if args.strategy is not None:
os.environ["LT_STRATEGY"] = str(args.strategy)
os.environ["LT_DEVICES"] = str(args.devices)
os.environ["LT_NUM_NODES"] = str(args.num_nodes)
if args.precision is not None:
os.environ["LT_PRECISION"] = str(args.precision)
def _get_num_processes(accelerator: str, devices: str) -> int:
"""Parse the `devices` argument to determine how many processes need to be launched on the current machine."""
if accelerator == "gpu":
parsed_devices = _parse_gpu_ids(devices, include_cuda=True, include_mps=True)
elif accelerator == "cuda":
parsed_devices = CUDAAccelerator.parse_devices(devices)
elif accelerator == "mps":
parsed_devices = MPSAccelerator.parse_devices(devices)
elif accelerator == "tpu":
raise ValueError("Launching processes for TPU through the CLI is not supported.")
else:
return CPUAccelerator.parse_devices(devices)
return len(parsed_devices) if parsed_devices is not None else 0
def _torchrun_launch(args: Namespace, script_args: list[str]) -> None:
"""This will invoke `torchrun` programmatically to launch the given script in new processes."""
import torch.distributed.run as torchrun
num_processes = 1 if args.strategy == "dp" else _get_num_processes(args.accelerator, args.devices)
torchrun_args = [
f"--nproc_per_node={num_processes}",
f"--nnodes={args.num_nodes}",
f"--node_rank={args.node_rank}",
f"--master_addr={args.main_address}",
f"--master_port={args.main_port}",
args.script,
]
torchrun_args.extend(script_args)
# set a good default number of threads for OMP to avoid warnings being emitted to the user
os.environ.setdefault("OMP_NUM_THREADS", str(_suggested_max_num_threads()))
torchrun.main(torchrun_args)
def main(args: Namespace, script_args: Optional[list[str]] = None) -> None:
_set_env_variables(args)
_torchrun_launch(args, script_args or [])
if __name__ == "__main__":
if not _CLICK_AVAILABLE: # pragma: no cover
_log.error(
"To use the Lightning Fabric CLI, you must have `click` installed."
" Install it by running `pip install -U click`."
)
raise SystemExit(1)
_run()