|
|
|
|
|
|
|
|
|
|
|
from .constants import * |
|
import copy |
|
from ..config_utils import get_scalar_param |
|
|
|
|
|
|
|
|
|
|
|
def get_data_efficiency_config(param_dict): |
|
output = {} |
|
output[DATA_EFFICIENCY_ENABLED] = get_data_efficiency_enabled(param_dict) |
|
output[DATA_EFFICIENCY_SEED] = get_data_efficiency_seed(param_dict) |
|
if DATA_EFFICIENCY not in param_dict.keys(): |
|
param_dict[DATA_EFFICIENCY] = {} |
|
sub_param_dict = param_dict[DATA_EFFICIENCY] |
|
output[DATA_SAMPLING] = get_data_sampling(sub_param_dict) |
|
output[DATA_ROUTING] = get_data_routing(sub_param_dict) |
|
return output |
|
|
|
|
|
def get_data_efficiency_enabled(param_dict): |
|
if DATA_EFFICIENCY in param_dict.keys(): |
|
return get_scalar_param(param_dict[DATA_EFFICIENCY], DATA_EFFICIENCY_ENABLED, DATA_EFFICIENCY_ENABLED_DEFAULT) |
|
else: |
|
return False |
|
|
|
|
|
def get_data_efficiency_seed(param_dict): |
|
if DATA_EFFICIENCY in param_dict.keys(): |
|
return get_scalar_param(param_dict[DATA_EFFICIENCY], DATA_EFFICIENCY_SEED, DATA_EFFICIENCY_SEED_DEFAULT) |
|
else: |
|
return DATA_EFFICIENCY_SEED_DEFAULT |
|
|
|
|
|
def get_data_sampling(param_dict): |
|
sub_param_dict = param_dict.get(DATA_SAMPLING, {}) |
|
output = copy.copy(sub_param_dict) |
|
output[DATA_SAMPLING_ENABLED] = get_data_sampling_enabled(param_dict) |
|
output[DATA_SAMPLING_NUM_EPOCHS] = get_data_sampling_num_epochs(param_dict) |
|
output[DATA_SAMPLING_NUM_WORKERS] = get_data_sampling_num_workers(param_dict) |
|
output[DATA_SAMPLING_PIN_MEMORY] = get_data_sampling_pin_memory(param_dict) |
|
output[CURRICULUM_LEARNING] = get_curriculum_learning(sub_param_dict) |
|
output[DYNAMIC_BATCHING] = get_dynamic_batching(sub_param_dict) |
|
return output |
|
|
|
|
|
def get_data_sampling_enabled(param_dict): |
|
if DATA_SAMPLING in param_dict.keys(): |
|
return get_scalar_param(param_dict[DATA_SAMPLING], DATA_SAMPLING_ENABLED, DATA_SAMPLING_ENABLED_DEFAULT) |
|
else: |
|
return False |
|
|
|
|
|
def get_data_sampling_num_epochs(param_dict): |
|
if DATA_SAMPLING in param_dict.keys(): |
|
return get_scalar_param(param_dict[DATA_SAMPLING], DATA_SAMPLING_NUM_EPOCHS, DATA_SAMPLING_NUM_EPOCHS_DEFAULT) |
|
else: |
|
return DATA_SAMPLING_NUM_EPOCHS_DEFAULT |
|
|
|
|
|
def get_data_sampling_num_workers(param_dict): |
|
if DATA_SAMPLING in param_dict.keys(): |
|
return get_scalar_param(param_dict[DATA_SAMPLING], DATA_SAMPLING_NUM_WORKERS, |
|
DATA_SAMPLING_NUM_WORKERS_DEFAULT) |
|
else: |
|
return DATA_SAMPLING_NUM_WORKERS_DEFAULT |
|
|
|
|
|
def get_data_sampling_pin_memory(param_dict): |
|
if DATA_SAMPLING in param_dict.keys(): |
|
return get_scalar_param(param_dict[DATA_SAMPLING], DATA_SAMPLING_PIN_MEMORY, DATA_SAMPLING_PIN_MEMORY_DEFAULT) |
|
else: |
|
return DATA_SAMPLING_PIN_MEMORY_DEFAULT |
|
|
|
|
|
def get_curriculum_learning(param_dict): |
|
output = {} |
|
output[CURRICULUM_LEARNING_ENABLED] = get_curriculum_learning_enabled(param_dict) |
|
if CURRICULUM_LEARNING not in param_dict.keys(): |
|
param_dict[CURRICULUM_LEARNING] = {} |
|
sub_param_dict = param_dict[CURRICULUM_LEARNING] |
|
if output[CURRICULUM_LEARNING_ENABLED]: |
|
assert CURRICULUM_LEARNING_METRICS in sub_param_dict.keys( |
|
), f"Curriculum learning is enabled, {CURRICULUM_LEARNING_METRICS} must be specified" |
|
for key, val in get_curriculum_learning_params(param_dict).items(): |
|
output[key] = val |
|
return output |
|
|
|
|
|
def get_dynamic_batching(param_dict): |
|
output = copy.copy(param_dict.get(DYNAMIC_BATCHING, {})) |
|
output[DYNAMIC_BATCHING_ENABLED] = bool(output.get(DYNAMIC_BATCHING_ENABLED, DYNAMIC_BATCHING_ENABLED_DEFAULT)) |
|
output[DYNAMIC_BATCHING_LR_SCALING_METHOD] = str( |
|
output.get(DYNAMIC_BATCHING_LR_SCALING_METHOD, DYNAMIC_BATCHING_LR_SCALING_METHOD_DEFAULT)) |
|
output[DYNAMIC_BATCHING_MIN_BATCH_SIZE] = int( |
|
output.get(DYNAMIC_BATCHING_MIN_BATCH_SIZE, DYNAMIC_BATCHING_MIN_BATCH_SIZE_DEFAULT)) |
|
output[DYNAMIC_BATCHING_MAX_BATCH_SIZE] = int(output[DYNAMIC_BATCHING_MAX_BATCH_SIZE]) \ |
|
if DYNAMIC_BATCHING_MAX_BATCH_SIZE in output.keys() \ |
|
else DYNAMIC_BATCHING_MAX_BATCH_SIZE_DEFAULT |
|
output[DYNAMIC_BATCHING_SEQUENCE_PICKING_ORDER] = str( |
|
output.get(DYNAMIC_BATCHING_SEQUENCE_PICKING_ORDER, DYNAMIC_BATCHING_SEQUENCE_PICKING_ORDER_DEFAULT)) |
|
if output[DYNAMIC_BATCHING_ENABLED]: |
|
assert DYNAMIC_BATCHING_MAX_TOKENS in output.keys( |
|
), f"Dynamic batching is enabled, so {DYNAMIC_BATCHING_MAX_TOKENS} must be specified" |
|
output[DYNAMIC_BATCHING_MAX_TOKENS] = int(output[DYNAMIC_BATCHING_MAX_TOKENS]) |
|
output[DYNAMIC_BATCHING_VERBOSE] = bool(output.get(DYNAMIC_BATCHING_VERBOSE, False)) |
|
return output |
|
|
|
|
|
def get_curriculum_learning_enabled(param_dict): |
|
if CURRICULUM_LEARNING in param_dict.keys(): |
|
return get_scalar_param(param_dict[CURRICULUM_LEARNING], CURRICULUM_LEARNING_ENABLED, |
|
CURRICULUM_LEARNING_ENABLED_DEFAULT) |
|
else: |
|
return False |
|
|
|
|
|
def get_curriculum_learning_params(param_dict): |
|
if CURRICULUM_LEARNING in param_dict.keys(): |
|
curriculum_learning_params = copy.copy(param_dict[CURRICULUM_LEARNING]) |
|
curriculum_learning_params.pop(CURRICULUM_LEARNING_ENABLED) |
|
return curriculum_learning_params |
|
else: |
|
return {} |
|
|
|
|
|
def get_curriculum_enabled_legacy(param_dict): |
|
if CURRICULUM_LEARNING_LEGACY in param_dict.keys(): |
|
return get_scalar_param(param_dict[CURRICULUM_LEARNING_LEGACY], CURRICULUM_ENABLED_LEGACY, |
|
CURRICULUM_ENABLED_DEFAULT_LEGACY) |
|
else: |
|
return False |
|
|
|
|
|
def get_curriculum_params_legacy(param_dict): |
|
if CURRICULUM_LEARNING_LEGACY in param_dict.keys(): |
|
curriculum_params = copy.copy(param_dict[CURRICULUM_LEARNING_LEGACY]) |
|
curriculum_params.pop(CURRICULUM_ENABLED_LEGACY) |
|
return curriculum_params |
|
else: |
|
return False |
|
|
|
|
|
def get_data_routing(param_dict): |
|
output = {} |
|
output[DATA_ROUTING_ENABLED] = get_data_routing_enabled(param_dict) |
|
if DATA_ROUTING not in param_dict.keys(): |
|
param_dict[DATA_ROUTING] = {} |
|
sub_param_dict = param_dict[DATA_ROUTING] |
|
output[RANDOM_LTD] = get_random_ltd(sub_param_dict) |
|
|
|
return output |
|
|
|
|
|
def get_data_routing_enabled(param_dict): |
|
if DATA_ROUTING in param_dict.keys(): |
|
return get_scalar_param(param_dict[DATA_ROUTING], DATA_ROUTING_ENABLED, DATA_ROUTING_ENABLED_DEFAULT) |
|
else: |
|
return False |
|
|
|
|
|
def get_random_ltd(param_dict): |
|
output = {} |
|
output[RANDOM_LTD_ENABLED] = RANDOM_LTD_ENABLED_DEFAULT |
|
output[RANDOM_LTD_LAYER_TOKEN_LR_SCHEDULE] = {} |
|
output[RANDOM_LTD_LAYER_TOKEN_LR_SCHEDULE][ |
|
RANDOM_LTD_LAYER_TOKEN_LR_ENABLED] = RANDOM_LTD_LAYER_TOKEN_LR_ENABLED_DEFAULT |
|
if get_random_ltd_enabled(param_dict): |
|
output[RANDOM_LTD_ENABLED] = get_random_ltd_enabled(param_dict) |
|
for key, val in get_random_ltd_params(param_dict).items(): |
|
output[key] = val |
|
return output |
|
|
|
|
|
def get_random_ltd_enabled(param_dict): |
|
if RANDOM_LTD in param_dict.keys(): |
|
return get_scalar_param(param_dict[RANDOM_LTD], RANDOM_LTD_ENABLED, RANDOM_LTD_ENABLED_DEFAULT) |
|
else: |
|
return False |
|
|
|
|
|
def get_random_ltd_params(param_dict): |
|
if RANDOM_LTD in param_dict.keys(): |
|
random_ltd_params = copy.copy(param_dict[RANDOM_LTD]) |
|
random_ltd_params.pop(RANDOM_LTD_ENABLED) |
|
return random_ltd_params |
|
else: |
|
return {} |
|
|