File size: 2,870 Bytes
9c6594c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import logging
from argparse import ArgumentParser, Namespace
from pathlib import Path

import torch

from lightning_fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_3
from lightning_fabric.utilities.load import _METADATA_FILENAME, _load_distributed_checkpoint

_log = logging.getLogger(__name__)


def _parse_cli_args() -> Namespace:
    parser = ArgumentParser(
        description=(
            "Converts a distributed/sharded checkpoint into a single file that can be loaded with `torch.load()`."
            " Only supports FSDP sharded checkpoints at the moment."
        ),
    )
    parser.add_argument(
        "checkpoint_folder",
        type=str,
        help=(
            "Path to a checkpoint folder, containing the sharded checkpoint files saved using the"
            " `torch.distributed.checkpoint` API."
        ),
    )
    parser.add_argument(
        "--output_file",
        type=str,
        help=(
            "Path to the file where the converted checkpoint should be saved. The file should not already exist."
            " If no path is provided, the file will be saved next to the input checkpoint folder with the same name"
            " and a '.consolidated' suffix."
        ),
    )
    return parser.parse_args()


def _process_cli_args(args: Namespace) -> Namespace:
    if not _TORCH_GREATER_EQUAL_2_3:
        _log.error("Processing distributed checkpoints requires PyTorch >= 2.3.")
        exit(1)

    checkpoint_folder = Path(args.checkpoint_folder)
    if not checkpoint_folder.exists():
        _log.error(f"The provided checkpoint folder does not exist: {checkpoint_folder}")
        exit(1)
    if not checkpoint_folder.is_dir():
        _log.error(
            f"The provided checkpoint path must be a folder, containing the checkpoint shards: {checkpoint_folder}"
        )
        exit(1)
    if not (checkpoint_folder / _METADATA_FILENAME).is_file():
        _log.error(
            "Only FSDP-sharded checkpoints saved with Lightning are supported for consolidation. The provided folder"
            f" is not in that format: {checkpoint_folder}"
        )
        exit(1)

    if args.output_file is None:
        output_file = checkpoint_folder.with_suffix(checkpoint_folder.suffix + ".consolidated")
    else:
        output_file = Path(args.output_file)
    if output_file.exists():
        _log.error(
            "The path for the converted checkpoint already exists. Choose a different path by providing"
            f" `--output_file` or move/delete the file first: {output_file}"
        )
        exit(1)

    return Namespace(checkpoint_folder=checkpoint_folder, output_file=output_file)


if __name__ == "__main__":
    args = _parse_cli_args()
    config = _process_cli_args(args)
    checkpoint = _load_distributed_checkpoint(config.checkpoint_folder)
    torch.save(checkpoint, config.output_file)