|
|
|
|
|
|
|
|
|
|
|
from deepspeed.runtime.model_checkpointing.constants import * |
|
from deepspeed.runtime.model_checkpointing.utils import create_data_parallel_writer_config |
|
from deepspeed.utils import logger |
|
|
|
from .decoupled_checkpoint_engine import DecoupledCheckpointEngine |
|
from .fast_checkpoint_engine import FastCheckpointEngine |
|
from .torch_checkpoint_engine import TorchCheckpointEngine |
|
|
|
|
|
def create_checkpoint_engine(config_params, groups, zero_stage, has_moe_layers, optimize_dp_state): |
|
if config_params is not None: |
|
if config_params.checkpoint_config[CHECKPOINT_WRITER] is not None: |
|
writer_config = config_params.checkpoint_config[CHECKPOINT_WRITER] |
|
dp_writer_config = create_data_parallel_writer_config( |
|
groups=groups, |
|
parallel_unit=writer_config[CHECKPOINT_DATA_PARALLEL], |
|
zero_stage=zero_stage, |
|
has_moe_layers=has_moe_layers) |
|
if writer_config[CHECKPOINT_WRITER_DECOUPLED]: |
|
return DecoupledCheckpointEngine(config_params, dp_writer_config, optimize_dp_state) |
|
else: |
|
return FastCheckpointEngine(config_params, dp_writer_config, optimize_dp_state) |
|
|
|
if config_params is not None and config_params.nebula_config.enabled: |
|
try: |
|
from .nebula_checkpoint_engine import NebulaCheckpointEngine |
|
except ImportError as err: |
|
logger.error(f"No torch_nebula was found! Will fall back to torch.save. Details: {err}") |
|
return TorchCheckpointEngine(config_params) |
|
else: |
|
return NebulaCheckpointEngine(config_params=config_params.nebula_config) |
|
|
|
return TorchCheckpointEngine(config_params) |
|
|