File size: 2,992 Bytes
9c6594c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 |
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from deepspeed.runtime.config_utils import get_scalar_param
from .constants import *
VALID_VALUES = {
CHECKPOINT_TAG_VALIDATION: CHECKPOINT_TAG_VALIDATION_MODES,
CHECKPOINT_WRITER_TYPE: CHECKPOINT_WRITER_TYPES,
CHECKPOINT_DATA_PARALLEL: CHECKPOINT_DATA_PARALLEL_UNITS
}
CHECKPOINT_DEFAULT_DICT = {
CHECKPOINT_TAG_VALIDATION: CHECKPOINT_TAG_VALIDATION_DEFAULT,
CHECKPOINT_SERIALIZATION: CHECKPOINT_SERIALIZATION_DEFAULT,
CHECKPOINT_WRITER: CHECKPOINT_WRITER_DEFAULT
}
def _validate_config_values(config_name, config_dict, valid_values):
for key, value in config_dict.items():
if value is None:
continue
if key in valid_values.keys():
assert value in valid_values[key], \
f"{config_name} contains invalid value {value} for {key}, expecting one of {valid_values[key]}"
def _make_upper_case(value):
return value if value is None else value.upper()
def get_checkpoint_writer_config(param_dict):
writer_dict = param_dict.get(CHECKPOINT_WRITER, None)
if writer_dict is None:
return CHECKPOINT_WRITER_DEFAULT
writer_config = {
CHECKPOINT_WRITER_TYPE:
_make_upper_case(get_scalar_param(writer_dict, CHECKPOINT_WRITER_TYPE, CHECKPOINT_WRITER_TYPE_DEFAULT)),
CHECKPOINT_IO_BUFFER_SIZE:
get_scalar_param(writer_dict, CHECKPOINT_IO_BUFFER_SIZE, CHECKPOINT_IO_BUFFER_SIZE_DEFAULT),
CHECKPOINT_IO_BUFFER_DOUBLE:
get_scalar_param(writer_dict, CHECKPOINT_IO_BUFFER_DOUBLE, CHECKPOINT_IO_BUFFER_DOUBLE_DEFAULT),
CHECKPOINT_IO_STATISTICS:
get_scalar_param(writer_dict, CHECKPOINT_IO_STATISTICS, CHECKPOINT_IO_STATISTICS_DEFAULT),
CHECKPOINT_DATA_PARALLEL:
_make_upper_case(get_scalar_param(writer_dict, CHECKPOINT_DATA_PARALLEL, CHECKPOINT_DATA_PARALLEL_DEFAULT)),
CHECKPOINT_WRITER_DECOUPLED:
get_scalar_param(writer_dict, CHECKPOINT_WRITER_DECOUPLED, CHECKPOINT_WRITER_DECOUPLED_DEFAULT),
CHECKPOINT_IO_MULTIPLIER:
get_scalar_param(writer_dict, CHECKPOINT_IO_MULTIPLIER, CHECKPOINT_IO_MULTIPLIER_DEFAULT),
}
_validate_config_values(CHECKPOINT_WRITER, writer_config, VALID_VALUES)
return writer_config
def get_checkpoint_config(param_dict):
checkpoint_dict = param_dict.get(CHECKPOINT, None)
if checkpoint_dict is None:
return CHECKPOINT_DEFAULT_DICT
checkpoint_config = {
CHECKPOINT_TAG_VALIDATION:
get_scalar_param(checkpoint_dict, CHECKPOINT_TAG_VALIDATION, CHECKPOINT_TAG_VALIDATION_DEFAULT).upper(),
CHECKPOINT_SERIALIZATION:
get_scalar_param(checkpoint_dict, CHECKPOINT_SERIALIZATION, CHECKPOINT_SERIALIZATION_DEFAULT),
CHECKPOINT_WRITER:
get_checkpoint_writer_config(checkpoint_dict)
}
_validate_config_values(CHECKPOINT, checkpoint_config, VALID_VALUES)
return checkpoint_config
|