|
|
|
|
|
|
|
|
|
|
|
from deepspeed.runtime.config_utils import get_scalar_param |
|
from .constants import * |
|
|
|
VALID_VALUES = { |
|
CHECKPOINT_TAG_VALIDATION: CHECKPOINT_TAG_VALIDATION_MODES, |
|
CHECKPOINT_WRITER_TYPE: CHECKPOINT_WRITER_TYPES, |
|
CHECKPOINT_DATA_PARALLEL: CHECKPOINT_DATA_PARALLEL_UNITS |
|
} |
|
|
|
CHECKPOINT_DEFAULT_DICT = { |
|
CHECKPOINT_TAG_VALIDATION: CHECKPOINT_TAG_VALIDATION_DEFAULT, |
|
CHECKPOINT_SERIALIZATION: CHECKPOINT_SERIALIZATION_DEFAULT, |
|
CHECKPOINT_WRITER: CHECKPOINT_WRITER_DEFAULT |
|
} |
|
|
|
|
|
def _validate_config_values(config_name, config_dict, valid_values): |
|
for key, value in config_dict.items(): |
|
if value is None: |
|
continue |
|
if key in valid_values.keys(): |
|
assert value in valid_values[key], \ |
|
f"{config_name} contains invalid value {value} for {key}, expecting one of {valid_values[key]}" |
|
|
|
|
|
def _make_upper_case(value): |
|
return value if value is None else value.upper() |
|
|
|
|
|
def get_checkpoint_writer_config(param_dict): |
|
writer_dict = param_dict.get(CHECKPOINT_WRITER, None) |
|
if writer_dict is None: |
|
return CHECKPOINT_WRITER_DEFAULT |
|
|
|
writer_config = { |
|
CHECKPOINT_WRITER_TYPE: |
|
_make_upper_case(get_scalar_param(writer_dict, CHECKPOINT_WRITER_TYPE, CHECKPOINT_WRITER_TYPE_DEFAULT)), |
|
CHECKPOINT_IO_BUFFER_SIZE: |
|
get_scalar_param(writer_dict, CHECKPOINT_IO_BUFFER_SIZE, CHECKPOINT_IO_BUFFER_SIZE_DEFAULT), |
|
CHECKPOINT_IO_BUFFER_DOUBLE: |
|
get_scalar_param(writer_dict, CHECKPOINT_IO_BUFFER_DOUBLE, CHECKPOINT_IO_BUFFER_DOUBLE_DEFAULT), |
|
CHECKPOINT_IO_STATISTICS: |
|
get_scalar_param(writer_dict, CHECKPOINT_IO_STATISTICS, CHECKPOINT_IO_STATISTICS_DEFAULT), |
|
CHECKPOINT_DATA_PARALLEL: |
|
_make_upper_case(get_scalar_param(writer_dict, CHECKPOINT_DATA_PARALLEL, CHECKPOINT_DATA_PARALLEL_DEFAULT)), |
|
CHECKPOINT_WRITER_DECOUPLED: |
|
get_scalar_param(writer_dict, CHECKPOINT_WRITER_DECOUPLED, CHECKPOINT_WRITER_DECOUPLED_DEFAULT), |
|
CHECKPOINT_IO_MULTIPLIER: |
|
get_scalar_param(writer_dict, CHECKPOINT_IO_MULTIPLIER, CHECKPOINT_IO_MULTIPLIER_DEFAULT), |
|
} |
|
_validate_config_values(CHECKPOINT_WRITER, writer_config, VALID_VALUES) |
|
|
|
return writer_config |
|
|
|
|
|
def get_checkpoint_config(param_dict): |
|
checkpoint_dict = param_dict.get(CHECKPOINT, None) |
|
if checkpoint_dict is None: |
|
return CHECKPOINT_DEFAULT_DICT |
|
|
|
checkpoint_config = { |
|
CHECKPOINT_TAG_VALIDATION: |
|
get_scalar_param(checkpoint_dict, CHECKPOINT_TAG_VALIDATION, CHECKPOINT_TAG_VALIDATION_DEFAULT).upper(), |
|
CHECKPOINT_SERIALIZATION: |
|
get_scalar_param(checkpoint_dict, CHECKPOINT_SERIALIZATION, CHECKPOINT_SERIALIZATION_DEFAULT), |
|
CHECKPOINT_WRITER: |
|
get_checkpoint_writer_config(checkpoint_dict) |
|
} |
|
|
|
_validate_config_values(CHECKPOINT, checkpoint_config, VALID_VALUES) |
|
|
|
return checkpoint_config |
|
|