jamtur01's picture
Upload folder using huggingface_hub
9c6594c verified
import logging
from argparse import ArgumentParser, Namespace
from pathlib import Path
import torch
from lightning_fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_3
from lightning_fabric.utilities.load import _METADATA_FILENAME, _load_distributed_checkpoint
_log = logging.getLogger(__name__)
def _parse_cli_args() -> Namespace:
parser = ArgumentParser(
description=(
"Converts a distributed/sharded checkpoint into a single file that can be loaded with `torch.load()`."
" Only supports FSDP sharded checkpoints at the moment."
),
)
parser.add_argument(
"checkpoint_folder",
type=str,
help=(
"Path to a checkpoint folder, containing the sharded checkpoint files saved using the"
" `torch.distributed.checkpoint` API."
),
)
parser.add_argument(
"--output_file",
type=str,
help=(
"Path to the file where the converted checkpoint should be saved. The file should not already exist."
" If no path is provided, the file will be saved next to the input checkpoint folder with the same name"
" and a '.consolidated' suffix."
),
)
return parser.parse_args()
def _process_cli_args(args: Namespace) -> Namespace:
if not _TORCH_GREATER_EQUAL_2_3:
_log.error("Processing distributed checkpoints requires PyTorch >= 2.3.")
exit(1)
checkpoint_folder = Path(args.checkpoint_folder)
if not checkpoint_folder.exists():
_log.error(f"The provided checkpoint folder does not exist: {checkpoint_folder}")
exit(1)
if not checkpoint_folder.is_dir():
_log.error(
f"The provided checkpoint path must be a folder, containing the checkpoint shards: {checkpoint_folder}"
)
exit(1)
if not (checkpoint_folder / _METADATA_FILENAME).is_file():
_log.error(
"Only FSDP-sharded checkpoints saved with Lightning are supported for consolidation. The provided folder"
f" is not in that format: {checkpoint_folder}"
)
exit(1)
if args.output_file is None:
output_file = checkpoint_folder.with_suffix(checkpoint_folder.suffix + ".consolidated")
else:
output_file = Path(args.output_file)
if output_file.exists():
_log.error(
"The path for the converted checkpoint already exists. Choose a different path by providing"
f" `--output_file` or move/delete the file first: {output_file}"
)
exit(1)
return Namespace(checkpoint_folder=checkpoint_folder, output_file=output_file)
if __name__ == "__main__":
args = _parse_cli_args()
config = _process_cli_args(args)
checkpoint = _load_distributed_checkpoint(config.checkpoint_folder)
torch.save(checkpoint, config.output_file)