File size: 1,820 Bytes
9c6594c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

from deepspeed.runtime.model_checkpointing.constants import *
from deepspeed.runtime.model_checkpointing.utils import create_data_parallel_writer_config
from deepspeed.utils import logger

from .decoupled_checkpoint_engine import DecoupledCheckpointEngine
from .fast_checkpoint_engine import FastCheckpointEngine
from .torch_checkpoint_engine import TorchCheckpointEngine


def create_checkpoint_engine(config_params, groups, zero_stage, has_moe_layers, optimize_dp_state):
    if config_params is not None:
        if config_params.checkpoint_config[CHECKPOINT_WRITER] is not None:
            writer_config = config_params.checkpoint_config[CHECKPOINT_WRITER]
            dp_writer_config = create_data_parallel_writer_config(
                groups=groups,
                parallel_unit=writer_config[CHECKPOINT_DATA_PARALLEL],
                zero_stage=zero_stage,
                has_moe_layers=has_moe_layers)
            if writer_config[CHECKPOINT_WRITER_DECOUPLED]:
                return DecoupledCheckpointEngine(config_params, dp_writer_config, optimize_dp_state)
            else:
                return FastCheckpointEngine(config_params, dp_writer_config, optimize_dp_state)

        if config_params is not None and config_params.nebula_config.enabled:
            try:
                from .nebula_checkpoint_engine import NebulaCheckpointEngine
            except ImportError as err:
                logger.error(f"No torch_nebula was found! Will fall back to torch.save. Details: {err}")
                return TorchCheckpointEngine(config_params)
            else:
                return NebulaCheckpointEngine(config_params=config_params.nebula_config)

    return TorchCheckpointEngine(config_params)