|
|
|
|
|
|
|
|
|
|
|
import os |
|
import re |
|
import stat |
|
import torch |
|
import hashlib |
|
from collections import defaultdict, OrderedDict, deque |
|
from shutil import copyfile |
|
import gc |
|
|
|
from torch.nn.modules import Module |
|
from torch.nn.parameter import Parameter |
|
from torch.optim import Optimizer |
|
from torch.optim.lr_scheduler import _LRScheduler |
|
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors |
|
from contextlib import contextmanager |
|
|
|
from typing import Callable, Dict, Union, Iterable, Container |
|
|
|
import deepspeed |
|
|
|
from deepspeed import comm as dist |
|
from deepspeed.runtime.utils import see_memory_usage, DummyOptim |
|
from .zero.offload_config import OffloadDeviceEnum, OffloadStateTypeEnum |
|
from deepspeed.runtime.zero.stage_1_and_2 import DeepSpeedZeroOptimizer |
|
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus |
|
from deepspeed.runtime.zero.utils import is_zero_supported_optimizer, ZeRORuntimeException |
|
from deepspeed.runtime.zero.parameter_offload import DeepSpeedZeRoOffload |
|
from deepspeed.runtime.zero.config import ZERO_OPTIMIZATION |
|
|
|
from deepspeed.runtime.fp16.fused_optimizer import FP16_Optimizer |
|
from deepspeed.runtime.fp16.unfused_optimizer import FP16_UnfusedOptimizer |
|
from deepspeed.runtime.bf16_optimizer import BF16_Optimizer |
|
|
|
from deepspeed.linear.optimized_linear import LoRAOptimizedLinear |
|
from deepspeed.module_inject.layers import GatherReplacedLayerParams, configure_tensor_parallel_runtime |
|
from deepspeed.runtime.config import DEEPSPEED_OPTIMIZERS, \ |
|
ADAGRAD_OPTIMIZER, ADAM_OPTIMIZER, ADAMW_OPTIMIZER, LAMB_OPTIMIZER, ONEBIT_ADAM_OPTIMIZER, ONEBIT_LAMB_OPTIMIZER, \ |
|
TORCH_ADAM_PARAM, ADAM_W_MODE, ADAM_W_MODE_DEFAULT, ZERO_ONE_ADAM_OPTIMIZER, MUADAM_OPTIMIZER, MUADAMW_OPTIMIZER, \ |
|
MUSGD_OPTIMIZER, LION_OPTIMIZER |
|
|
|
from deepspeed.runtime.model_checkpointing.constants import ValidationMode, \ |
|
CHECKPOINT_TAG_VALIDATION, CHECKPOINT_WRITER, CHECKPOINT_SERIALIZATION |
|
|
|
from deepspeed.runtime.dataloader import DeepSpeedDataLoader |
|
from deepspeed.runtime.constants import \ |
|
ROUTE_TRAIN, ROUTE_PREDICT, ROUTE_EVAL, \ |
|
PLD_THETA, PLD_GAMMA, BFLOAT16, FP16, AMP, GRADIENT_ACCUMULATION_STEPS, \ |
|
DATA_PARALLEL_GROUP, GLOBAL_RANK |
|
from deepspeed.runtime.zero.config import ZeroStageEnum |
|
from deepspeed.compression import compression_scheduler |
|
from deepspeed.compression.constants import \ |
|
WEIGHT_QUANTIZE_IN_FORWARD_ENABLED, \ |
|
WEIGHT_QUANTIZATION, SHARED_PARAMETERS, \ |
|
WEIGHT_QUANTIZE_ENABLED, \ |
|
WEIGHT_QUANTIZE_GROUPS, \ |
|
WEIGHT_QUANTIZE_FP16_MIXED_QUANTIZE, \ |
|
WEIGHT_QUANTIZE_CHANGE_RATIO, \ |
|
WEIGHT_QUANTIZE_TYPE, \ |
|
WEIGHT_QUANTIZE_ROUNDING, \ |
|
WEIGHT_QUANTIZE_VERBOSE, \ |
|
WEIGHT_QUANTIZE_KERNEL |
|
from deepspeed.checkpoint.constants import OPTIMIZER_STATE_DICT, FROZEN_PARAM_FRAGMENTS |
|
from deepspeed.checkpoint.utils import clone_tensors_for_torch_save |
|
from deepspeed.runtime.sparse_tensor import SparseTensor |
|
|
|
from deepspeed.runtime import lr_schedules |
|
from deepspeed.utils import groups |
|
from deepspeed.utils import logger, log_dist, instrument_w_nvtx |
|
from deepspeed.utils.timer import NoopTimer, ThroughputTimer, SynchronizedWallClockTimer, \ |
|
FORWARD_MICRO_TIMER, BACKWARD_MICRO_TIMER, BACKWARD_INNER_MICRO_TIMER, BACKWARD_REDUCE_MICRO_TIMER, \ |
|
STEP_MICRO_TIMER, \ |
|
FORWARD_GLOBAL_TIMER, BACKWARD_GLOBAL_TIMER, BACKWARD_INNER_GLOBAL_TIMER, BACKWARD_REDUCE_GLOBAL_TIMER, \ |
|
STEP_GLOBAL_TIMER |
|
from deepspeed.utils.debug import debug_extract_module_and_param_names, debug_clear_module_and_param_names |
|
from deepspeed.monitor.monitor import MonitorMaster |
|
from deepspeed.runtime.progressive_layer_drop import ProgressiveLayerDrop |
|
from deepspeed.runtime.utils import clip_grad_norm_, compare_tensors_in_structures |
|
from deepspeed.runtime.eigenvalue import Eigenvalue |
|
from deepspeed.runtime.data_pipeline.constants import DATA_SAMPLING, \ |
|
DATA_ROUTING, DATA_SAMPLING_ENABLED, CURRICULUM_LEARNING, \ |
|
CURRICULUM_LEARNING_ENABLED, DATA_SAMPLING_NUM_WORKERS, RANDOM_LTD, \ |
|
RANDOM_LTD_ENABLED, RANDOM_LTD_LAYER_ID, RANDOM_LTD_LAYER_NUM, \ |
|
RANDOM_LTD_LAYER_TOKEN_LR_SCHEDULE, RANDOM_LTD_LAYER_TOKEN_LR_ENABLED, \ |
|
RANDOM_LTD_GLOBAL_BATCH_SIZE, RANDOM_LTD_MICRO_BATCH_SIZE, DATA_EFFICIENCY |
|
from deepspeed.runtime.data_pipeline.curriculum_scheduler import CurriculumScheduler |
|
from deepspeed.runtime.checkpoint_engine import (create_checkpoint_engine, TorchCheckpointEngine, CheckpointCommitInfo) |
|
|
|
from deepspeed.runtime.data_pipeline.data_routing.scheduler import RandomLTDScheduler |
|
from deepspeed.runtime.data_pipeline.data_routing.helper import remove_random_ltd_state_dict |
|
from deepspeed.runtime.data_pipeline.data_routing.basic_layer import RandomLayerTokenDrop |
|
|
|
from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint |
|
|
|
from .pipe.module import PipelineModule |
|
from .utils import get_ma_status |
|
from .compiler import is_compile_supported |
|
from ..ops.adam import FusedAdam |
|
from ..moe.sharded_moe import TopKGate, MOELayer |
|
from ..moe.layer import MoE |
|
from ..moe.utils import is_moe_param, configure_moe_param_groups |
|
from ..git_version_info import version |
|
|
|
from deepspeed.profiling.flops_profiler.profiler import FlopsProfiler |
|
from deepspeed.utils.logging import print_json_dist, print_configuration |
|
|
|
from deepspeed.accelerator import get_accelerator |
|
|
|
from deepspeed.runtime.config import DtypeEnum |
|
|
|
from deepspeed.compile.util import is_deepcompile_supported, get_deepcompile_handle, deepcompile_backward_prologue |
|
from deepspeed.compile.backend import register_compile_pass, opt_passes |
|
from deepspeed.compile.passes import zero3_compile, prefetch, selective_gather, offload_adam_states |
|
from deepspeed.compile.init_z1 import init_z1 |
|
from deepspeed.compile.init_z3 import init_z3 |
|
|
|
MEMORY_OPT_ALLREDUCE_SIZE = 500000000 |
|
|
|
DeepSpeedOptimizerCallable = \ |
|
Callable[[Union[Iterable[Parameter], Dict[str, Iterable]]], Optimizer] |
|
DeepSpeedSchedulerCallable = Callable[[Optimizer], _LRScheduler] |
|
|
|
try: |
|
import apex |
|
from apex import amp |
|
APEX_INSTALLED = True |
|
except ImportError: |
|
|
|
APEX_INSTALLED = False |
|
|
|
|
|
def split_half_float_double_sparse(tensors): |
|
device_type = get_accelerator().device_name() |
|
supported_types = get_accelerator().supported_dtypes() |
|
|
|
for t in tensors: |
|
assert t.dtype in supported_types, f"attempting to reduce an unsupported grad type: {t.dtype}" |
|
|
|
sparse_tensor_buckets, dense_tensor_buckets = [], [] |
|
for i, dtype in enumerate(supported_types): |
|
sparse_bucket, dense_bucket = [], [] |
|
for t in tensors: |
|
if t.dtype == dtype: |
|
if isinstance(t, SparseTensor): |
|
sparse_bucket.append(t) |
|
else: |
|
dense_bucket.append(t) |
|
if sparse_bucket: |
|
sparse_tensor_buckets.append((dtype, sparse_bucket)) |
|
if dense_bucket: |
|
dense_tensor_buckets.append((dtype, dense_bucket)) |
|
return sparse_tensor_buckets, dense_tensor_buckets |
|
|
|
|
|
class EngineTimers(object): |
|
r"""Wallclock timers for DeepSpeedEngine""" |
|
|
|
def __init__(self, enable_micro_timers, enable_global_timers): |
|
self.forward_timers = [] |
|
self.backward_timers = [] |
|
self.backward_inner_timers = [] |
|
self.backward_reduce_timers = [] |
|
self.step_timers = [] |
|
self.global_timers = [] |
|
self.micro_timers = [] |
|
|
|
if enable_micro_timers: |
|
self.forward_timers += [FORWARD_MICRO_TIMER] |
|
self.backward_timers += [BACKWARD_MICRO_TIMER] |
|
self.backward_inner_timers += [BACKWARD_INNER_MICRO_TIMER] |
|
self.backward_reduce_timers += [BACKWARD_REDUCE_MICRO_TIMER] |
|
self.step_timers += [STEP_MICRO_TIMER] |
|
self.micro_timers += [ |
|
FORWARD_MICRO_TIMER, BACKWARD_MICRO_TIMER, BACKWARD_INNER_MICRO_TIMER, BACKWARD_REDUCE_MICRO_TIMER, |
|
STEP_MICRO_TIMER |
|
] |
|
|
|
if enable_global_timers: |
|
self.forward_timers += [FORWARD_GLOBAL_TIMER] |
|
self.backward_timers += [BACKWARD_GLOBAL_TIMER] |
|
self.backward_inner_timers += [BACKWARD_INNER_GLOBAL_TIMER] |
|
self.backward_reduce_timers += [BACKWARD_REDUCE_GLOBAL_TIMER] |
|
self.step_timers += [STEP_GLOBAL_TIMER] |
|
self.global_timers += [ |
|
FORWARD_GLOBAL_TIMER, BACKWARD_GLOBAL_TIMER, BACKWARD_INNER_GLOBAL_TIMER, BACKWARD_REDUCE_GLOBAL_TIMER, |
|
STEP_GLOBAL_TIMER |
|
] |
|
|
|
|
|
class DeepSpeedEngine(Module): |
|
r"""DeepSpeed engine for training.""" |
|
|
|
def __init__(self, |
|
args, |
|
model, |
|
optimizer=None, |
|
model_parameters=None, |
|
training_data=None, |
|
lr_scheduler=None, |
|
mpu=None, |
|
dist_init_required=None, |
|
collate_fn=None, |
|
config=None, |
|
config_class=None, |
|
mesh_device=None, |
|
dont_change_device=False): |
|
super(DeepSpeedEngine, self).__init__() |
|
self.dont_change_device = dont_change_device |
|
self.client_optimizer = optimizer |
|
self.client_lr_scheduler = lr_scheduler |
|
self.training_data = training_data |
|
self.collate_fn = collate_fn |
|
self.mpu = mpu |
|
self.all_to_all_group = None |
|
self.data_parallel_group = None |
|
self.global_steps = 0 |
|
self.global_samples = 0 |
|
self.micro_steps = 0 |
|
self.skipped_steps = 0 |
|
self.gradient_average = True |
|
self.warn_unscaled_loss = True |
|
self.config = config |
|
self._config = config_class |
|
self.loaded_checkpoint_mp_world_size = None |
|
self.loaded_checkpoint_dp_world_size = None |
|
self.enable_backward_allreduce = True |
|
self.inside_no_sync_ctxt = False |
|
self.progressive_layer_drop = None |
|
self.eigenvalue = None |
|
self.block_eigenvalue = None |
|
self.gas_boundary_ctr = 0 |
|
self.dist_backend = get_accelerator().communication_backend_name() |
|
self.has_moe_layers = False |
|
self.num_experts = [] |
|
self.gate_modules = [] |
|
self.moe_layers = [] |
|
self._step_applied = False |
|
self._global_grad_norm = None |
|
self.use_ds_comm = False |
|
self.checkpoint_engine = None |
|
|
|
self._is_gradient_accumulation_boundary = None |
|
self.scale_wrt_gas = None |
|
self.losses = None |
|
self.mesh_device = mesh_device |
|
|
|
|
|
debug_extract_module_and_param_names(model) |
|
|
|
if self.mesh_device: |
|
groups.mesh_device = self.mesh_device |
|
|
|
self._do_args_sanity_check(args) |
|
self._configure_with_arguments(args, mpu) |
|
self._do_sanity_check() |
|
if self.autotp_size() > 1: |
|
self._configure_tensor_parallel(model, self.tensor_parallel_config()) |
|
see_memory_usage(f"DeepSpeed Engine: After args sanity test", force=self.memory_breakdown()) |
|
if mpu is not None: |
|
if self.elasticity_enabled(): |
|
if not self.is_elastic_model_parallel_supported(): |
|
assert not self.elasticity_enabled(), ("Elasticity is not currently supported" |
|
" with model parallelism.") |
|
|
|
self._set_distributed_vars(args) |
|
|
|
dist.configure(self._config) |
|
|
|
self.monitor = MonitorMaster(self._config.monitor_config) |
|
|
|
see_memory_usage( |
|
f"DeepSpeed Engine: Before configure distributed model", |
|
force=self.memory_breakdown(), |
|
) |
|
|
|
self.pipeline_parallelism = isinstance(model, PipelineModule) |
|
|
|
|
|
self._configure_distributed_model(model) |
|
|
|
if not self.is_deepcompile_enabled(): |
|
self.module_forward_pre_hook = self._create_module_forward_pre_hook() |
|
self.module_forward_post_hook = self._create_module_forward_post_hook() |
|
|
|
|
|
self.param_names = {param: name for name, param in model.named_parameters()} |
|
|
|
self._get_model_parameters() |
|
|
|
see_memory_usage(f"DeepSpeed Engine: After configure distributed model") |
|
|
|
|
|
self.timers = SynchronizedWallClockTimer() |
|
|
|
self.tput_timer = ThroughputTimer(self._config.timers_config, |
|
batch_size=self.train_batch_size(), |
|
steps_per_output=self.steps_per_print(), |
|
monitor_memory=False) |
|
|
|
log_dist(f"DeepSpeed Flops Profiler Enabled: {self.flops_profiler_enabled()}", ranks=[0]) |
|
|
|
if self.flops_profiler_enabled(): |
|
self.flops_profiler = FlopsProfiler(self.module, self, self.flops_profiler_recompute_fwd_factor()) |
|
|
|
if training_data: |
|
self.training_dataloader = self.deepspeed_io(training_data) |
|
else: |
|
self.training_dataloader = None |
|
|
|
|
|
self.optimizer = None |
|
self.basic_optimizer = None |
|
self.lr_scheduler = None |
|
has_optimizer = False |
|
|
|
if optimizer or self.optimizer_name(): |
|
has_optimizer = True |
|
|
|
if model_parameters is None: |
|
model_parameters = self.module.parameters() |
|
|
|
|
|
if not isinstance(model_parameters, list): |
|
model_parameters = list(model_parameters) |
|
|
|
if has_optimizer: |
|
self._configure_optimizer(optimizer, model_parameters) |
|
self._configure_lr_scheduler() |
|
self._report_progress(0) |
|
elif self.zero_optimization(): |
|
|
|
self.optimizer = self._configure_zero_optimizer(optimizer=None) |
|
elif self.bfloat16_enabled(): |
|
self.optimizer = self._configure_bf16_optimizer(optimizer=None) |
|
|
|
|
|
if hasattr(model, 'pruners'): |
|
from ..compression.helper import rewrite_optimizer_step |
|
self.optimizer.pruners = model.pruners |
|
rewrite_optimizer_step(self.optimizer) |
|
|
|
|
|
self.sparse_tensor_module_names = set() |
|
|
|
for name, module in self.module.named_modules(): |
|
if isinstance(module, (torch.nn.Embedding, torch.nn.EmbeddingBag)) and self.sparse_gradients_enabled(): |
|
self.sparse_tensor_module_names.add(name + ".weight") |
|
logger.info("Will convert {} to sparse tensor during training".format(name)) |
|
|
|
self._optimized_linear_offload_setup() |
|
|
|
self.save_non_zero_checkpoint = False |
|
self.save_zero_checkpoint = False |
|
if not isinstance(self.optimizer, DeepSpeedZeRoOffload): |
|
self._configure_checkpointing() |
|
|
|
if self.eigenvalue_enabled(): |
|
self.eigenvalue = self._configure_eigenvalue() |
|
|
|
if self.pld_enabled(): |
|
self.progressive_layer_drop = self._configure_progressive_layer_drop() |
|
|
|
if self.curriculum_enabled_legacy(): |
|
self.curriculum_scheduler_legacy = self._configure_curriculum_scheduler_legacy() |
|
|
|
if self.random_ltd_enabled(): |
|
random_ltd_config = self.random_ltd_config() |
|
random_ltd_config[RANDOM_LTD_GLOBAL_BATCH_SIZE] = self.train_batch_size() |
|
random_ltd_config[RANDOM_LTD_MICRO_BATCH_SIZE] = self.train_micro_batch_size_per_gpu() |
|
self.random_ltd_scheduler = self._configure_random_ltd_scheduler(random_ltd_config) |
|
|
|
|
|
|
|
self.engine_timers = EngineTimers(enable_micro_timers=self.wall_clock_breakdown(), |
|
enable_global_timers=self.wall_clock_breakdown() |
|
or self.flops_profiler_enabled()) |
|
|
|
if self.global_rank == 0: |
|
self._config.print("DeepSpeedEngine configuration") |
|
if self.dump_state(): |
|
print_configuration(self, "DeepSpeedEngine") |
|
|
|
|
|
self.flatten = _flatten_dense_tensors |
|
self.unflatten = _unflatten_dense_tensors |
|
|
|
self._is_compiled = False |
|
if is_deepcompile_supported(): |
|
|
|
self.register_compile_pass(zero3_compile.NAME, zero3_compile.add_z3_gather_release) |
|
self.register_compile_pass(prefetch.NAME, prefetch.schedule_prefetch) |
|
self.register_compile_pass(selective_gather.NAME, selective_gather.selective_gather) |
|
self.register_compile_pass(offload_adam_states.NAME, offload_adam_states.move_opt_states) |
|
|
|
def _optimized_linear_offload_setup(self): |
|
self.optimized_linear_base_weight_sharding = False |
|
self.optimized_linear_lora_enabled = False |
|
offload_ratio = None |
|
for _, module in self.module.named_modules(): |
|
if isinstance(module, LoRAOptimizedLinear): |
|
self.optimized_linear_lora_enabled = True |
|
offload_ratio = None |
|
if offload_ratio is not None: |
|
assert offload_ratio == module.lora_config.offload_ratio, \ |
|
"all lora_config offload ratios should be the same across the model" |
|
offload_ratio = module.lora_config.offload_ratio |
|
if module.zero_shards > 1: |
|
|
|
self.optimized_linear_base_weight_sharding = True |
|
|
|
if offload_ratio is None: |
|
|
|
return |
|
|
|
total_params = 0 |
|
for _, p in self.module.named_parameters(): |
|
if hasattr(p, 'ds_optim_param'): |
|
total_params += p.numel() |
|
|
|
offload_limit = total_params * offload_ratio |
|
logger.info(f'offloading {offload_ratio*100}% of eligible params, specifically {offload_limit} params') |
|
total_offloaded = 0 |
|
for _, p in self.module.named_parameters(): |
|
if hasattr(p, 'ds_optim_param'): |
|
if total_offloaded < offload_limit: |
|
total_offloaded += p.numel() |
|
p.ds_offload = True |
|
p.offload() |
|
else: |
|
p.ds_offload = False |
|
|
|
def _configure_tensor_parallel(self, model, tp_config): |
|
self._configure_tensor_parallel_states(model) |
|
configure_tensor_parallel_runtime(tp_config) |
|
|
|
def _configure_tensor_parallel_states(self, model): |
|
""" |
|
Configures the tensor parallel states for the model. |
|
This includes setting up the tensor parallel groups, initializing the TP mesh, |
|
and registering a pre-hook to ensure that the Dataloader inputs are consistent across ranks. |
|
""" |
|
self._set_client_model(model) |
|
|
|
|
|
assert self.zero_optimization_stage( |
|
) <= 2, "Currently, the compatibility between 'autotp' and 'zero_stage = 3' has not been validated" |
|
|
|
self.mpu = groups |
|
self.mpu._init_tp_mesh_device(tensor_model_parallel_size=self.autotp_size()) |
|
|
|
self.first_dataloader_check = None |
|
|
|
def check_dataloader_inputs_same_across_ranks(module, args, kwargs): |
|
|
|
def broadcast_and_check(args, bcast_rank, bcast_group): |
|
if isinstance(args, tuple): |
|
args = list(args) |
|
if len(args) > 0: |
|
if self.mpu.get_tensor_model_parallel_rank() == 0: |
|
_src_args = [args] |
|
dist.broadcast_object_list(object_list=_src_args, |
|
src=bcast_rank, |
|
group=bcast_group, |
|
device=get_accelerator().current_device()) |
|
|
|
is_equal = True |
|
else: |
|
_src_args = [None] |
|
dist.broadcast_object_list(object_list=_src_args, |
|
src=bcast_rank, |
|
group=bcast_group, |
|
device=get_accelerator().current_device()) |
|
|
|
is_equal = compare_tensors_in_structures(args, _src_args[0]) |
|
|
|
equal_tensor = torch.tensor(is_equal, |
|
dtype=self.communication_data_type, |
|
device=get_accelerator().current_device()) |
|
dist.all_reduce(equal_tensor, group=bcast_group) |
|
assert torch.equal( |
|
equal_tensor, |
|
torch.tensor(groups.get_tensor_model_parallel_world_size(), |
|
dtype=self.communication_data_type, |
|
device=get_accelerator().current_device()) |
|
), "Data inconsistency within the TP group. Please check the Dataloader implementation to ensure consistency." |
|
|
|
bcast_rank = self.mpu.get_tensor_model_parallel_src_rank() |
|
bcast_group = self.mpu.get_tensor_model_parallel_group() |
|
|
|
broadcast_and_check(args, bcast_rank, bcast_group) |
|
broadcast_and_check(kwargs, bcast_rank, bcast_group) |
|
|
|
logger.info(f":The Dataloader has passed the TP group consistency check.") |
|
self.first_dataloader_check.remove() |
|
|
|
self.first_dataloader_check = self.module.register_forward_pre_hook(check_dataloader_inputs_same_across_ranks, |
|
prepend=True, |
|
with_kwargs=True) |
|
|
|
def __del__(self): |
|
self.destroy() |
|
|
|
def destroy(self): |
|
if self.optimizer is not None and hasattr(self.optimizer, 'destroy'): |
|
self.optimizer.destroy() |
|
if self.is_deepcompile_enabled(): |
|
get_deepcompile_handle().cleanup() |
|
debug_clear_module_and_param_names() |
|
|
|
if self.checkpoint_engine is not None and self.checkpoint_engine.is_decoupled(): |
|
self.checkpoint_engine.cleanup() |
|
|
|
def _get_model_parameters(self): |
|
if self.autotuning_profile_model_info(): |
|
self.autotuning_model_info = {} |
|
num_params = 0 |
|
trainable_num_params = 0 |
|
|
|
for p in self.module.parameters(): |
|
|
|
n = 0 |
|
if hasattr(p, "ds_tensor"): |
|
n += p.ds_numel |
|
else: |
|
n += p.numel() |
|
num_params += n |
|
if p.requires_grad: |
|
trainable_num_params += n |
|
if self.global_rank == 0: |
|
self.autotuning_model_info["num_params"] = num_params * self.mp_world_size |
|
self.autotuning_model_info["trainable_num_params"] = trainable_num_params * self.mp_world_size |
|
|
|
logger.info(f"model parameter = {num_params}") |
|
|
|
def get_batch_info(self): |
|
"""Get all training batch related settings. |
|
Returns: |
|
train_batch_size (int): The effective training batch size. This is the amount of data |
|
samples that leads to one step of model update. |
|
train_micro_batch_size_per_gpu (int): Batch size to be processed by one GPU in one |
|
step (without gradient accumulation). |
|
gradient_accumulation_steps (int): Number of training steps to accumulate gradients |
|
before averaging and applying them. |
|
""" |
|
return ( |
|
self.train_batch_size, |
|
self.train_micro_batch_size_per_gpu, |
|
self.gradient_accumulation_steps, |
|
) |
|
|
|
def set_train_batch_size(self, train_batch_size): |
|
"""Adjust the global batch size by increasing or decreasing the number of |
|
micro-batches (i.e., gradient accumulation steps). The size of each micro-batch |
|
(i.e., ``train_micro_batch_size_per_gpu``) is not changed. |
|
Args: |
|
train_batch_size (int): The new global batch size for training. |
|
Raises: |
|
ValueError: if ``train_batch_size`` is not divisible by the |
|
configured micro-batch size and data parallelism. |
|
""" |
|
if train_batch_size % (self.train_micro_batch_size_per_gpu() * self.dp_world_size) != 0: |
|
|
|
raise ValueError(f'Train batch size must be divisible by micro-batch data parallelism') |
|
new_gas = train_batch_size // (self.train_micro_batch_size_per_gpu() * self.dp_world_size) |
|
|
|
self._config.train_batch_size = train_batch_size |
|
self._config.gradient_accumulation_steps = new_gas |
|
|
|
def set_train_micro_batch_size(self, micro_batch_size): |
|
"""Adjust the micro batch size(i.e., the micro batch size in every data parallel group), |
|
while keep the gradient accumulation steps the same. |
|
Args: |
|
micro_batch_size (int): The new micro batch size for training. |
|
""" |
|
|
|
new_global_batch_size = micro_batch_size * self._config.gradient_accumulation_steps * self.dp_world_size |
|
self._config.train_batch_size = new_global_batch_size |
|
self._config.train_micro_batch_size_per_gpu = micro_batch_size |
|
|
|
def set_data_post_process_func(self, post_process_func): |
|
if self.training_dataloader is not None: |
|
self.training_dataloader.post_process_func = post_process_func |
|
|
|
def set_custom_curriculum_learning_schedule(self, schedule_func_dict): |
|
if self.training_dataloader is not None and self.curriculum_learning_enabled(): |
|
self.training_dataloader.data_sampler.set_custom_curriculum_learning_schedule(schedule_func_dict) |
|
|
|
def get_global_grad_norm(self) -> float: |
|
"""Return the 2-norm of all gradients. If there is model parallelism, |
|
the norm will be global. |
|
The computed norm will be cached and reused until the next step() pass. |
|
.. note:: |
|
In the presence of model parallelism, this is a collective call |
|
and acts as a barrier among ``mpu.get_model_parallel_group()``. |
|
Returns: |
|
float: norm |
|
""" |
|
return self._global_grad_norm |
|
|
|
def __getattr__(self, name): |
|
""" |
|
Pass through attributes defined in the model if they are not overridden by ds-engine. |
|
""" |
|
|
|
_module = {} |
|
if "module" in self.__dict__: |
|
_module = self.__dict__['module'] |
|
if name in dir(self): |
|
return getattr(self, name) |
|
elif name in dir(_module): |
|
return getattr(_module, name) |
|
else: |
|
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") |
|
|
|
def checkpoint_serialization_enabled(self): |
|
return self._config.checkpoint_config[CHECKPOINT_SERIALIZATION] |
|
|
|
def checkpoint_writer_enabled(self): |
|
return self._config.checkpoint_config[CHECKPOINT_WRITER] is not None |
|
|
|
def checkpoint_tag_validation_enabled(self): |
|
return self._config.checkpoint_config[CHECKPOINT_TAG_VALIDATION] != ValidationMode.IGNORE |
|
|
|
def checkpoint_tag_validation_fail(self): |
|
return self._config.checkpoint_config[CHECKPOINT_TAG_VALIDATION] == ValidationMode.FAIL |
|
|
|
def elasticity_enabled(self): |
|
return self._config.elasticity_enabled |
|
|
|
def is_elastic_model_parallel_supported(self): |
|
if self.elasticity_enabled(): |
|
|
|
if self._config.num_gpus_per_node % self._config.elastic_model_parallel_size == 0: |
|
return True |
|
else: |
|
return False |
|
|
|
def pld_enabled(self): |
|
return self._config.pld_enabled |
|
|
|
def pld_params(self): |
|
return self._config.pld_params |
|
|
|
def pld_theta(self): |
|
return self.pld_params()[PLD_THETA] |
|
|
|
def pld_gamma(self): |
|
return self.pld_params()[PLD_GAMMA] |
|
|
|
def eigenvalue_enabled(self): |
|
return self._config.eigenvalue_enabled |
|
|
|
def eigenvalue_verbose(self): |
|
return self._config.eigenvalue_verbose |
|
|
|
def eigenvalue_max_iter(self): |
|
return self._config.eigenvalue_max_iter |
|
|
|
def eigenvalue_tol(self): |
|
return self._config.eigenvalue_tol |
|
|
|
def eigenvalue_stability(self): |
|
return self._config.eigenvalue_stability |
|
|
|
def eigenvalue_gas_boundary_resolution(self): |
|
return self._config.eigenvalue_gas_boundary_resolution |
|
|
|
def eigenvalue_layer_name(self): |
|
return self._config.eigenvalue_layer_name |
|
|
|
def eigenvalue_layer_num(self): |
|
return self._config.eigenvalue_layer_num |
|
|
|
def curriculum_enabled_legacy(self): |
|
return self._config.curriculum_enabled_legacy |
|
|
|
def curriculum_params_legacy(self): |
|
return self._config.curriculum_params_legacy |
|
|
|
def data_efficiency_enabled(self): |
|
return self._config.data_efficiency_enabled |
|
|
|
def data_efficiency_config(self): |
|
return self._config.data_efficiency_config |
|
|
|
def data_sampling_enabled(self): |
|
return self._config.data_efficiency_config[DATA_SAMPLING][DATA_SAMPLING_ENABLED] |
|
|
|
def data_sampling_config(self): |
|
return self._config.data_efficiency_config[DATA_SAMPLING] |
|
|
|
def curriculum_learning_enabled(self): |
|
return self._config.data_efficiency_config[DATA_SAMPLING][CURRICULUM_LEARNING][CURRICULUM_LEARNING_ENABLED] |
|
|
|
def curriculum_learning_config(self): |
|
return self._config.data_efficiency_config[DATA_SAMPLING][CURRICULUM_LEARNING] |
|
|
|
def random_ltd_enabled(self): |
|
return self._config.data_efficiency_config[DATA_ROUTING][RANDOM_LTD][RANDOM_LTD_ENABLED] |
|
|
|
def random_ltd_config(self): |
|
return self._config.data_efficiency_config[DATA_ROUTING][RANDOM_LTD] |
|
|
|
def random_ltd_initialize(self): |
|
assert self.random_ltd_enabled() |
|
random_ltd_config = self.random_ltd_config() |
|
random_ltd_queue = deque([x for x in sorted(random_ltd_config[RANDOM_LTD_LAYER_ID])]) |
|
count = 0 |
|
for name, layer in self.module.named_modules(): |
|
if isinstance(layer, RandomLayerTokenDrop): |
|
if len(random_ltd_queue) != 0 and str(random_ltd_queue[0]) in name: |
|
layer.init_config(random_ltd_config, self.random_ltd_scheduler, count) |
|
random_ltd_queue.popleft() |
|
count += 1 |
|
|
|
if random_ltd_config[RANDOM_LTD_LAYER_NUM] != count: |
|
raise ValueError(f'random_ltd_layer_num {random_ltd_config[RANDOM_LTD_LAYER_NUM]} must be \ |
|
equivalent to the len of random_ltd_layer_id {count}') |
|
|
|
if random_ltd_config[RANDOM_LTD_LAYER_TOKEN_LR_SCHEDULE][RANDOM_LTD_LAYER_TOKEN_LR_ENABLED]: |
|
assert self.client_lr_scheduler is None |
|
raise ValueError(f'not yet support') |
|
|
|
|
|
def get_sequence_parallel_group(self): |
|
return self.seq_parallel_group |
|
|
|
def wall_clock_breakdown(self): |
|
return self._config.wall_clock_breakdown |
|
|
|
def flops_profiler_enabled(self): |
|
return self._config.flops_profiler_config.enabled or self.autotuning_enabled() |
|
|
|
def flops_profiler_recompute_fwd_factor(self): |
|
return self._config.flops_profiler_config.recompute_fwd_factor |
|
|
|
def flops_profiler_profile_step(self): |
|
step = self._config.flops_profiler_config.profile_step |
|
if self._config.autotuning_config.enabled: |
|
step = self.autotuning_start_profile_step() |
|
return step |
|
|
|
def flops_profiler_module_depth(self): |
|
return self._config.flops_profiler_config.module_depth |
|
|
|
def flops_profiler_top_modules(self): |
|
return self._config.flops_profiler_config.top_modules |
|
|
|
def flops_profiler_detailed(self): |
|
if self._config.autotuning_config.enabled: |
|
return False |
|
return self._config.flops_profiler_config.detailed |
|
|
|
def flops_profiler_output_file(self): |
|
return self._config.flops_profiler_config.output_file |
|
|
|
def memory_breakdown(self): |
|
return self._config.memory_breakdown |
|
|
|
def autotuning_enabled(self): |
|
return self._config.autotuning_config.enabled |
|
|
|
def autotuning_start_profile_step(self): |
|
return self._config.autotuning_config.start_profile_step |
|
|
|
def autotuning_end_profile_step(self): |
|
return self._config.autotuning_config.end_profile_step |
|
|
|
def autotuning_metric_path(self): |
|
path = self._config.autotuning_config.metric_path |
|
if not path: |
|
path = os.path.join(os.getcwd(), "autotuning_metric.json") |
|
return path |
|
|
|
def autotuning_model_info_path(self): |
|
path = self._config.autotuning_config.model_info_path |
|
if not path: |
|
path = os.path.join(os.getcwd(), "autotuning_model_info.json") |
|
return path |
|
|
|
def autotuning_metric(self): |
|
return self._config.autotuning_config.metric |
|
|
|
def autotuning_profile_model_info(self): |
|
return self.autotuning_enabled( |
|
) and self._config.autotuning_config.model_info and self._config.autotuning_config.model_info.get( |
|
"profile", False) |
|
|
|
def sparse_gradients_enabled(self): |
|
return self._config.sparse_gradients_enabled |
|
|
|
def train_batch_size(self): |
|
return self._config.train_batch_size |
|
|
|
def train_micro_batch_size_per_gpu(self): |
|
return self._config.train_micro_batch_size_per_gpu |
|
|
|
def optimizer_name(self): |
|
return (self.client_optimizer.__class__.__name__ if self.client_optimizer else self._config.optimizer_name) |
|
|
|
def optimizer_params(self): |
|
return self._config.optimizer_params |
|
|
|
def optimizer_legacy_fusion(self): |
|
return self._config.optimizer_legacy_fusion |
|
|
|
def scheduler_name(self): |
|
return self._config.scheduler_name |
|
|
|
def scheduler_params(self): |
|
return self._config.scheduler_params |
|
|
|
def quantize_training(self): |
|
return ( |
|
self._config.compression_config[WEIGHT_QUANTIZATION][SHARED_PARAMETERS] |
|
[WEIGHT_QUANTIZE_IN_FORWARD_ENABLED], |
|
self._config.compression_config[WEIGHT_QUANTIZATION][SHARED_PARAMETERS][WEIGHT_QUANTIZE_ENABLED], |
|
self._config.compression_config[WEIGHT_QUANTIZATION][SHARED_PARAMETERS][WEIGHT_QUANTIZE_GROUPS], |
|
self._config.compression_config[WEIGHT_QUANTIZATION][SHARED_PARAMETERS] |
|
[WEIGHT_QUANTIZE_FP16_MIXED_QUANTIZE], |
|
self._config.compression_config[WEIGHT_QUANTIZATION][SHARED_PARAMETERS][WEIGHT_QUANTIZE_CHANGE_RATIO], |
|
self._config.compression_config[WEIGHT_QUANTIZATION][SHARED_PARAMETERS][WEIGHT_QUANTIZE_TYPE], |
|
self._config.compression_config[WEIGHT_QUANTIZATION][SHARED_PARAMETERS][WEIGHT_QUANTIZE_ROUNDING], |
|
self._config.compression_config[WEIGHT_QUANTIZATION][SHARED_PARAMETERS][WEIGHT_QUANTIZE_VERBOSE], |
|
self._config.compression_config[WEIGHT_QUANTIZATION][SHARED_PARAMETERS][WEIGHT_QUANTIZE_KERNEL], |
|
) |
|
|
|
def zero_optimization(self): |
|
return self._config.zero_enabled |
|
|
|
def zero_allow_untested_optimizer(self): |
|
return self._config.zero_allow_untested_optimizer |
|
|
|
def zero_force_ds_cpu_optimizer(self): |
|
return self._config.zero_force_ds_cpu_optimizer |
|
|
|
def zero_reduce_scatter(self): |
|
return self._config.zero_config.reduce_scatter |
|
|
|
def zero_overlap_comm(self): |
|
return self._config.zero_config.overlap_comm |
|
|
|
def zero_offload_optimizer(self): |
|
return self._config.zero_config.offload_optimizer |
|
|
|
def zero_offload_param(self): |
|
return self._config.zero_config.offload_param |
|
|
|
def zero_use_cpu_optimizer(self): |
|
if self._config.zero_config.offload_optimizer is not None: |
|
return self._config.zero_config.offload_optimizer.device in [OffloadDeviceEnum.cpu, OffloadDeviceEnum.nvme] |
|
return False |
|
|
|
def zero_cpu_offload(self): |
|
if self._config.zero_config.offload_optimizer is not None: |
|
return self._config.zero_config.offload_optimizer.device == OffloadDeviceEnum.cpu |
|
return False |
|
|
|
def zero_partial_offload(self): |
|
return getattr(self._config.zero_config.offload_optimizer, "ratio", 1.0) |
|
|
|
def zero_sub_group_size(self): |
|
return self._config.zero_config.sub_group_size |
|
|
|
def zero_optimization_stage(self): |
|
return self._config.zero_optimization_stage |
|
|
|
def mics_shard_size(self): |
|
return self._config.mics_shard_size |
|
|
|
def zero_reduce_bucket_size(self): |
|
return self._config.zero_config.reduce_bucket_size |
|
|
|
def zero_multi_rank_bucket_allreduce(self): |
|
return self._config.zero_config.use_multi_rank_bucket_allreduce |
|
|
|
def zero_allgather_bucket_size(self): |
|
return self._config.zero_config.allgather_bucket_size |
|
|
|
def zero_optimization_partition_gradients(self): |
|
return self.zero_optimization_stage() >= ZeroStageEnum.gradients |
|
|
|
def zero_optimization_partition_weights(self): |
|
return self.zero_optimization_stage() >= ZeroStageEnum.weights |
|
|
|
def is_first_weights_partition_group(self): |
|
ret = True if self.mics_shard_size() < 0 \ |
|
and self.zero_optimization_partition_weights() else False |
|
if self.mics_shard_size() > 0 and self.global_rank < self.mics_shard_size(): |
|
ret = True |
|
return ret |
|
|
|
def zero_contiguous_gradients(self): |
|
return self._config.zero_config.contiguous_gradients |
|
|
|
def zero_load_from_fp32_weights(self): |
|
return self._config.zero_config.load_from_fp32_weights |
|
|
|
def zero_elastic_checkpoint(self): |
|
return self._config.zero_config.elastic_checkpoint |
|
|
|
def zero_nvme_offload_optimizer(self): |
|
return getattr(self.optimizer, "swap_optimizer", False) |
|
|
|
def zero_max_live_parameters(self): |
|
return self._config.zero_config.max_live_parameters |
|
|
|
def zero_max_reuse_distance(self): |
|
return self._config.zero_config.max_reuse_distance |
|
|
|
def zero_prefetch_bucket_size(self): |
|
return self._config.zero_config.prefetch_bucket_size |
|
|
|
def zero_module_granularity_threshold(self): |
|
return self._config.zero_config.module_granularity_threshold |
|
|
|
def zero_param_persistence_threshold(self): |
|
return self._config.zero_config.param_persistence_threshold |
|
|
|
def zero_model_persistence_threshold(self): |
|
return self._config.zero_config.model_persistence_threshold |
|
|
|
def zero_gather_16bit_weights_on_model_save(self): |
|
return self._config.zero_config.gather_16bit_weights_on_model_save |
|
|
|
def zero_grad_hooks(self): |
|
return self._config.zero_config.grad_hooks |
|
|
|
def zero_legacy_stage1(self): |
|
return self._config.zero_config.legacy_stage1 |
|
|
|
def zero_ignore_unused_parameters(self): |
|
return self._config.zero_config.ignore_unused_parameters |
|
|
|
def tensor_parallel_config(self): |
|
return self._config.tensor_parallel_config |
|
|
|
def autotp_size(self): |
|
return self._config.tensor_parallel_config.autotp_size |
|
|
|
def graph_harvesting(self): |
|
return self._config.graph_harvesting |
|
|
|
def fp16_enabled(self): |
|
return self._config.float16_config.enabled |
|
|
|
def bfloat16_enabled(self): |
|
return self._config.bfloat16_config.enabled |
|
|
|
def fp16_master_weights_and_gradients(self): |
|
return self._config.float16_config.fp16_master_weights_and_grads |
|
|
|
def amp_enabled(self): |
|
return self._config.amp_enabled |
|
|
|
def amp_params(self): |
|
return self._config.amp_params |
|
|
|
def fp16_auto_cast(self): |
|
return self._config.float16_config.auto_cast |
|
|
|
def loss_scale(self): |
|
return self._config.float16_config.loss_scale |
|
|
|
def gradient_accumulation_steps(self): |
|
return self._config.gradient_accumulation_steps |
|
|
|
def use_node_local_storage(self): |
|
return self._config.use_node_local_storage |
|
|
|
def load_universal_checkpoint(self): |
|
return self._config.load_universal_checkpoint |
|
|
|
@property |
|
def communication_data_type(self): |
|
res = self._config.communication_data_type |
|
if res is not None: |
|
return res |
|
|
|
if self.fp16_enabled(): |
|
return torch.float16 |
|
|
|
if self.bfloat16_enabled(): |
|
return torch.bfloat16 |
|
|
|
return torch.float32 |
|
|
|
@communication_data_type.setter |
|
def communication_data_type(self, value): |
|
self._config.communication_data_type = value |
|
|
|
def postscale_gradients(self): |
|
return not self._config.prescale_gradients |
|
|
|
def gradient_predivide_factor(self): |
|
return self._config.gradient_predivide_factor |
|
|
|
def steps_per_print(self): |
|
return self._config.steps_per_print |
|
|
|
def zero_allgather_partitions(self): |
|
return self._config.zero_config.allgather_partitions |
|
|
|
def zero_round_robin_gradients(self): |
|
return self._config.zero_config.round_robin_gradients |
|
|
|
def zero_hpz_partition_size(self): |
|
return self._config.zero_config.zero_hpz_partition_size |
|
|
|
def zero_quantized_weights(self): |
|
return self._config.zero_config.zero_quantized_weights |
|
|
|
def zero_quantized_nontrainable_weights(self): |
|
return self._config.zero_config.zero_quantized_nontrainable_weights |
|
|
|
def zero_quantized_gradients(self): |
|
return self._config.zero_config.zero_quantized_gradients |
|
|
|
def zeropp_loco_param(self): |
|
return self._config.zero_config.zeropp_loco_param |
|
|
|
def zero_log_trace_cache_warnings(self): |
|
return self._config.zero_config.log_trace_cache_warnings |
|
|
|
def dump_state(self): |
|
return self._config.dump_state |
|
|
|
def gradient_clipping(self): |
|
return self._config.gradient_clipping |
|
|
|
def dynamic_loss_scale(self): |
|
return self._config.float16_config.loss_scale == 0 |
|
|
|
def initial_dynamic_scale(self): |
|
return self._config.float16_config.initial_dynamic_scale() |
|
|
|
def dynamic_loss_scale_args(self): |
|
return self._config.float16_config.dynamic_loss_scale_args() |
|
|
|
def swap_tensor_config(self): |
|
return self._config.swap_tensor_config |
|
|
|
def aio_config(self): |
|
return self._config.aio_config |
|
|
|
def get_data_types(self): |
|
model_dtype = torch.float32 |
|
if self.fp16_enabled(): |
|
model_dtype = torch.float16 |
|
elif self.bfloat16_enabled(): |
|
model_dtype = torch.bfloat16 |
|
|
|
if self._config.grad_accum_dtype is None: |
|
if model_dtype == torch.bfloat16 and not self.zero_optimization(): |
|
grad_accum_dtype = torch.float32 |
|
else: |
|
grad_accum_dtype = model_dtype |
|
else: |
|
grad_accum_dtype = DtypeEnum(self._config.grad_accum_dtype).value |
|
|
|
return (model_dtype, grad_accum_dtype) |
|
|
|
def _optimizer_has_ckpt_event_prologue(self): |
|
return self.optimizer is not None and hasattr(self.optimizer, 'checkpoint_event_prologue') |
|
|
|
def _optimizer_has_ckpt_event_epilogue(self): |
|
return self.optimizer is not None and hasattr(self.optimizer, 'checkpoint_event_epilogue') |
|
|
|
def _configure_lr_scheduler(self): |
|
if self.client_lr_scheduler: |
|
if isinstance(self.client_lr_scheduler, Callable): |
|
log_dist('DeepSpeed using client callable to create LR scheduler', ranks=[0]) |
|
self.lr_scheduler = self.client_lr_scheduler(self.basic_optimizer) |
|
else: |
|
log_dist('DeepSpeed using client LR scheduler', ranks=[0]) |
|
self.lr_scheduler = self.client_lr_scheduler |
|
else: |
|
|
|
lr_scheduler = self._scheduler_from_config(self.optimizer) |
|
log_dist(f"DeepSpeed using configured LR scheduler = {self.scheduler_name()}", ranks=[0]) |
|
self.lr_scheduler = lr_scheduler |
|
|
|
log_dist(f'DeepSpeed LR Scheduler = {self.lr_scheduler}', ranks=[0]) |
|
|
|
def _configure_checkpointing(self): |
|
|
|
optimize_dp_state = not self.zero_optimization_partition_weights() |
|
self.checkpoint_engine = create_checkpoint_engine(config_params=self._config, |
|
groups=groups, |
|
zero_stage=self.zero_optimization_stage(), |
|
has_moe_layers=self.has_moe_layers, |
|
optimize_dp_state=optimize_dp_state) |
|
|
|
dp_rank = groups._get_sequence_data_parallel_rank() |
|
rank = self.local_rank if self.use_node_local_storage() else dp_rank |
|
|
|
|
|
if self.checkpoint_engine.is_data_parallel_writer(rank) \ |
|
or (self.zero_optimization_partition_weights() and self.is_first_weights_partition_group()): |
|
self.save_non_zero_checkpoint = True |
|
|
|
if self.zero_optimization() or self.bfloat16_enabled(): |
|
param_rank = dist.get_rank(group=self.optimizer.dp_process_group) |
|
|
|
|
|
|
|
self.save_zero_checkpoint = param_rank == dp_rank |
|
|
|
def _scheduler_from_config(self, optimizer): |
|
scheduler_name = self.scheduler_name() |
|
if scheduler_name is not None: |
|
if hasattr(lr_schedules, scheduler_name): |
|
scheduler = getattr(lr_schedules, scheduler_name) |
|
else: |
|
assert hasattr(torch.optim.lr_scheduler, |
|
scheduler_name), f"DeepSpeed does not recognize LR scheduler {scheduler_name}" |
|
|
|
scheduler = getattr(torch.optim.lr_scheduler, scheduler_name) |
|
|
|
scheduler_params = self.scheduler_params() |
|
instantiated_scheduler = scheduler(optimizer, **scheduler_params) |
|
return instantiated_scheduler |
|
else: |
|
return None |
|
|
|
def _set_distributed_vars(self, args): |
|
device_rank = args.device_rank if args is not None and hasattr(args, 'device_rank') else self.local_rank |
|
if device_rank >= 0: |
|
get_accelerator().set_device(device_rank) |
|
self.device = torch.device(get_accelerator().device_name(device_rank)) |
|
self.world_size = dist.get_world_size() |
|
self.global_rank = dist.get_rank() |
|
else: |
|
self.world_size = 1 |
|
self.global_rank = 0 |
|
self.device = get_accelerator().device() |
|
|
|
|
|
def _configure_with_arguments(self, args, mpu): |
|
|
|
|
|
|
|
|
|
|
|
if "OMPI_COMM_WORLD_LOCAL_RANK" in os.environ: |
|
ompi_local_rank = os.environ.get("OMPI_COMM_WORLD_LOCAL_RANK") |
|
local_rank = os.environ.get('LOCAL_RANK', ompi_local_rank) |
|
assert ompi_local_rank == local_rank, f"LOCAL_RANK ({local_rank}) != OMPI_COMM_WORLD_LOCAL_RANK ({ompi_local_rank}), " \ |
|
"not sure how to proceed as we're seeing conflicting local rank info." |
|
os.environ['LOCAL_RANK'] = local_rank |
|
|
|
self.local_rank = int(os.environ['LOCAL_RANK']) |
|
if hasattr(args, 'local_rank'): |
|
args.local_rank = self.local_rank |
|
|
|
|
|
def _do_args_sanity_check(self, args): |
|
assert "LOCAL_RANK" in os.environ or "OMPI_COMM_WORLD_LOCAL_RANK" in os.environ, "DeepSpeed requires the LOCAL_RANK environment " \ |
|
"variable, it is set by the deepspeed launcher, deepspeed.init_distributed, or the torch's launcher. If using a " \ |
|
"different launcher please ensure LOCAL_RANK is set prior to initializing deepspeed." |
|
|
|
if hasattr(args, 'local_rank') and args.local_rank is not None: |
|
assert isinstance(args.local_rank, |
|
int), f"args.local_rank of {args.local_rank} is an unknown type {type(args.local_rank)}" |
|
if args.local_rank >= 0: |
|
env_local_rank = int(os.environ.get("LOCAL_RANK")) |
|
assert ( |
|
env_local_rank == args.local_rank |
|
), f"Mismatch in local rank setting, args.local_rank={args.local_rank} but env['LOCAL_RANK']={env_local_rank}." |
|
|
|
def _is_supported_optimizer(self, optimizer_name): |
|
return (optimizer_name in DEEPSPEED_OPTIMIZERS or getattr(torch.optim, optimizer_name, None) is not None) |
|
|
|
def _supported_optims(self): |
|
FairseqOptimizer = None |
|
try: |
|
from fairseq.optim.fairseq_optimizer import FairseqOptimizer |
|
except ImportError: |
|
pass |
|
|
|
expected_optim_types = [Optimizer] |
|
if FairseqOptimizer: |
|
|
|
expected_optim_types.append(FairseqOptimizer) |
|
return expected_optim_types |
|
|
|
|
|
def _do_sanity_check(self): |
|
if self.fp16_enabled() and not get_accelerator().is_fp16_supported(): |
|
raise ValueError("Type fp16 is not supported on your device.") |
|
|
|
if self.bfloat16_enabled() and not get_accelerator().is_bf16_supported(): |
|
raise ValueError("Type bf16 is not supported on your device.") |
|
|
|
expected_optim_types = self._supported_optims() |
|
expected_optim_types += [type(None), Callable] |
|
assert isinstance(self.client_optimizer, tuple(expected_optim_types)), \ |
|
f'Client Optimizer is of unexpected type {type(self.client_optimizer)}' |
|
|
|
if not self.client_optimizer: |
|
if self.optimizer_name() is not None: |
|
assert self._is_supported_optimizer( |
|
self.optimizer_name()), "{} is not a supported DeepSpeed Optimizer".format(self.optimizer_name()) |
|
|
|
if (self.optimizer_name() == LAMB_OPTIMIZER or self.optimizer_name() == ONEBIT_LAMB_OPTIMIZER): |
|
assert (self.dynamic_loss_scale()), "DeepSpeed {} optimizer requires dynamic loss scaling".format( |
|
self.optimizer_name()) |
|
|
|
|
|
if isinstance(self.client_lr_scheduler, _LRScheduler): |
|
assert isinstance(self.client_optimizer, Optimizer), \ |
|
f'Client Optimizer (type = {type(self.client_optimizer)} is not instantiated but Client LR Scheduler is instantiated' |
|
|
|
def _broadcast_model(self): |
|
|
|
def is_replicated(p): |
|
if hasattr(p, "ds_status") and p.ds_status is not ZeroParamStatus.AVAILABLE: |
|
return False |
|
elif hasattr(p, 'ds_optim_param'): |
|
|
|
return False |
|
return True |
|
|
|
for n, p in self.module.named_parameters(): |
|
|
|
if is_moe_param(p): |
|
if torch.is_tensor(p) and is_replicated(p): |
|
dist.broadcast(p.data, |
|
groups._get_expert_broadcast_src_rank(p.group_name), |
|
group=self.expert_data_parallel_group[p.group_name]) |
|
else: |
|
if torch.is_tensor(p) and is_replicated(p): |
|
dist.broadcast(p.data, groups._get_broadcast_src_rank(), group=self.seq_data_parallel_group) |
|
|
|
@staticmethod |
|
def __check_params(model: Module, dtype: torch.dtype) -> None: |
|
return |
|
if not all(param.dtype == dtype for param in model.parameters()) and dist.get_rank() == 0: |
|
raise ValueError(f"{dtype} is enabled but the following parameters have dtype that is " |
|
f"not {dtype}: " |
|
f"{[(n, p.dtype) for n, p in model.named_parameters() if p.dtype != dtype]}") |
|
|
|
def _set_client_model(self, model): |
|
|
|
modules = self.__dict__.get('_modules') |
|
modules['module'] = model |
|
|
|
self.__dict__['module'] = model |
|
|
|
def _configure_distributed_model(self, model): |
|
self._set_client_model(model) |
|
is_zero_init_model = self.zero_optimization_partition_weights() and any( |
|
[hasattr(param, "ds_id") for param in self.module.parameters()]) |
|
|
|
if self.fp16_enabled(): |
|
if is_zero_init_model: |
|
self.__check_params(self.module, torch.half) |
|
self.module.half() |
|
elif self.bfloat16_enabled(): |
|
if is_zero_init_model: |
|
self.__check_params(self.module, torch.bfloat16) |
|
self.module.bfloat16() |
|
else: |
|
self.__check_params(self.module, torch.float) |
|
|
|
|
|
if not (self.dont_change_device or is_zero_init_model): |
|
self.module.to(self.device) |
|
|
|
|
|
for _, module in self.module.named_modules(): |
|
if isinstance(module, MoE): |
|
self.has_moe_layers = True |
|
self.num_experts.append(module.num_experts) |
|
|
|
if self.has_moe_layers: |
|
for _, module in self.module.named_modules(): |
|
if isinstance(module, TopKGate): |
|
self.gate_modules.append(module) |
|
if self.wall_clock_breakdown(): |
|
module.wall_clock_breakdown = True |
|
if isinstance(module, MOELayer): |
|
self.moe_layers.append(module) |
|
if self.wall_clock_breakdown(): |
|
module.wall_clock_breakdown = True |
|
|
|
|
|
if self.mpu is not None: |
|
groups.mpu = self.mpu |
|
|
|
|
|
for _, module in self.module.named_modules(): |
|
if hasattr(module, 'set_deepspeed_parallelism'): |
|
module.set_deepspeed_parallelism(self._config.use_data_before_expert_parallel_) |
|
|
|
|
|
self.local_all_to_all_group = None |
|
if self.zero_quantized_gradients(): |
|
message = "Using LoCo quantized gradients" if self.zeropp_loco_param() else "Using quantized gradients" |
|
log_dist(message, ranks=[0]) |
|
self.local_all_to_all_group = groups._get_local_all_to_all_group() |
|
self.data_parallel_group = groups._get_data_parallel_group() |
|
self.dp_world_size = groups._get_data_parallel_world_size() |
|
self.seq_data_parallel_group = groups._get_sequence_data_parallel_group() |
|
self.seq_dp_world_size = groups._get_sequence_data_parallel_world_size() |
|
self.mp_world_size = groups._get_model_parallel_world_size() |
|
self.expert_parallel_group = groups._get_expert_parallel_group_dict() |
|
self.expert_data_parallel_group = groups._get_expert_data_parallel_group_dict() |
|
self.sequence_parallel_size = groups._get_sequence_parallel_world_size() |
|
if self.sequence_parallel_size > 1: |
|
self.communication_data_type = self._config.seq_parallel_communication_data_type |
|
self.seq_parallel_group = groups._get_sequence_parallel_group() |
|
|
|
if dist.get_rank() == 0: |
|
summary = "********** distributed groups summary **********\n" |
|
summary += f"\t {self.dp_world_size=}\n" |
|
summary += f"\t {self.mp_world_size=}\n" |
|
summary += f"\t {self.seq_dp_world_size=}\n" |
|
summary += f"\t {self.sequence_parallel_size=}\n" |
|
summary += "***********************************************" |
|
logger.info(summary) |
|
|
|
if not (self.amp_enabled() or is_zero_init_model): |
|
self._broadcast_model() |
|
|
|
|
|
def _check_for_duplicates(self, optimizer): |
|
for name, param in self.module.named_parameters(): |
|
param_id = id(param) |
|
|
|
def ids_list(group): |
|
return [id(param) for param in group] |
|
|
|
occurrence = sum([ |
|
ids_list(group['params']).count(param_id) if param_id in ids_list(group['params']) else 0 |
|
for group in optimizer.param_groups |
|
]) |
|
assert occurrence <= 1, f"Parameter with name: {name} occurs multiple times in optimizer.param_groups. Make sure it only appears once to prevent undefined behavior." |
|
|
|
def _do_optimizer_sanity_check(self, basic_optimizer): |
|
model_dtype, grad_accum_dtype = self.get_data_types() |
|
zero_enabled = self.zero_optimization() |
|
amp_enabled = self.amp_enabled() |
|
|
|
assert ( |
|
not (amp_enabled and zero_enabled) |
|
), "Amp and ZeRO are not currently compatible, please use (legacy) fp16 mode which performs similar to amp opt_mode=O2" |
|
if zero_enabled: |
|
if not is_zero_supported_optimizer(basic_optimizer): |
|
assert ( |
|
self.zero_allow_untested_optimizer() |
|
), 'You are using an untested ZeRO Optimizer. Please add <"zero_allow_untested_optimizer": true> in the configuration file to use it.' |
|
|
|
if self.global_rank == 0: |
|
logger.warning("**** You are using ZeRO with an untested optimizer, proceed with caution *****") |
|
if model_dtype == torch.bfloat16 and grad_accum_dtype == torch.float32 and self.zero_optimization_stage( |
|
) == 1 and not self.zero_cpu_offload(): |
|
return BFLOAT16 |
|
return ZERO_OPTIMIZATION |
|
elif amp_enabled: |
|
if model_dtype != grad_accum_dtype: |
|
raise NotImplementedError( |
|
"Model data type and gradient accumulation data type must be equal to use Amp") |
|
if model_dtype == torch.bfloat16 or model_dtype == torch.float16: |
|
raise NotImplementedError("Cannot enable both amp with (legacy) fp16 or bfloat16 mode") |
|
try: |
|
logger.info("Initializing Apex amp from: {}".format(amp.__path__)) |
|
except NameError: |
|
|
|
raise RuntimeError("Unable to import apex/amp, please make sure it is installed") |
|
return AMP |
|
|
|
elif model_dtype == grad_accum_dtype: |
|
if model_dtype == torch.bfloat16: |
|
if self.pipeline_parallelism: |
|
logger.warning( |
|
"**** BF16 gradient accumulation is not safe numerically with large number of accumulation steps, proceed with caution *****" |
|
) |
|
return BFLOAT16 |
|
else: |
|
raise NotImplementedError( |
|
"Bfloat16 wrapper must use a gradient accumulation type of fp32, enable ZeRO to use Bfloat16 gradient accumulation" |
|
) |
|
if model_dtype == torch.float16: |
|
return FP16 |
|
|
|
elif model_dtype == torch.bfloat16 and grad_accum_dtype == torch.float32: |
|
return BFLOAT16 |
|
else: |
|
raise NotImplementedError("unsupported mix of model dtype and gradient accumulation type") |
|
|
|
return None |
|
|
|
|
|
def _configure_optimizer(self, client_optimizer, model_parameters): |
|
if client_optimizer is None: |
|
if self.has_moe_layers: |
|
model_parameters = configure_moe_param_groups(model_parameters) |
|
basic_optimizer = self._configure_basic_optimizer(model_parameters) |
|
log_dist(f"Using DeepSpeed Optimizer param name {self.optimizer_name()} as basic optimizer", ranks=[0]) |
|
else: |
|
if isinstance(client_optimizer, tuple(self._supported_optims())): |
|
basic_optimizer = client_optimizer |
|
log_dist('Using client Optimizer as basic optimizer', ranks=[0]) |
|
else: |
|
basic_optimizer = client_optimizer(model_parameters) |
|
log_dist('Using client callable to create basic optimizer', ranks=[0]) |
|
|
|
if self.zero_use_cpu_optimizer() and not isinstance(basic_optimizer, deepspeed.ops.adam.DeepSpeedCPUAdam): |
|
if self.zero_force_ds_cpu_optimizer(): |
|
msg = f'You are using ZeRO-Offload with a client provided optimizer ({type(basic_optimizer)}) which in most cases will yield poor performance. Please either use deepspeed.ops.adam.DeepSpeedCPUAdam or set an optimizer in your ds-config (https://www.deepspeed.ai/docs/config-json/#optimizer-parameters). If you really want to use a custom optimizer w. ZeRO-Offload and understand the performance impacts you can also set <"zero_force_ds_cpu_optimizer": false> in your configuration file.' |
|
raise ZeRORuntimeException(msg) |
|
|
|
basic_optimizer.param_groups[:] = [pg for pg in basic_optimizer.param_groups if len(pg["params"]) != 0] |
|
log_dist("Removing param_group that has no 'params' in the basic Optimizer", ranks=[0]) |
|
|
|
self._check_for_duplicates(basic_optimizer) |
|
|
|
self.basic_optimizer = basic_optimizer |
|
log_dist(f"DeepSpeed Basic Optimizer = {basic_optimizer.__class__.__name__}", ranks=[0]) |
|
|
|
optimizer_wrapper = self._do_optimizer_sanity_check(basic_optimizer) |
|
|
|
if optimizer_wrapper == ZERO_OPTIMIZATION: |
|
self.optimizer = self._configure_zero_optimizer(basic_optimizer) |
|
elif optimizer_wrapper == AMP: |
|
amp_params = self.amp_params() |
|
log_dist(f"Initializing AMP with these params: {amp_params}", ranks=[0]) |
|
model, self.optimizer = amp.initialize(self.module, basic_optimizer, **amp_params) |
|
self._set_client_model(model) |
|
self._broadcast_model() |
|
|
|
elif optimizer_wrapper == FP16: |
|
self.optimizer = self._configure_fp16_optimizer(basic_optimizer) |
|
elif optimizer_wrapper == BFLOAT16: |
|
self.optimizer = self._configure_bf16_optimizer(basic_optimizer) |
|
else: |
|
self.optimizer = basic_optimizer |
|
|
|
log_dist("DeepSpeed Final Optimizer = {}".format(self.optimizer.__class__.__name__), ranks=[0]) |
|
|
|
self.compression_scheduler = self._configure_compression_scheduler() |
|
self.quantizer = self._configure_quantization() |
|
|
|
def _configure_basic_optimizer(self, model_parameters): |
|
optimizer_parameters = self.optimizer_params() |
|
if optimizer_parameters is None: |
|
optimizer_parameters = {} |
|
|
|
if "max_grad_norm" in optimizer_parameters.keys(): |
|
raise ValueError( |
|
"'max_grad_norm' is not supported as an optimizer parameter, please switch to using the deepspeed parameter 'gradient_clipping' see: https://www.deepspeed.ai/docs/config-json/#gradient-clipping for more details" |
|
) |
|
|
|
if self.optimizer_name() in [ADAM_OPTIMIZER, ADAMW_OPTIMIZER]: |
|
torch_adam = optimizer_parameters.pop(TORCH_ADAM_PARAM, False) |
|
adam_w_mode = optimizer_parameters.pop(ADAM_W_MODE, ADAM_W_MODE_DEFAULT) |
|
|
|
|
|
effective_adam_w_mode = self.optimizer_name() == ADAMW_OPTIMIZER or adam_w_mode |
|
|
|
if torch_adam: |
|
if not effective_adam_w_mode: |
|
optimizer = torch.optim.Adam(model_parameters, **optimizer_parameters) |
|
else: |
|
optimizer = torch.optim.AdamW(model_parameters, **optimizer_parameters) |
|
else: |
|
if self.zero_use_cpu_optimizer(): |
|
from deepspeed.ops.adam import DeepSpeedCPUAdam |
|
optimizer = DeepSpeedCPUAdam(model_parameters, |
|
**optimizer_parameters, |
|
adamw_mode=effective_adam_w_mode) |
|
else: |
|
from deepspeed.ops.adam import FusedAdam |
|
|
|
optimizer = FusedAdam( |
|
model_parameters, |
|
**optimizer_parameters, |
|
adam_w_mode=effective_adam_w_mode, |
|
) |
|
|
|
elif self.optimizer_name() == ADAGRAD_OPTIMIZER: |
|
if self.zero_use_cpu_optimizer(): |
|
from deepspeed.ops.adagrad import DeepSpeedCPUAdagrad |
|
optimizer = DeepSpeedCPUAdagrad(model_parameters, **optimizer_parameters) |
|
else: |
|
optimizer = torch.optim.Adagrad(model_parameters, **optimizer_parameters) |
|
elif self.optimizer_name() == LAMB_OPTIMIZER: |
|
from deepspeed.ops.lamb import FusedLamb |
|
|
|
optimizer = FusedLamb(model_parameters, **optimizer_parameters) |
|
elif self.optimizer_name() == ONEBIT_ADAM_OPTIMIZER: |
|
assert not self.zero_optimization(), "1bit-Adam is not compatible with ZeRO" |
|
from deepspeed.runtime.fp16.onebit.adam import OnebitAdam |
|
|
|
optimizer = OnebitAdam(model_parameters, self, **optimizer_parameters) |
|
if not self.fp16_enabled(): |
|
logger.warning(f"Currently the convergence of 1-bit Adam is only verified under FP16") |
|
elif self.optimizer_name() == ZERO_ONE_ADAM_OPTIMIZER: |
|
assert not self.zero_optimization(), "0/1 Adam is not compatible with ZeRO" |
|
from deepspeed.runtime.fp16.onebit.zoadam import ZeroOneAdam |
|
|
|
optimizer = ZeroOneAdam(model_parameters, self, **optimizer_parameters) |
|
if not self.fp16_enabled(): |
|
logger.warning(f'Currently the convergence of 0/1 Adam is only verified under FP16') |
|
elif self.optimizer_name() == ONEBIT_LAMB_OPTIMIZER: |
|
assert not self.zero_optimization(), "1bit-Lamb is not compatible with ZeRO" |
|
from deepspeed.runtime.fp16.onebit.lamb import OnebitLamb |
|
|
|
optimizer = OnebitLamb(model_parameters, self, **optimizer_parameters) |
|
if not self.fp16_enabled(): |
|
logger.warning(f"Currently the convergence of 1-bit Lamb is only verified under FP16") |
|
elif self.optimizer_name() == LION_OPTIMIZER: |
|
if self.zero_use_cpu_optimizer(): |
|
from deepspeed.ops.lion import DeepSpeedCPULion |
|
optimizer = DeepSpeedCPULion(model_parameters, **optimizer_parameters) |
|
else: |
|
from deepspeed.ops.lion import FusedLion |
|
optimizer = FusedLion(model_parameters, **optimizer_parameters) |
|
elif self.optimizer_name() == MUADAM_OPTIMIZER: |
|
try: |
|
from mup import MuAdam |
|
except ImportError: |
|
logger.error(f"Install mup to use MuAdam optimizer") |
|
optimizer = MuAdam(model_parameters, **optimizer_parameters) |
|
elif self.optimizer_name() == MUADAMW_OPTIMIZER: |
|
try: |
|
from mup import MuAdamW |
|
except ImportError: |
|
logger.error(f"Install mup to use MuAdamW optimizer") |
|
optimizer = MuAdamW(model_parameters, **optimizer_parameters) |
|
elif self.optimizer_name() == MUSGD_OPTIMIZER: |
|
try: |
|
from mup import MuSGD |
|
except ImportError: |
|
logger.error(f"Install mup to use MuSGD optimizer") |
|
optimizer = MuSGD(model_parameters, **optimizer_parameters) |
|
else: |
|
torch_optimizer = getattr(torch.optim, self.optimizer_name()) |
|
optimizer = torch_optimizer(model_parameters, **optimizer_parameters) |
|
return optimizer |
|
|
|
def _configure_compression_scheduler(self): |
|
return compression_scheduler(self.module, self._config.compression_config) |
|
|
|
def _configure_random_ltd_scheduler(self, configs): |
|
return RandomLTDScheduler(configs) |
|
|
|
def _configure_quantization(self): |
|
( |
|
quantize_weight_in_forward, |
|
quantize_enabled, |
|
q_groups, |
|
q_mixed_fp16, |
|
q_change_ratio, |
|
q_type, |
|
q_rounding, |
|
q_verbose, |
|
use_quantizer_kernel, |
|
) = self.quantize_training() |
|
if quantize_enabled and not quantize_weight_in_forward: |
|
assert self.fp16_enabled( |
|
), "MoQ (quantize in optimization step) weight quantization is only supported for FP16" |
|
quantizer = None |
|
if quantize_enabled and not quantize_weight_in_forward: |
|
from deepspeed.runtime.quantize import Quantizer |
|
|
|
quantizer = Quantizer( |
|
q_groups, |
|
q_mixed_fp16, |
|
q_change_ratio, |
|
q_type, |
|
q_rounding, |
|
q_verbose, |
|
self.eigenvalue_enabled(), |
|
use_quantizer_kernel, |
|
self.eigenvalue_layer_num() if self.eigenvalue_enabled() else 0, |
|
) |
|
return quantizer |
|
|
|
def _configure_fp16_optimizer(self, optimizer): |
|
initial_dynamic_scale = self.initial_dynamic_scale() |
|
dynamic_loss_args = self.dynamic_loss_scale_args() |
|
clip_grad = self.gradient_clipping() |
|
|
|
if APEX_INSTALLED: |
|
fused_opts = (apex.optimizers.FusedAdam, FusedAdam) |
|
else: |
|
fused_opts = FusedAdam |
|
|
|
if isinstance(optimizer, fused_opts) \ |
|
or self.optimizer_name() in [ONEBIT_ADAM_OPTIMIZER, ZERO_ONE_ADAM_OPTIMIZER]: |
|
if self.dynamic_loss_scale(): |
|
log_dist(f'Creating fp16 optimizer with dynamic loss scale', ranks=[0]) |
|
timers = self.timers if self.wall_clock_breakdown() else NoopTimer() |
|
optimizer = FP16_Optimizer( |
|
optimizer, |
|
deepspeed=self, |
|
dynamic_loss_scale=True, |
|
initial_dynamic_scale=initial_dynamic_scale, |
|
dynamic_loss_args=dynamic_loss_args, |
|
mpu=self.mpu, |
|
clip_grad=clip_grad, |
|
fused_adam_legacy=self.optimizer_legacy_fusion(), |
|
timers=timers, |
|
has_moe_layers=self.has_moe_layers, |
|
) |
|
else: |
|
log_dist(f'Creating fp16 optimizer with static loss scale: {self.loss_scale()}', ranks=[0]) |
|
timers = self.timers if self.wall_clock_breakdown() else NoopTimer() |
|
optimizer = FP16_Optimizer( |
|
optimizer, |
|
deepspeed=self, |
|
static_loss_scale=self.loss_scale(), |
|
mpu=self.mpu, |
|
clip_grad=clip_grad, |
|
fused_adam_legacy=self.optimizer_legacy_fusion(), |
|
timers=timers, |
|
has_moe_layers=self.has_moe_layers, |
|
) |
|
else: |
|
log_dist(f'Creating fp16 unfused optimizer with dynamic loss scale', ranks=[0]) |
|
optimizer = FP16_UnfusedOptimizer( |
|
optimizer, |
|
deepspeed=self, |
|
static_loss_scale=self.loss_scale(), |
|
dynamic_loss_scale=self.dynamic_loss_scale(), |
|
dynamic_loss_args=dynamic_loss_args, |
|
mpu=self.mpu, |
|
clip_grad=clip_grad, |
|
fused_lamb_legacy=self.optimizer_name() == LAMB_OPTIMIZER, |
|
) |
|
|
|
return optimizer |
|
|
|
def _configure_bf16_optimizer(self, optimizer): |
|
clip_grad = self.gradient_clipping() |
|
|
|
if optimizer is None: |
|
optimizer = DummyOptim(list(self.module.parameters())) |
|
|
|
log_dist('Creating BF16 optimizer', ranks=[0]) |
|
|
|
timers = self.timers if self.wall_clock_breakdown() else NoopTimer() |
|
optimizer = BF16_Optimizer(optimizer, |
|
self.param_names, |
|
bfloat16_config=self._config.bfloat16_config, |
|
mpu=self.mpu, |
|
clip_grad=clip_grad, |
|
allgather_bucket_size=self.zero_allgather_bucket_size(), |
|
dp_process_group=self.seq_data_parallel_group, |
|
timers=timers, |
|
grad_acc_dtype=self.get_data_types()[1], |
|
graph_harvesting=self.graph_harvesting(), |
|
has_moe_layers=self.has_moe_layers) |
|
|
|
return optimizer |
|
|
|
def _configure_zero_optimizer(self, optimizer): |
|
zero_stage = self.zero_optimization_stage() |
|
|
|
mics_shard_size = self.mics_shard_size() |
|
model_dtype, gradient_accumulation_dtype = self.get_data_types() |
|
|
|
if self.bfloat16_enabled(): |
|
check_grad_overflow = self._config.bfloat16_config.check_grad_overflow |
|
elif self.fp16_enabled(): |
|
check_grad_overflow = True |
|
else: |
|
check_grad_overflow = False |
|
|
|
timers = self.timers if self.wall_clock_breakdown() else NoopTimer() |
|
|
|
if optimizer is None: |
|
optimizer = DummyOptim(list(self.module.parameters())) |
|
|
|
if self.zero_legacy_stage1(): |
|
raise Exception( |
|
"The deprecated version of ZeRO Stage 1 is not supported in deepspeed >= 0.5.9. Please downgrade to a version less than 0.5.9 if you need to use this deprecated version of ZeRO." |
|
) |
|
|
|
if zero_stage <= ZeroStageEnum.gradients: |
|
overlap_comm = self.zero_overlap_comm() |
|
contiguous_gradients = self.zero_contiguous_gradients() |
|
round_robin_gradients = self.zero_round_robin_gradients() |
|
assert not isinstance(optimizer, DummyOptim), "zero stage {} requires an optimizer".format(zero_stage) |
|
|
|
log_dist(f'Creating {model_dtype} ZeRO stage {zero_stage} optimizer', ranks=[0]) |
|
|
|
if isinstance(self.module, PipelineModule): |
|
if overlap_comm: |
|
logger.warning("Pipeline parallelism does not support overlapped communication, will be disabled.") |
|
overlap_comm = False |
|
optimizer = DeepSpeedZeroOptimizer( |
|
optimizer, |
|
self.param_names, |
|
timers=timers, |
|
static_loss_scale=self.loss_scale(), |
|
dynamic_loss_scale=self.dynamic_loss_scale(), |
|
dynamic_loss_args=self.dynamic_loss_scale_args(), |
|
clip_grad=self.gradient_clipping(), |
|
contiguous_gradients=contiguous_gradients, |
|
reduce_bucket_size=self.zero_reduce_bucket_size(), |
|
use_multi_rank_bucket_allreduce=self.zero_multi_rank_bucket_allreduce(), |
|
allgather_bucket_size=self.zero_allgather_bucket_size(), |
|
dp_process_group=self.seq_data_parallel_group, |
|
expert_parallel_group=self.expert_parallel_group if self.has_moe_layers else None, |
|
expert_data_parallel_group=self.expert_data_parallel_group if self.has_moe_layers else None, |
|
reduce_scatter=self.zero_reduce_scatter(), |
|
overlap_comm=overlap_comm, |
|
offload_optimizer_config=self.zero_offload_optimizer(), |
|
mpu=self.mpu, |
|
postscale_gradients=self.postscale_gradients(), |
|
gradient_predivide_factor=self.gradient_predivide_factor(), |
|
gradient_accumulation_steps=self.gradient_accumulation_steps(), |
|
ignore_unused_parameters=self.zero_ignore_unused_parameters(), |
|
partition_grads=zero_stage == ZeroStageEnum.gradients, |
|
round_robin_gradients=round_robin_gradients, |
|
has_moe_layers=self.has_moe_layers, |
|
fp16_master_weights_and_gradients=self.fp16_master_weights_and_gradients(), |
|
gradient_accumulation_dtype=gradient_accumulation_dtype, |
|
communication_data_type=self.communication_data_type, |
|
elastic_checkpoint=self.zero_elastic_checkpoint(), |
|
check_grad_overflow=check_grad_overflow) |
|
|
|
elif zero_stage == ZeroStageEnum.weights: |
|
assert not self.has_moe_layers, "MoE not supported with Stage 3" |
|
if isinstance(optimizer, DummyOptim): |
|
log_dist("Creating ZeRO Offload", ranks=[0]) |
|
zero_param_parallel_group = groups._get_zero_param_intra_parallel_group() |
|
if self.zero_hpz_partition_size() > 1 and zero_param_parallel_group is None: |
|
self._set_zero_group_parallelism() |
|
zero_param_parallel_group = groups._get_zero_param_intra_parallel_group() |
|
optimizer = DeepSpeedZeRoOffload( |
|
self.module, |
|
timers=timers, |
|
ds_config=self.config, |
|
overlap_comm=self.zero_overlap_comm(), |
|
prefetch_bucket_size=self.zero_prefetch_bucket_size(), |
|
max_reuse_distance=self.zero_max_reuse_distance(), |
|
max_live_parameters=self.zero_max_live_parameters(), |
|
param_persistence_threshold=self.zero_param_persistence_threshold(), |
|
model_persistence_threshold=self.zero_model_persistence_threshold(), |
|
offload_param_config=self.zero_offload_param(), |
|
mpu=self.mpu, |
|
zero_param_parallel_group=zero_param_parallel_group, |
|
zero_quantized_weights=self.zero_quantized_weights(), |
|
zero_quantized_nontrainable_weights=self.zero_quantized_nontrainable_weights(), |
|
zero_module_granularity_threshold=self.zero_module_granularity_threshold(), |
|
log_trace_cache_warnings=self.zero_log_trace_cache_warnings(), |
|
) |
|
else: |
|
log_dist( |
|
f'Creating fp16 ZeRO stage {zero_stage} optimizer,' |
|
f' MiCS is enabled {mics_shard_size>0},' |
|
f' Hierarchical params gather {self._config.mics_hierarchial_params_gather}', |
|
ranks=[0]) |
|
if mics_shard_size > 0: |
|
return self._return_mics_optimizer(optimizer, timers) |
|
|
|
log_dist(f'Creating {model_dtype} ZeRO stage {zero_stage} optimizer', ranks=[0]) |
|
from deepspeed.runtime.zero.stage3 import DeepSpeedZeroOptimizer_Stage3 |
|
optimizer = DeepSpeedZeroOptimizer_Stage3( |
|
self.module, |
|
optimizer, |
|
timers=timers, |
|
ds_config=self.config, |
|
static_loss_scale=self.loss_scale(), |
|
dynamic_loss_scale=self.dynamic_loss_scale(), |
|
dynamic_loss_args=self.dynamic_loss_scale_args(), |
|
clip_grad=self.gradient_clipping(), |
|
contiguous_gradients=self.zero_contiguous_gradients(), |
|
reduce_bucket_size=self.zero_reduce_bucket_size(), |
|
prefetch_bucket_size=self.zero_prefetch_bucket_size(), |
|
max_reuse_distance=self.zero_max_reuse_distance(), |
|
max_live_parameters=self.zero_max_live_parameters(), |
|
param_persistence_threshold=self.zero_param_persistence_threshold(), |
|
model_persistence_threshold=self.zero_model_persistence_threshold(), |
|
dp_process_group=self.seq_data_parallel_group, |
|
all2all_process_group=self.local_all_to_all_group, |
|
reduce_scatter=self.zero_reduce_scatter(), |
|
overlap_comm=self.zero_overlap_comm(), |
|
offload_optimizer_config=self.zero_offload_optimizer(), |
|
offload_param_config=self.zero_offload_param(), |
|
sub_group_size=self.zero_sub_group_size(), |
|
offload_ratio=self.zero_partial_offload(), |
|
mpu=self.mpu, |
|
postscale_gradients=self.postscale_gradients(), |
|
gradient_predivide_factor=self.gradient_predivide_factor(), |
|
gradient_accumulation_steps=self.gradient_accumulation_steps(), |
|
aio_config=self.aio_config(), |
|
gradient_accumulation_dtype=gradient_accumulation_dtype, |
|
communication_data_type=self.communication_data_type, |
|
zero_hpz_partition_size=self.zero_hpz_partition_size(), |
|
zero_quantized_weights=self.zero_quantized_weights(), |
|
zero_quantized_nontrainable_weights=self.zero_quantized_nontrainable_weights(), |
|
zero_module_granularity_threshold=self.zero_module_granularity_threshold(), |
|
zeropp_loco_param=self.zeropp_loco_param(), |
|
log_trace_cache_warnings=self.zero_log_trace_cache_warnings(), |
|
) |
|
|
|
else: |
|
raise NotImplementedError("ZeRO stage {} not implemented".format(zero_stage)) |
|
|
|
return optimizer |
|
|
|
def _return_mics_optimizer(self, basic_optimizer, timers): |
|
from deepspeed.runtime.zero.mics import MiCS_Optimizer |
|
model_dtype, gradient_accumulation_dtype = self.get_data_types() |
|
optimizer = MiCS_Optimizer(self.module, |
|
basic_optimizer, |
|
timers=timers, |
|
ds_config=self.config, |
|
static_loss_scale=self.loss_scale(), |
|
dynamic_loss_scale=self.dynamic_loss_scale(), |
|
dynamic_loss_args=self.dynamic_loss_scale_args(), |
|
clip_grad=self.gradient_clipping(), |
|
contiguous_gradients=self.zero_contiguous_gradients(), |
|
reduce_bucket_size=self.zero_reduce_bucket_size(), |
|
prefetch_bucket_size=self.zero_prefetch_bucket_size(), |
|
max_reuse_distance=self.zero_max_reuse_distance(), |
|
max_live_parameters=self.zero_max_live_parameters(), |
|
param_persistence_threshold=self.zero_param_persistence_threshold(), |
|
model_persistence_threshold=self.zero_model_persistence_threshold(), |
|
dp_process_group=self.seq_data_parallel_group, |
|
reduce_scatter=self.zero_reduce_scatter(), |
|
overlap_comm=self.zero_overlap_comm(), |
|
offload_optimizer_config=self.zero_offload_optimizer(), |
|
offload_param_config=self.zero_offload_param(), |
|
sub_group_size=self.zero_sub_group_size(), |
|
mpu=self.mpu, |
|
postscale_gradients=self.postscale_gradients(), |
|
gradient_predivide_factor=self.gradient_predivide_factor(), |
|
gradient_accumulation_steps=self.gradient_accumulation_steps(), |
|
aio_config=self.aio_config(), |
|
gradient_accumulation_dtype=gradient_accumulation_dtype, |
|
communication_data_type=self.communication_data_type) |
|
return optimizer |
|
|
|
def _configure_eigenvalue(self): |
|
eigenvalue = Eigenvalue( |
|
verbose=self.eigenvalue_verbose(), |
|
max_iter=self.eigenvalue_max_iter(), |
|
tol=self.eigenvalue_tol(), |
|
stability=self.eigenvalue_stability(), |
|
gas_boundary_resolution=self.eigenvalue_gas_boundary_resolution(), |
|
layer_name=self.eigenvalue_layer_name(), |
|
layer_num=self.eigenvalue_layer_num(), |
|
) |
|
|
|
return eigenvalue |
|
|
|
def _configure_progressive_layer_drop(self): |
|
pld = ProgressiveLayerDrop(theta=self.pld_theta(), gamma=self.pld_gamma()) |
|
|
|
return pld |
|
|
|
def _configure_curriculum_scheduler_legacy(self): |
|
scheduler = CurriculumScheduler(self.curriculum_params_legacy()) |
|
return scheduler |
|
|
|
@staticmethod |
|
def is_map_style_dataset(obj): |
|
return hasattr(obj, "__getitem__") and hasattr(obj, "__len__") |
|
|
|
@staticmethod |
|
def is_iterable_style_dataset(obj): |
|
return isinstance(obj, torch.utils.data.IterableDataset) |
|
|
|
def dataloader_drop_last(self): |
|
return self._config.dataloader_drop_last |
|
|
|
def was_step_applied(self) -> bool: |
|
"""Returns True if the latest ``step()`` produced in parameter updates. |
|
Note that a ``False`` return is not an error condition. Steps are frequently |
|
no-ops, such as between gradient accumulation boundaries or when overflows |
|
occur. |
|
Returns: |
|
bool: Whether the latest ``step()`` modified model parameters. |
|
""" |
|
return self._step_applied |
|
|
|
def deepspeed_io(self, |
|
dataset, |
|
batch_size=None, |
|
route=ROUTE_TRAIN, |
|
pin_memory=True, |
|
data_sampler=None, |
|
collate_fn=None, |
|
num_local_io_workers=None): |
|
if not (self.is_map_style_dataset(dataset) or self.is_iterable_style_dataset(dataset)): |
|
raise ValueError("Training data must be a torch Dataset") |
|
|
|
if batch_size is None: |
|
batch_size = self.train_micro_batch_size_per_gpu() |
|
|
|
if collate_fn is None: |
|
collate_fn = self.collate_fn |
|
|
|
|
|
deepspeed_io_timer = None |
|
if route == ROUTE_TRAIN: |
|
deepspeed_io_timer = self.tput_timer |
|
|
|
|
|
data_parallel_world_size = self.dp_world_size |
|
data_parallel_rank = self.global_rank |
|
if self.mpu is not None: |
|
data_parallel_world_size = self.mpu.get_data_parallel_world_size() |
|
data_parallel_rank = self.mpu.get_data_parallel_rank() |
|
|
|
if data_sampler is None and (route == ROUTE_PREDICT or route == ROUTE_EVAL): |
|
data_sampler = torch.utils.data.DistributedSampler( |
|
dataset, |
|
num_replicas=data_parallel_world_size, |
|
rank=data_parallel_rank, |
|
shuffle=False, |
|
) |
|
|
|
deepspeed_dataloader_config = {} |
|
if self.curriculum_learning_enabled(): |
|
deepspeed_dataloader_config = { |
|
CURRICULUM_LEARNING: self.curriculum_learning_enabled(), |
|
DATA_EFFICIENCY: self.data_efficiency_config(), |
|
DATA_PARALLEL_GROUP: self.data_parallel_group, |
|
GRADIENT_ACCUMULATION_STEPS: self.gradient_accumulation_steps(), |
|
GLOBAL_RANK: self.global_rank, |
|
DATA_SAMPLING_NUM_WORKERS: self.data_sampling_config()[DATA_SAMPLING_NUM_WORKERS] |
|
} |
|
return DeepSpeedDataLoader(dataset=dataset, |
|
batch_size=batch_size, |
|
pin_memory=pin_memory, |
|
collate_fn=collate_fn, |
|
local_rank=self.local_rank, |
|
tput_timer=deepspeed_io_timer, |
|
num_local_io_workers=num_local_io_workers, |
|
data_sampler=data_sampler, |
|
data_parallel_world_size=data_parallel_world_size, |
|
data_parallel_rank=data_parallel_rank, |
|
dataloader_drop_last=self.dataloader_drop_last(), |
|
deepspeed_dataloader_config=deepspeed_dataloader_config) |
|
|
|
def train(self, mode=True): |
|
r"""""" |
|
|
|
self.warn_unscaled_loss = True |
|
self.module.train(mode) |
|
|
|
def eval(self): |
|
r"""""" |
|
|
|
self.warn_unscaled_loss = True |
|
self.module.train(False) |
|
|
|
def _scale_loss_by_gas(self, prescaled_loss, eval_micro_batches=None): |
|
|
|
|
|
scaling_factor = self.gradient_accumulation_steps() if eval_micro_batches is None else eval_micro_batches |
|
if isinstance(prescaled_loss, torch.Tensor): |
|
scaled_loss = prescaled_loss / scaling_factor |
|
elif isinstance(prescaled_loss, tuple) or isinstance(prescaled_loss, list): |
|
scaled_loss = [] |
|
for l in prescaled_loss: |
|
if isinstance(l, torch.Tensor): |
|
scaled_loss.append(l / scaling_factor) |
|
else: |
|
scaled_loss.append(l) |
|
else: |
|
scaled_loss = prescaled_loss |
|
if self.warn_unscaled_loss: |
|
logger.warning(f"DeepSpeed unable to scale loss because of type: {type(prescaled_loss)}") |
|
self.warn_unscaled_loss = False |
|
|
|
return scaled_loss |
|
|
|
def _create_module_forward_pre_hook(self): |
|
|
|
def _module_forward_pre_hook(module, inputs, kwargs): |
|
return self._forward_prologue(inputs, kwargs) |
|
|
|
return self.module.register_forward_pre_hook(_module_forward_pre_hook, prepend=False, with_kwargs=True) |
|
|
|
def _create_module_forward_post_hook(self): |
|
|
|
def _module_forward_post_hook(module, input, output): |
|
self._forward_epilogue() |
|
|
|
return self.module.register_forward_hook(_module_forward_post_hook) |
|
|
|
def _forward_prologue(self, inputs, kwargs): |
|
return_modified = False |
|
|
|
if not self.autotuning_profile_model_info(): |
|
see_memory_usage("Engine before forward", force=self.memory_breakdown()) |
|
|
|
flops_profiler_active = (self.flops_profiler_enabled() |
|
and self.global_steps == self.flops_profiler_profile_step() and self.global_rank == 0) |
|
|
|
|
|
if self.global_steps == 0 and hasattr(self, "compression_scheduler"): |
|
self.compression_scheduler.step(step_zero_check=True) |
|
if self.quantizer: |
|
tensor_to_quantize = self.optimizer.bit16_groups if self.zero_optimization_stage( |
|
) == 2 else self.optimizer.fp16_groups |
|
if self.compression_scheduler.weight_quantization_enabled: |
|
self.quantizer.quantize( |
|
tensor_to_quantize, |
|
(self.optimizer.overflow if self.fp16_enabled() else False), |
|
self.eigenvalue_enabled(), |
|
None, |
|
) |
|
return_modified = True |
|
|
|
if flops_profiler_active: |
|
self.flops_profiler.start_profile(ignore_list=None) |
|
|
|
if kwargs is not None: |
|
if self.module.training: |
|
if self.progressive_layer_drop: |
|
kwargs.update(self.progressive_layer_drop.get_state()) |
|
|
|
if self.__class__.__name__ != "PipelineEngine": |
|
|
|
|
|
if self.module.training and self.curriculum_enabled_legacy(): |
|
self.curriculum_scheduler_legacy.update_difficulty(self.global_steps + 1) |
|
if self.curriculum_params_legacy()["curriculum_type"] == "seqlen": |
|
kwargs.update({"curriculum_seqlen": self.curriculum_scheduler_legacy.get_current_difficulty()}) |
|
return_modified = True |
|
|
|
if self.module.training and self.random_ltd_enabled(): |
|
self.random_ltd_scheduler.update_seq(self.global_steps) |
|
|
|
if self.training_dataloader is None: |
|
self.tput_timer.start() |
|
|
|
self._start_timers(self.engine_timers.forward_timers) |
|
|
|
if self.zero_optimization_partition_weights(): |
|
|
|
|
|
for module in self.module.modules(): |
|
module._parameters._in_forward = True |
|
|
|
if self.fp16_auto_cast(): |
|
inputs = self._cast_inputs_half(inputs) |
|
return_modified = True |
|
|
|
if return_modified: |
|
return inputs, kwargs |
|
|
|
def _forward_epilogue(self): |
|
if self.zero_optimization_partition_weights(): |
|
|
|
for module in self.module.modules(): |
|
module._parameters._in_forward = False |
|
|
|
self._stop_timers(self.engine_timers.forward_timers) |
|
|
|
flops_profiler_active = (self.flops_profiler_enabled() |
|
and self.global_steps == self.flops_profiler_profile_step() and self.global_rank == 0) |
|
|
|
if flops_profiler_active: |
|
self.flops_profiler.stop_profile() |
|
|
|
if not self.autotuning_profile_model_info(): |
|
see_memory_usage("Engine after forward", force=self.memory_breakdown()) |
|
|
|
@instrument_w_nvtx |
|
def forward(self, *inputs, **kwargs): |
|
r"""Execute forward propagation |
|
Arguments: |
|
*inputs: Variable length input list |
|
**kwargs: variable length keyword arguments |
|
""" |
|
if self.autotuning_profile_model_info(): |
|
ma = get_ma_status() |
|
|
|
if self.is_deepcompile_enabled() and hasattr(self, "launch_compile_passes"): |
|
|
|
self.launch_compile_passes(self.global_steps) |
|
|
|
loss = self.module(*inputs, **kwargs) |
|
|
|
if self.autotuning_profile_model_info(): |
|
activation_mem = get_ma_status() - ma |
|
self.autotuning_model_info["activation_mem_per_gpu"] = activation_mem |
|
print_json_dist(self.autotuning_model_info, [0], path=self.autotuning_model_info_path()) |
|
exit() |
|
|
|
return loss |
|
|
|
def _cast_inputs_half(self, inputs): |
|
if isinstance(inputs, (list, tuple)): |
|
new_inputs = [] |
|
for v in inputs: |
|
new_inputs.append(self._cast_inputs_half(v)) |
|
return inputs.__class__(new_inputs) |
|
elif isinstance(inputs, dict): |
|
new_inputs = {} |
|
for k, v in inputs.items(): |
|
new_inputs[k] = self._cast_inputs_half(v) |
|
return new_inputs |
|
elif hasattr(inputs, 'half') and inputs.is_floating_point(): |
|
return inputs.half() |
|
else: |
|
return inputs |
|
|
|
def print_forward_breakdown(self, fwd_time): |
|
gate_time = 0.0 |
|
moe_time = 0.0 |
|
falltoall = 0.0 |
|
salltoall = 0.0 |
|
|
|
for gate in self.gate_modules: |
|
|
|
gate_time += gate.gate_time |
|
|
|
for l in self.moe_layers: |
|
|
|
moe_time += l.time_moe |
|
falltoall += l.time_falltoall |
|
salltoall += l.time_salltoall |
|
|
|
|
|
|
|
|
|
log_dist( |
|
f"time (ms) | fwd: {fwd_time:.2f} (fwd_moe: {moe_time:.2f}, 1st_a2a: {falltoall:.2f}, 2nd_a2a: {salltoall:.2f}, top_k: {gate_time:.2f})", |
|
ranks=[0]) |
|
|
|
@instrument_w_nvtx |
|
def allreduce_gradients(self, bucket_size=MEMORY_OPT_ALLREDUCE_SIZE): |
|
|
|
self.optimizer.is_gradient_accumulation_boundary = self.is_gradient_accumulation_boundary() |
|
|
|
if self.zero_optimization_partition_gradients(): |
|
self.optimizer.overlapping_partition_gradients_reduce_epilogue() |
|
|
|
|
|
elif self.is_gradient_accumulation_boundary(): |
|
if self.zero_optimization_stage() == ZeroStageEnum.optimizer_states and hasattr( |
|
self.optimizer, 'reduce_gradients'): |
|
self.optimizer.reduce_gradients(pipeline_parallel=self.pipeline_parallelism) |
|
else: |
|
grads = None |
|
self.buffered_allreduce_fallback(grads=grads, elements_per_buffer=bucket_size) |
|
|
|
def _backward_prologue(self, loss, scale_wrt_gas=True): |
|
see_memory_usage("Engine before backward", force=self.memory_breakdown()) |
|
if self.scale_wrt_gas is not None: |
|
scale_wrt_gas = self.scale_wrt_gas |
|
|
|
|
|
do_gradient_reduction = self.enable_backward_allreduce and not self.inside_no_sync_ctxt and not self.is_deepcompile_enabled( |
|
) |
|
if do_gradient_reduction and self.gradient_accumulation_steps() > 1 and scale_wrt_gas: |
|
loss = self._scale_loss_by_gas(loss.float()) |
|
|
|
|
|
mean_loss = loss.mean().detach() |
|
self.losses = mean_loss if self.losses is None else self.losses + mean_loss |
|
if self.monitor.enabled: |
|
if self.is_gradient_accumulation_boundary(): |
|
if self.global_rank == 0: |
|
self.summary_events = [( |
|
f"Train/Samples/train_loss", |
|
self.losses.item(), |
|
self.global_samples, |
|
)] |
|
self.monitor.write_events(self.summary_events) |
|
|
|
if self.is_deepcompile_enabled(): |
|
deepcompile_backward_prologue(self.is_gradient_accumulation_boundary()) |
|
|
|
return loss |
|
|
|
def _backward_epilogue(self): |
|
self._start_timers(self.engine_timers.backward_reduce_timers) |
|
if self.enable_backward_allreduce and not self.inside_no_sync_ctxt: |
|
|
|
self.allreduce_gradients() |
|
|
|
self._stop_timers(self.engine_timers.backward_reduce_timers) |
|
see_memory_usage("Engine after backward", force=self.memory_breakdown()) |
|
|
|
def _do_optimizer_backward(self, loss, retain_graph): |
|
self._start_timers(self.engine_timers.backward_inner_timers) |
|
if self.zero_optimization(): |
|
self.optimizer.is_gradient_accumulation_boundary = self.is_gradient_accumulation_boundary() |
|
self.optimizer.backward(loss, retain_graph=retain_graph) |
|
elif self.amp_enabled(): |
|
|
|
|
|
delay_unscale = not self.is_gradient_accumulation_boundary() |
|
with amp.scale_loss(loss, self.optimizer, delay_unscale=delay_unscale) as scaled_loss: |
|
scaled_loss.backward(retain_graph=retain_graph) |
|
elif self.fp16_enabled(): |
|
if self.eigenvalue_enabled(): |
|
self.optimizer.backward(loss, create_graph=True, retain_graph=True) |
|
else: |
|
self.optimizer.backward(loss, retain_graph=retain_graph) |
|
elif self.bfloat16_enabled(): |
|
self.optimizer.backward(loss, retain_graph=retain_graph) |
|
else: |
|
if self.eigenvalue_enabled(): |
|
loss.backward(create_graph=True, retain_graph=True) |
|
else: |
|
loss.backward(retain_graph=retain_graph) |
|
self._stop_timers(self.engine_timers.backward_inner_timers) |
|
|
|
@contextmanager |
|
def no_sync(self): |
|
r""" |
|
Context manager to disable gradient reduction during backward pass. |
|
This context manager has the following effects on other DeepSpeed features: |
|
1. Incompatible with ZeRO stage 2/3 which rely on reduction for gradient partitioning. |
|
2. It is illegal to call engine.step() within the context manager. |
|
3. Tracking of gradient accumulation steps is disabled. |
|
""" |
|
assert not self.zero_optimization_partition_gradients(), \ |
|
f"no_sync context manager is incompatible with gradient partitioning logic of ZeRO stage {self.zero_optimization_stage()}" |
|
|
|
assert not self.inside_no_sync_ctxt, f"no_sync context manager reentry is unsupported" |
|
|
|
self.inside_no_sync_ctxt = True |
|
try: |
|
yield |
|
finally: |
|
self.inside_no_sync_ctxt = False |
|
|
|
@instrument_w_nvtx |
|
def backward(self, loss, retain_graph=False, scale_wrt_gas=True): |
|
r"""Execute backward pass on the loss |
|
Arguments: |
|
loss: Torch tensor on which to execute backward propagation |
|
retain_graph: bool, default: false |
|
forward on user defined choice of retain_graph |
|
""" |
|
assert self.optimizer is not None and not isinstance(self.optimizer, DummyOptim), \ |
|
"must provide optimizer during init in order to use backward" |
|
|
|
self._start_timers(self.engine_timers.backward_timers) |
|
loss = self._backward_prologue(loss, scale_wrt_gas) |
|
self._do_optimizer_backward(loss, retain_graph) |
|
self._backward_epilogue() |
|
self._stop_timers(self.engine_timers.backward_timers) |
|
|
|
return loss |
|
|
|
def is_gradient_accumulation_boundary(self): |
|
""" |
|
Query whether the current micro-batch is at the boundary of |
|
gradient accumulation, and thus will trigger gradient reductions and |
|
an optimizer step. |
|
|
|
Returns: |
|
bool: if the current step is a gradient accumulation boundary. |
|
|
|
""" |
|
if self._is_gradient_accumulation_boundary is None: |
|
return (self.micro_steps + 1) % \ |
|
self.gradient_accumulation_steps() == 0 |
|
else: |
|
return self._is_gradient_accumulation_boundary |
|
|
|
def set_gradient_accumulation_boundary(self, is_boundary): |
|
""" |
|
Manually overrides the DeepSpeed engine's gradient accumulation boundary state, this is an optional |
|
feature and should be used with care. The state should be set before to the intended |
|
value before each forward/backward. The final forward/backward should have the |
|
boundary state set to True. This style allows client code to only call engine.step() once after all |
|
the gradient accumulation passes are complete. See example below: |
|
.. code-block:: python |
|
engine.set_gradient_accumulation_boundary(False) |
|
for _ in range(gradient_accumulation_steps - 1): |
|
micro_batch = next(data_loader) |
|
loss = engine(micro_batch) |
|
engine.backward(loss) |
|
engine.set_gradient_accumulation_boundary(True) |
|
micro_batch = next(data_loader) |
|
loss = engine(micro_batch) |
|
engine.backward(loss) |
|
engine.step() |
|
Arguments: |
|
is_boundary (bool): are we at a gradient accumulation boundary or not? |
|
""" |
|
self._is_gradient_accumulation_boundary = is_boundary |
|
self.optimizer.is_gradient_accumulation_boundary = is_boundary |
|
|
|
def zero_grad(self): |
|
""" |
|
Zero parameter grads. |
|
""" |
|
for param_name, param in self.module.named_parameters(): |
|
param.grad = None |
|
|
|
def clip_fp32_gradients(self): |
|
clip_grad_norm_(parameters=self.module.parameters(), max_norm=self.gradient_clipping(), mpu=self.mpu) |
|
|
|
def _take_model_step(self, lr_kwargs, block_eigenvalue={}): |
|
if self.gradient_clipping() > 0.0: |
|
if not (self.fp16_enabled() or self.bfloat16_enabled() or self.amp_enabled() or self.zero_optimization()): |
|
self.clip_fp32_gradients() |
|
elif self.amp_enabled(): |
|
|
|
|
|
master_params = amp.master_params(self.optimizer) |
|
clip_grad_norm_(parameters=master_params, max_norm=self.gradient_clipping(), mpu=self.mpu) |
|
self.optimizer.step() |
|
|
|
if hasattr(self.optimizer, '_global_grad_norm'): |
|
self._global_grad_norm = self.optimizer._global_grad_norm |
|
|
|
|
|
if self.quantizer: |
|
tensor_to_quantize = self.optimizer.bit16_groups if self.zero_optimization_stage( |
|
) == 2 else self.optimizer.fp16_groups |
|
if self.compression_scheduler.weight_quantization_enabled: |
|
self.quantizer.quantize( |
|
tensor_to_quantize, |
|
(self.optimizer.overflow if self.fp16_enabled() else False), |
|
self.eigenvalue_enabled(), |
|
block_eigenvalue, |
|
) |
|
|
|
|
|
if self.bfloat16_enabled(): |
|
|
|
if self.zero_optimization() and hasattr(self.optimizer, "zero_grad"): |
|
self.optimizer.zero_grad() |
|
else: |
|
pass |
|
elif self.zero_optimization() or self.fp16_enabled() or self.amp_enabled(): |
|
self.optimizer.zero_grad() |
|
else: |
|
self.zero_grad() |
|
|
|
|
|
overflow = False |
|
if hasattr(self.optimizer, "overflow"): |
|
overflow = self.optimizer.overflow |
|
self._step_applied = not overflow |
|
|
|
if overflow: |
|
self.skipped_steps += 1 |
|
else: |
|
self.compression_scheduler.step() |
|
if self.lr_scheduler is not None: |
|
try: |
|
self.lr_scheduler.step(**(lr_kwargs or {})) |
|
except TypeError: |
|
|
|
|
|
|
|
self.lr_scheduler.step(self.train_batch_size()) |
|
|
|
if self.steps_per_print() is not None: |
|
report_progress = self.global_rank == 0 if self.global_rank else True |
|
if report_progress and (self.global_steps + 1) % self.steps_per_print() == 0: |
|
self._report_progress(self.global_steps + 1) |
|
|
|
self.losses = None |
|
self.global_steps += 1 |
|
self.global_samples += self.train_batch_size() |
|
|
|
def step(self, lr_kwargs=None): |
|
r"""Execute the weight update step after forward and backward propagation |
|
on effective_train_batch. |
|
""" |
|
assert not self.inside_no_sync_ctxt, \ |
|
"It is illegal to call Engine.step() inside no_sync context manager" |
|
|
|
see_memory_usage("Engine before step", force=self.memory_breakdown()) |
|
|
|
|
|
|
|
flops_profiler_active = self.flops_profiler_enabled( |
|
) and self.global_steps == self.flops_profiler_profile_step() and self.global_rank == 0 |
|
|
|
self._start_timers(self.engine_timers.step_timers) |
|
|
|
assert self.optimizer is not None and not isinstance(self.optimizer, DummyOptim), \ |
|
"must provide optimizer during init in order to use step" |
|
|
|
report_progress = False |
|
|
|
self._step_applied = False |
|
|
|
|
|
if self.is_gradient_accumulation_boundary(): |
|
self.gas_boundary_ctr += 1 |
|
|
|
if self.checkpoint_engine.is_decoupled(): |
|
self._commit_decoupled_checkpoint() |
|
|
|
if (self.eigenvalue_enabled() and (self.gas_boundary_ctr % self.eigenvalue_gas_boundary_resolution() == 0) |
|
and self.quantizer.any_precision_switch()): |
|
log_dist(f"computing eigenvalue...", ranks=[0]) |
|
self.block_eigenvalue = self.eigenvalue.compute_eigenvalue(self.module, self.device, |
|
self.optimizer.cur_scale) |
|
|
|
if self.progressive_layer_drop: |
|
self.progressive_layer_drop.update_state(self.global_steps) |
|
|
|
if (self.eigenvalue_enabled() and not self.gas_boundary_ctr % self.eigenvalue_gas_boundary_resolution() |
|
and self.quantizer.any_precision_switch()): |
|
self._take_model_step(lr_kwargs, self.block_eigenvalue) |
|
else: |
|
self._take_model_step(lr_kwargs) |
|
|
|
report_progress = self.global_rank == 0 if self.global_rank else True |
|
|
|
self.tput_timer.stop(global_step=self.is_gradient_accumulation_boundary(), report_speed=report_progress) |
|
|
|
self._stop_timers(self.engine_timers.step_timers) |
|
|
|
|
|
if self.monitor.enabled: |
|
if self.is_gradient_accumulation_boundary(): |
|
if self.global_rank == 0: |
|
self.summary_events = [(f"Train/Samples/lr", self.get_lr()[0], self.global_samples)] |
|
|
|
if self.fp16_enabled() and hasattr(self.optimizer, "cur_scale"): |
|
self.summary_events.append(( |
|
f"Train/Samples/loss_scale", |
|
self.optimizer.cur_scale, |
|
self.global_samples, |
|
)) |
|
|
|
if (self.eigenvalue_enabled() |
|
and not self.gas_boundary_ctr % self.eigenvalue_gas_boundary_resolution()): |
|
ev_values = self.block_eigenvalue.values() |
|
for i in range(len(ev_values)): |
|
self.summary_events.append(( |
|
f"Train/Eigenvalues/ModelBlockParam_{i}", |
|
self.ev_values[i][0], |
|
self.global_samples, |
|
)) |
|
self.monitor.write_events(self.summary_events) |
|
|
|
|
|
if flops_profiler_active: |
|
if self.autotuning_enabled(): |
|
self.flops = self.flops_profiler.get_total_flops() * 3 |
|
self.fwd_duration = self.flops_profiler.get_total_duration() |
|
else: |
|
self.flops_profiler.print_model_profile( |
|
profile_step=self.global_steps, |
|
module_depth=self.flops_profiler_module_depth(), |
|
top_modules=self.flops_profiler_top_modules(), |
|
detailed=self.flops_profiler_detailed(), |
|
output_file=self.flops_profiler_output_file(), |
|
) |
|
self.flops_profiler.end_profile() |
|
|
|
if self.autotuning_enabled() and self.global_steps == (self.autotuning_end_profile_step() + 1): |
|
self._autotuning_exit() |
|
|
|
if self.wall_clock_breakdown(): |
|
|
|
self.timers.log(names=self.engine_timers.micro_timers, memory_breakdown=self.memory_breakdown()) |
|
|
|
if self.wall_clock_breakdown() or self.flops_profiler_enabled(): |
|
|
|
if self.is_gradient_accumulation_boundary(): |
|
if self.monitor.enabled: |
|
self._write_monitor() |
|
|
|
if self.has_moe_layers: |
|
fwd_time = self.timers(FORWARD_GLOBAL_TIMER).elapsed(reset=False) |
|
self.print_forward_breakdown(fwd_time=fwd_time) |
|
|
|
self.timers.log(self.engine_timers.global_timers) |
|
|
|
self.micro_steps += 1 |
|
see_memory_usage("Engine after step", force=self.memory_breakdown()) |
|
|
|
def _start_timers(self, timer_names): |
|
for name in timer_names: |
|
self.timers(name).start() |
|
|
|
def _stop_timers(self, timer_names): |
|
record = self.is_gradient_accumulation_boundary() and \ |
|
self.flops_profiler_enabled() and \ |
|
(self.global_steps >= self.flops_profiler_profile_step()) |
|
for name in timer_names: |
|
self.timers(name).stop(record=record) |
|
|
|
def _autotuning_exit(self): |
|
if self.global_rank == 0: |
|
msg = self.timers.get_mean([ |
|
FORWARD_GLOBAL_TIMER, |
|
BACKWARD_GLOBAL_TIMER, |
|
STEP_GLOBAL_TIMER, |
|
], reset=False) |
|
titer = 0.0 |
|
titer += msg[FORWARD_GLOBAL_TIMER] if FORWARD_GLOBAL_TIMER in msg else 0 |
|
titer += msg[BACKWARD_GLOBAL_TIMER] if BACKWARD_GLOBAL_TIMER in msg else 0 |
|
titer += msg[STEP_GLOBAL_TIMER] if STEP_GLOBAL_TIMER in msg else 0 |
|
titer *= self.gradient_accumulation_steps() |
|
msg["latency"] = titer |
|
msg["FLOPS_per_gpu"] = self.flops * 1_000_000 * self.gradient_accumulation_steps() / titer |
|
msg["throughput"] = self.train_batch_size() * 1_000_000 / \ |
|
msg["latency"] |
|
print_json_dist(msg, [0], path=self.autotuning_metric_path()) |
|
log_dist( |
|
f"Wrote metrics to {self.autotuning_metric_path()}, {os.path.abspath(self.autotuning_metric_path())}", |
|
ranks=[0]) |
|
import atexit |
|
atexit.register(print, "Autotuning: done with running current ds config.") |
|
exit() |
|
|
|
def _write_monitor(self): |
|
if self.global_rank == 0: |
|
self.summary_events = [ |
|
( |
|
f"Train/Samples/elapsed_time_ms_forward", |
|
self.timers(FORWARD_GLOBAL_TIMER).elapsed(reset=False), |
|
self.global_samples, |
|
), |
|
( |
|
f"Train/Samples/elapsed_time_ms_backward", |
|
self.timers(BACKWARD_GLOBAL_TIMER).elapsed(reset=False), |
|
self.global_samples, |
|
), |
|
( |
|
f"Train/Samples/elapsed_time_ms_backward_inner", |
|
self.timers(BACKWARD_INNER_GLOBAL_TIMER).elapsed(reset=False), |
|
self.global_samples, |
|
), |
|
( |
|
f"Train/Samples/elapsed_time_ms_backward_allreduce", |
|
self.timers(BACKWARD_REDUCE_GLOBAL_TIMER).elapsed(reset=False), |
|
self.global_samples, |
|
), |
|
( |
|
f"Train/Samples/elapsed_time_ms_step", |
|
self.timers(STEP_GLOBAL_TIMER).elapsed(reset=False), |
|
self.global_samples, |
|
), |
|
] |
|
self.monitor.write_events(self.summary_events) |
|
|
|
def _get_optimizer_param(self, param_name): |
|
result = [] |
|
if not self.optimizer: |
|
return result |
|
for group in self.optimizer.param_groups: |
|
if param_name in group: |
|
result.append(group[param_name]) |
|
else: |
|
result.append(0.0) |
|
return result |
|
|
|
def get_lr(self): |
|
return self._get_optimizer_param("lr") |
|
|
|
def get_type(self): |
|
return self._get_optimizer_param("type") |
|
|
|
def get_mom(self): |
|
if self.optimizer_name() in ["SGD", "RMSprop"]: |
|
return self._get_optimizer_param("momentum") |
|
else: |
|
return self._get_optimizer_param("betas") |
|
|
|
def get_pld_theta(self): |
|
if self.progressive_layer_drop: |
|
return self.progressive_layer_drop.get_theta() |
|
else: |
|
return None |
|
|
|
def _report_progress(self, step): |
|
lr = self.get_lr() |
|
mom = self.get_mom() |
|
log_dist(f"step={step}, skipped={self.skipped_steps}, lr={lr}, mom={mom}", ranks=[0]) |
|
|
|
def allreduce_bucket(self, bucket, dp_group, dp_world_size=None): |
|
tensor = self.flatten(bucket) |
|
|
|
tensor_to_allreduce = tensor |
|
|
|
if self.communication_data_type != tensor.dtype: |
|
tensor_to_allreduce = tensor.to(self.communication_data_type) |
|
|
|
if dp_world_size is None: |
|
dp_world_size = dist.get_world_size(group=dp_group) |
|
if self.postscale_gradients(): |
|
if self.gradient_predivide_factor() != 1.0: |
|
tensor_to_allreduce.mul_(1.0 / self.gradient_predivide_factor()) |
|
|
|
dist.all_reduce(tensor_to_allreduce, group=dp_group) |
|
if self.gradient_average: |
|
if self.gradient_predivide_factor() != dp_world_size: |
|
tensor_to_allreduce.mul_(self.gradient_predivide_factor() / dp_world_size) |
|
else: |
|
tensor_to_allreduce.mul_(1. / dp_world_size) |
|
dist.all_reduce(tensor_to_allreduce, group=dp_group) |
|
|
|
if self.communication_data_type != tensor.dtype and tensor is not tensor_to_allreduce: |
|
tensor.copy_(tensor_to_allreduce) |
|
|
|
return tensor |
|
|
|
def allreduce_and_copy(self, small_bucket, dp_group, dp_world_size=None): |
|
allreduced = self.allreduce_bucket(small_bucket, dp_group, dp_world_size) |
|
for buf, synced in zip(small_bucket, self.unflatten(allreduced, small_bucket)): |
|
buf.copy_(synced) |
|
|
|
def allreduce_no_retain(self, bucket, dp_group, numel_per_bucket=500000000, dp_world_size=None): |
|
small_bucket = [] |
|
numel = 0 |
|
for tensor in bucket: |
|
small_bucket.append(tensor) |
|
numel = numel + tensor.numel() |
|
if numel > numel_per_bucket: |
|
self.allreduce_and_copy(small_bucket, dp_group, dp_world_size) |
|
small_bucket = [] |
|
numel = 0 |
|
if len(small_bucket) > 0: |
|
self.allreduce_and_copy(small_bucket, dp_group, dp_world_size) |
|
|
|
def _get_gradients_for_reduction(self): |
|
non_expert_grads = [] |
|
expert_grads = {} |
|
if self.has_moe_layers: |
|
for key in self.expert_data_parallel_group.keys(): |
|
expert_grads[key] = [] |
|
|
|
for param_name, param in self.module.named_parameters(): |
|
if not param.requires_grad: |
|
continue |
|
|
|
if param.grad is None: |
|
|
|
|
|
|
|
|
|
|
|
param.grad = torch.zeros(param.size(), dtype=param.dtype, device=param.device) |
|
|
|
grad_data = param.grad.data |
|
if param_name in self.sparse_tensor_module_names or grad_data.is_sparse: |
|
|
|
grad_data = SparseTensor(param.grad) |
|
|
|
if is_moe_param(param): |
|
expert_grads[param.group_name].append(grad_data) |
|
else: |
|
non_expert_grads.append(grad_data) |
|
|
|
return non_expert_grads, expert_grads |
|
|
|
def _reduce_non_expert_gradients(self, grads, elements_per_buffer): |
|
split_sparse_tensor_buckets, split_dense_tensor_buckets = split_half_float_double_sparse(grads) |
|
if self.pipeline_parallelism: |
|
dp_group = self.mpu.get_data_parallel_group() |
|
dp_world_size = dist.get_world_size(dp_group) |
|
else: |
|
dp_group = groups._get_sequence_data_parallel_group() |
|
dp_world_size = dist.get_world_size(dp_group) / float(self.sequence_parallel_size) |
|
for _, sparse_bucket_tuple in enumerate(split_sparse_tensor_buckets): |
|
if sparse_bucket_tuple: |
|
bucket_type, sparse_bucket = sparse_bucket_tuple |
|
self.sparse_allreduce_no_retain(sparse_bucket, dp_group=dp_group, dp_world_size=dp_world_size) |
|
|
|
for _, dense_bucket_tuple in enumerate(split_dense_tensor_buckets): |
|
if dense_bucket_tuple: |
|
bucket_type, dense_bucket = dense_bucket_tuple |
|
self.allreduce_no_retain(dense_bucket, |
|
dp_group=dp_group, |
|
numel_per_bucket=elements_per_buffer, |
|
dp_world_size=dp_world_size) |
|
|
|
def _reduce_expert_gradients(self, expert_grads, elements_per_buffer): |
|
|
|
|
|
dp_world_size = dist.get_world_size(groups._get_data_parallel_group()) |
|
for ep_name, expert_grads_group in expert_grads.items(): |
|
ep_dp_group = groups._get_expert_data_parallel_group(ep_name) |
|
split_sparse_tensor_buckets, split_dense_tensor_buckets = split_half_float_double_sparse( |
|
expert_grads_group) |
|
|
|
for _, sparse_bucket_tuple in enumerate(split_sparse_tensor_buckets): |
|
if sparse_bucket_tuple: |
|
bucket_type, sparse_bucket = sparse_bucket_tuple |
|
self.sparse_allreduce_no_retain(sparse_bucket, dp_group=ep_dp_group, dp_world_size=dp_world_size) |
|
|
|
for _, dense_bucket_tuple in enumerate(split_dense_tensor_buckets): |
|
if dense_bucket_tuple: |
|
bucket_type, dense_bucket = dense_bucket_tuple |
|
|
|
self.allreduce_no_retain(dense_bucket, |
|
dp_group=ep_dp_group, |
|
numel_per_bucket=elements_per_buffer, |
|
dp_world_size=dp_world_size) |
|
|
|
def buffered_allreduce_fallback(self, grads=None, elements_per_buffer=500000000): |
|
if grads is None: |
|
if hasattr(self.optimizer, "get_grads_for_reduction"): |
|
|
|
non_expert_grads, expert_grads = self.optimizer.get_grads_for_reduction() |
|
else: |
|
non_expert_grads, expert_grads = self._get_gradients_for_reduction() |
|
else: |
|
assert not self.has_moe_layers, "attempting to reduce grads in unsupported way w.r.t. MoE" |
|
non_expert_grads = grads |
|
|
|
self._reduce_non_expert_gradients(non_expert_grads, elements_per_buffer) |
|
|
|
if self.has_moe_layers: |
|
self._reduce_expert_gradients(expert_grads, elements_per_buffer) |
|
|
|
def sparse_allreduce_no_retain(self, bucket, dp_group, dp_world_size=None): |
|
allreduced_sparses = self.sparse_allreduce_bucket(bucket, dp_group, dp_world_size) |
|
|
|
for tensor in allreduced_sparses: |
|
if tensor.is_sparse: |
|
tensor.orig_dense_tensor.data = tensor.to_coo_tensor() |
|
else: |
|
tensor.orig_dense_tensor.copy_(tensor.to_dense()) |
|
|
|
def sparse_allreduce_bucket(self, bucket, dp_group, dp_world_size=None): |
|
sparse_list = [] |
|
for sparse in bucket: |
|
sparse_list.append(self.sparse_allreduce(sparse, dp_group, dp_world_size)) |
|
return sparse_list |
|
|
|
def sparse_allreduce(self, sparse, dp_group, dp_world_size=None): |
|
original_data_type = sparse.values.dtype |
|
if self.communication_data_type != sparse.values.dtype: |
|
if self.communication_data_type in (torch.float16, torch.bfloat16): |
|
indices = sparse.indices.to(torch.int32) |
|
else: |
|
indices = sparse.indices |
|
values = sparse.values.to(self.communication_data_type) |
|
else: |
|
indices = sparse.indices |
|
values = sparse.values |
|
|
|
if dp_world_size is None: |
|
dp_world_size = dist.get_world_size(group=dp_group) |
|
if self.postscale_gradients(): |
|
if self.gradient_average: |
|
values.mul_(self.gradient_predivide_factor() / (dp_world_size)) |
|
else: |
|
values.mul_(1. / (dp_world_size)) |
|
|
|
indices_device_list = self.sparse_all_gather(indices, dp_group) |
|
values_device_list = self.sparse_all_gather(values, dp_group) |
|
|
|
sparse.indices = torch.cat(indices_device_list).to(torch.long) |
|
sparse.values = torch.cat(values_device_list).to(original_data_type) |
|
return sparse |
|
|
|
def sparse_all_gather(self, value, dp_group): |
|
my_size = torch.LongTensor([value.size()[0]]).to(self.device) |
|
all_sizes = self.all_gather_scalar(my_size, dp_group) |
|
max_size = torch.cat(all_sizes).max() |
|
fill_size = max_size - my_size |
|
|
|
assert value.dim() in [1, 2] |
|
if value.dim() == 1: |
|
if fill_size > 0: |
|
value = torch.cat([value, value.new_empty(fill_size)]) |
|
tensor_list = [value.new_empty(max_size) for _ in range(dist.get_world_size(group=dp_group))] |
|
else: |
|
if fill_size > 0: |
|
value = torch.cat([value, value.new_empty(fill_size, value.size()[1])]) |
|
tensor_list = [ |
|
value.new_empty(max_size, |
|
value.size()[1]) for _ in range(dist.get_world_size(group=dp_group)) |
|
] |
|
|
|
dist.all_gather(tensor_list, value, group=dp_group) |
|
tensors = [] |
|
for dev_idx, t in enumerate(tensor_list): |
|
size = all_sizes[dev_idx][0] |
|
tensors.append(t.index_select(0, torch.arange(size, dtype=torch.long, device=self.device))) |
|
|
|
return tensors |
|
|
|
def all_gather_scalar(self, value, dp_group): |
|
tensor_list = [value.new_zeros(value.size()) for _ in range(dist.get_world_size(group=dp_group))] |
|
dist.all_gather(tensor_list, value, group=dp_group) |
|
return tensor_list |
|
|
|
def module_state_dict(self, destination=None, prefix="", keep_vars=False, exclude_frozen_parameters=False): |
|
sd = self.module.state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars) |
|
|
|
|
|
if exclude_frozen_parameters: |
|
for n, p in self.module.named_parameters(): |
|
if not p.requires_grad and n in sd: |
|
del sd[n] |
|
|
|
if self.random_ltd_enabled(): |
|
sd = remove_random_ltd_state_dict(sd) |
|
return sd |
|
|
|
@staticmethod |
|
def load_moe_state_dict(checkpoint_path, |
|
tag, |
|
state_dict, |
|
old_moe_load, |
|
model=None, |
|
mpu=None, |
|
num_experts=1, |
|
checkpoint_engine=TorchCheckpointEngine()): |
|
if old_moe_load: |
|
expp_rank = groups._get_expert_data_parallel_rank(groups._get_max_expert_size_name()) |
|
|
|
num_local_experts = max(num_experts) // groups._get_expert_parallel_world_size( |
|
groups._get_max_expert_size_name()) |
|
for local_expert_id in range(num_local_experts): |
|
global_expert_id = expp_rank * num_local_experts + local_expert_id |
|
expert_state_dict = checkpoint_engine.load( |
|
DeepSpeedEngine._get_expert_ckpt_name( |
|
checkpoint_path, |
|
-1, |
|
global_expert_id, |
|
tag, |
|
mpu), |
|
map_location=torch.device('cpu')) |
|
|
|
|
|
moe_str_prefix = '.deepspeed_moe.experts.deepspeed_experts.' |
|
for key in list(expert_state_dict.keys()): |
|
local_key = key.replace(f'{moe_str_prefix}{global_expert_id}', |
|
f'{moe_str_prefix}{local_expert_id}') |
|
expert_state_dict[local_key] = expert_state_dict.pop(key) |
|
state_dict.update(expert_state_dict) |
|
|
|
else: |
|
moe_layer_id = 0 |
|
for n_module, module in model.named_modules(): |
|
if isinstance(module, MoE): |
|
group_name = module.expert_group_name |
|
num_local_experts = module.num_local_experts |
|
expp_rank = groups._get_expert_parallel_rank(group_name) |
|
|
|
for local_expert_id in range(num_local_experts): |
|
global_expert_id = expp_rank * num_local_experts + local_expert_id |
|
expert_state_dict = checkpoint_engine.load(DeepSpeedEngine._get_expert_ckpt_name( |
|
checkpoint_path, moe_layer_id, global_expert_id, tag, mpu), |
|
map_location=torch.device('cpu')) |
|
|
|
|
|
moe_str_prefix = '.deepspeed_moe.experts.deepspeed_experts.' |
|
for key in list(expert_state_dict.keys()): |
|
local_key = key.replace(f'{moe_str_prefix}{global_expert_id}', |
|
f'{moe_str_prefix}{local_expert_id}') |
|
expert_state_dict[local_key] = expert_state_dict.pop(key) |
|
state_dict.update(expert_state_dict) |
|
moe_layer_id += 1 |
|
|
|
def load_module_state_dict(self, checkpoint, strict=True, custom_load_fn=None, fetch_z3_params=False): |
|
if fetch_z3_params: |
|
params_to_fetch = [ |
|
p for p in self.module.parameters() |
|
if hasattr(p, 'ds_id') and p.ds_status == ZeroParamStatus.NOT_AVAILABLE |
|
] |
|
else: |
|
params_to_fetch = [] |
|
|
|
with deepspeed.zero.GatheredParameters(params_to_fetch, modifier_rank=0): |
|
module_state_dict = checkpoint['module'] |
|
if custom_load_fn: |
|
custom_load_fn(src=module_state_dict, dst=self.module) |
|
else: |
|
self.module.load_state_dict( |
|
module_state_dict, |
|
strict=strict) |
|
|
|
if checkpoint.get(FROZEN_PARAM_FRAGMENTS, None) is not None: |
|
saved_frozen_params = checkpoint[FROZEN_PARAM_FRAGMENTS] |
|
for param in self.module.parameters(): |
|
if param.requires_grad: |
|
continue |
|
if param not in self.param_names: |
|
raise ValueError(f"failed to find frozen {param} in named params") |
|
name = self.param_names[param] |
|
if hasattr(param, 'ds_id'): |
|
param.ds_tensor.data.copy_(saved_frozen_params[name].data) |
|
else: |
|
param.data.copy_(saved_frozen_params[name].data) |
|
|
|
def _get_zero_ckpt_prefix(self, dp_rank, bf16_mode): |
|
return f'{"bf16_" if bf16_mode else ""}zero_pp_rank_{dp_rank}' |
|
|
|
def _get_rank_zero_ckpt_name(self, checkpoints_path, tag, mp_rank, dp_rank, bf16_mode): |
|
file_prefix = self._get_zero_ckpt_prefix(dp_rank, bf16_mode=bf16_mode) |
|
zero_ckpt_name = os.path.join( |
|
checkpoints_path, |
|
str(tag), |
|
f"{file_prefix}_mp_rank_{mp_rank:02d}_optim_states.pt", |
|
) |
|
return zero_ckpt_name |
|
|
|
def _get_zero_ckpt_name(self, checkpoints_path, tag): |
|
mp_rank = 0 if self.mpu is None else self.mpu.get_model_parallel_rank() |
|
pp_rank = dist.get_rank(group=self.optimizer.dp_process_group) |
|
bf16_mode = self.bfloat16_enabled() |
|
return self._get_rank_zero_ckpt_name(checkpoints_path, tag, mp_rank, pp_rank, bf16_mode) |
|
|
|
def _get_ckpt_name(self, checkpoints_path, tag, mp_placeholder=None): |
|
if mp_placeholder is not None: |
|
mp_rank_str = mp_placeholder |
|
else: |
|
mp_rank = 0 if self.mpu is None else self.mpu.get_model_parallel_rank() |
|
mp_rank_str = f"{mp_rank:02d}" |
|
|
|
if self.zero_optimization_partition_weights(): |
|
if self.load_universal_checkpoint(): |
|
filename = "zero_pp_rank_0" |
|
else: |
|
filename = "zero_pp_rank_{}".format(dist.get_rank(group=self.optimizer.dp_process_group)) |
|
ckpt_name = os.path.join( |
|
checkpoints_path, |
|
str(tag), |
|
f"{filename}_mp_rank_{mp_rank_str}_model_states.pt", |
|
) |
|
else: |
|
ckpt_name = os.path.join( |
|
checkpoints_path, |
|
str(tag), |
|
"mp_rank_" + mp_rank_str + "_model_states.pt", |
|
) |
|
return ckpt_name |
|
|
|
def _get_optimizer_ckpt_name(self, checkpoints_path, tag, expp_rank): |
|
mp_rank = 0 if self.mpu is None else self.mpu.get_model_parallel_rank() |
|
ckpt_name = os.path.join(checkpoints_path, str(tag), |
|
f'expp_rank_{expp_rank}_mp_rank_{mp_rank:02d}_optim_states.pt') |
|
return ckpt_name |
|
|
|
@staticmethod |
|
def _get_expert_ckpt_name(checkpoints_path, layer_id, expert_id, tag, mpu=None): |
|
mp_rank = 0 if mpu is None else mpu.get_model_parallel_rank() |
|
if layer_id <= -1: |
|
|
|
ckpt_name = os.path.join(checkpoints_path, '' if tag is None else str(tag), |
|
f'expert_{expert_id}_mp_rank_{mp_rank:02d}_model_states.pt') |
|
else: |
|
|
|
ckpt_name = os.path.join(checkpoints_path, '' if tag is None else str(tag), |
|
f'layer_{layer_id}_expert_{expert_id}_mp_rank_{mp_rank:02d}_model_states.pt') |
|
return ckpt_name |
|
|
|
def _get_all_ckpt_names(self, checkpoints_path, tag): |
|
|
|
ckpt_file_pattern = self._get_ckpt_name(checkpoints_path, tag, mp_placeholder="*") |
|
import glob |
|
|
|
ckpt_files = glob.glob(ckpt_file_pattern) |
|
ckpt_files.sort() |
|
return ckpt_files |
|
|
|
def load_checkpoint(self, |
|
load_dir, |
|
tag=None, |
|
load_module_strict=True, |
|
load_optimizer_states=True, |
|
load_lr_scheduler_states=True, |
|
load_module_only=False, |
|
custom_load_fn=None): |
|
""" |
|
Load training checkpoint |
|
|
|
Arguments: |
|
load_dir: Required. Directory to load the checkpoint from |
|
tag: Checkpoint tag used as a unique identifier for checkpoint, if not provided will attempt to load tag in 'latest' file |
|
load_module_strict: Optional. Boolean to strictly enforce that the keys in state_dict of module and checkpoint match. |
|
load_optimizer_states: Optional. Boolean to load the training optimizer states from Checkpoint. Ex. ADAM's momentum and variance |
|
load_lr_scheduler_states: Optional. Boolean to add the learning rate scheduler states from Checkpoint. |
|
load_module_only: Optional. Boolean to load only the model weights from the checkpoint. Ex. warmstarting. |
|
custom_load_fn: Optional. Custom model load function. |
|
|
|
Returns: |
|
A tuple of ``load_path`` and ``client_state``. |
|
*``load_path``: Path of the loaded checkpoint. ``None`` if loading the checkpoint failed. |
|
*``client_state``: State dictionary used for loading required training states in the client code. |
|
|
|
Important: under ZeRO3, one cannot load checkpoint with ``engine.load_checkpoint()`` right |
|
after ``engine.save_checkpoint()``. It is because ``engine.module`` is partitioned, and |
|
``load_checkpoint()`` wants a pristine model. If insisting to do so, please reinitialize engine |
|
before ``load_checkpoint()``. |
|
|
|
""" |
|
|
|
if tag is None: |
|
latest_tag = "latest_universal" if self.load_universal_checkpoint() else "latest" |
|
latest_path = os.path.join(load_dir, latest_tag) |
|
if os.path.isfile(latest_path): |
|
with open(latest_path, "r") as fd: |
|
tag = fd.read().strip() |
|
else: |
|
if self.load_universal_checkpoint(): |
|
raise ValueError(f'Invalid for universal checkpoint: {latest_path} does not exist') |
|
else: |
|
logger.warning( |
|
f"Unable to find latest file at {latest_path}, if trying to load latest " |
|
"checkpoint please ensure this file exists or pass an explicit checkpoint tag when loading a checkpoint." |
|
) |
|
return None, None |
|
|
|
if self._optimizer_has_ckpt_event_prologue(): |
|
|
|
self.optimizer.checkpoint_event_prologue() |
|
|
|
load_path, client_states = self._load_checkpoint(load_dir, |
|
tag, |
|
load_module_strict=load_module_strict, |
|
load_optimizer_states=load_optimizer_states, |
|
load_lr_scheduler_states=load_lr_scheduler_states, |
|
load_module_only=load_module_only, |
|
custom_load_fn=custom_load_fn) |
|
|
|
load_zero_checkpoint = load_path is not None and (self.zero_optimization() or self.bfloat16_enabled()) |
|
if load_zero_checkpoint: |
|
if (load_optimizer_states and not load_module_only) or self.load_universal_checkpoint(): |
|
success = self._load_zero_checkpoint(load_dir, tag, load_optimizer_states=load_optimizer_states) |
|
else: |
|
success = False |
|
if not success: |
|
self.optimizer._restore_from_bit16_weights() |
|
|
|
if self.zero_nvme_offload_optimizer(): |
|
from shutil import copytree, disk_usage |
|
offload_dir = self.optimizer.optimizer_swapper.swap_folder |
|
offload_ckpt_dir = os.path.join(load_dir, tag, "offloaded_tensors") |
|
_, _, free = disk_usage(offload_dir) |
|
logger.info( |
|
f"Copying NVMe offload checkpoint from {offload_ckpt_dir} to {offload_dir}, {free / 1e9:,.2f} GB free on target filesystem..." |
|
) |
|
copytree(offload_ckpt_dir, offload_dir, dirs_exist_ok=True) |
|
_, _, free = disk_usage(offload_dir) |
|
logger.info(f"Copying complete! {free / 1e9:,.2f} GB free on target filesystem") |
|
self.optimizer.reset_swap_buffers() |
|
|
|
if self._optimizer_has_ckpt_event_epilogue(): |
|
self.optimizer.checkpoint_event_epilogue() |
|
|
|
if self.load_universal_checkpoint() and not self.zero_optimization_partition_weights(): |
|
self.optimizer.update_lp_params() |
|
|
|
return load_path, client_states |
|
|
|
def _load_checkpoint(self, |
|
load_dir, |
|
tag, |
|
load_module_strict=True, |
|
load_optimizer_states=True, |
|
load_lr_scheduler_states=True, |
|
load_module_only=False, |
|
custom_load_fn=None): |
|
|
|
from deepspeed.runtime.state_dict_factory import SDLoaderFactory |
|
|
|
ckpt_list = self._get_all_ckpt_names(load_dir, tag) |
|
sd_loader = SDLoaderFactory.get_sd_loader(ckpt_list, checkpoint_engine=self.checkpoint_engine) |
|
|
|
is_pipe_parallel = isinstance(self.module, PipelineModule) |
|
|
|
mp_rank = 0 if self.mpu is None else self.mpu.get_model_parallel_rank() |
|
load_path, checkpoint, _ = sd_loader.load(self.mp_world_size, mp_rank, is_pipe_parallel=is_pipe_parallel) |
|
|
|
if checkpoint is None: |
|
return None, None |
|
|
|
fetch_z3_params = False |
|
if self.zero_optimization_partition_weights() and not load_optimizer_states: |
|
checkpoint['module'] = get_fp32_state_dict_from_zero_checkpoint(load_dir) |
|
fetch_z3_params = True |
|
|
|
if is_pipe_parallel: |
|
|
|
self._curr_ckpt_path = os.path.join(load_dir, tag) |
|
|
|
if self.has_moe_layers: |
|
|
|
old_moe_load = False |
|
if not isinstance(checkpoint['num_experts'], list): |
|
old_moe_load = True |
|
DeepSpeedEngine.load_moe_state_dict(load_dir, |
|
tag, |
|
state_dict=checkpoint['module'], |
|
old_moe_load=old_moe_load, |
|
model=self.module, |
|
mpu=self.mpu, |
|
num_experts=self.num_experts, |
|
checkpoint_engine=self.checkpoint_engine) |
|
if not self.load_universal_checkpoint(): |
|
self.load_module_state_dict(checkpoint=checkpoint, |
|
strict=load_module_strict, |
|
custom_load_fn=custom_load_fn, |
|
fetch_z3_params=fetch_z3_params) |
|
|
|
self.loaded_checkpoint_dp_world_size = checkpoint['dp_world_size'] |
|
|
|
optim_checkpoint = None |
|
if load_module_only: |
|
deepspeed_states = ['module'] |
|
if self.optimizer is not None and hasattr(self.optimizer, 'refresh_fp32_params'): |
|
self.optimizer.refresh_fp32_params() |
|
else: |
|
has_zero_optimizer_state = self.zero_optimization() or self.bfloat16_enabled() |
|
if load_optimizer_states and self.optimizer is not None and not has_zero_optimizer_state: |
|
if self.has_moe_layers: |
|
largest_group_name = groups._get_max_expert_size_name() |
|
expp_rank = groups._get_expert_parallel_rank(largest_group_name) |
|
optim_load_path = self._get_optimizer_ckpt_name(load_dir, tag, expp_rank) |
|
optim_checkpoint = self.checkpoint_engine.load(optim_load_path, map_location=torch.device('cpu')) |
|
else: |
|
optim_checkpoint = checkpoint |
|
|
|
if self.fp16_enabled() or self.bfloat16_enabled(): |
|
self.optimizer.load_state_dict(optim_checkpoint['optimizer'], |
|
load_optimizer_states=load_optimizer_states) |
|
else: |
|
optim_checkpoint = checkpoint |
|
|
|
self.optimizer.load_state_dict(optim_checkpoint['optimizer']) |
|
|
|
if load_lr_scheduler_states and self.lr_scheduler is not None: |
|
self.lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) |
|
|
|
if self.random_ltd_enabled() and self.random_ltd_scheduler is not None and 'random_ltd' in checkpoint: |
|
self.random_ltd_scheduler.load_state_dict(checkpoint['random_ltd']) |
|
|
|
if self.training_dataloader is not None and self.curriculum_learning_enabled( |
|
) and 'data_sampler' in checkpoint: |
|
self.training_dataloader.data_sampler.load_state_dict(checkpoint['data_sampler']) |
|
|
|
def get_sparse_tensor_module_names(original_set, loaded_set, original_parameters, loaded_parameters): |
|
result = set() |
|
|
|
for name in original_set: |
|
if name in loaded_parameters and name not in loaded_set: |
|
continue |
|
result.add(name) |
|
|
|
for name in loaded_set: |
|
if name in original_parameters: |
|
result.add(name) |
|
|
|
return result |
|
|
|
if 'sparse_tensor_module_names' in checkpoint: |
|
sparse_tensor_module_names = checkpoint['sparse_tensor_module_names'] |
|
elif 'csr_tensor_module_names' in checkpoint: |
|
sparse_tensor_module_names = checkpoint['csr_tensor_module_names'] |
|
else: |
|
sparse_tensor_module_names = None |
|
if sparse_tensor_module_names is not None: |
|
if load_module_strict: |
|
self.sparse_tensor_module_names = sparse_tensor_module_names |
|
else: |
|
self.sparse_tensor_module_names = get_sparse_tensor_module_names( |
|
self.sparse_tensor_module_names, sparse_tensor_module_names, |
|
dict(self.module.named_parameters()), checkpoint["module"]) |
|
|
|
self.global_steps = checkpoint['global_steps'] |
|
self.global_samples = checkpoint.get('global_samples', self.global_steps * self.train_batch_size()) |
|
self.skipped_steps = checkpoint['skipped_steps'] |
|
self.loaded_checkpoint_mp_world_size = checkpoint['mp_world_size'] |
|
deepspeed_states = [ |
|
'module', 'sparse_tensor_module_names', 'skipped_steps', 'global_steps', 'dp_world_size', |
|
'mp_world_size', 'data_sampler', 'random_ltd' |
|
] |
|
client_state = {} |
|
|
|
if load_lr_scheduler_states: |
|
deepspeed_states.append('lr_scheduler') |
|
if load_optimizer_states: |
|
deepspeed_states.append('optimizer') |
|
|
|
client_state = {key: value for key, value in checkpoint.items() if not key in deepspeed_states} |
|
|
|
if optim_checkpoint is not None: |
|
client_state['optimizer'] = optim_checkpoint['optimizer'] |
|
|
|
return load_path, client_state |
|
|
|
def _load_zero_checkpoint(self, load_dir, tag, load_optimizer_states=True): |
|
|
|
load_serial = None |
|
|
|
|
|
if self._config.zero_config.pipeline_loading_checkpoint: |
|
assert self.zero_optimization_stage( |
|
) == ZeroStageEnum.weights, "Only stage3 support for pipeline checkpoint loading" |
|
load_serial = torch.zeros(1).to(self.device) |
|
if dist.get_local_rank() != 0: |
|
dist.recv(tensor=load_serial, src=dist.get_rank() - 1) |
|
if self.load_universal_checkpoint(): |
|
zero_sd_list = None |
|
checkpoint_folder = f'{os.path.join(load_dir, tag)}' |
|
else: |
|
if load_optimizer_states and self.seq_dp_world_size != self.loaded_checkpoint_dp_world_size: |
|
raise ZeRORuntimeException("The checkpoint being loaded used a DP " \ |
|
f"world size of {self.loaded_checkpoint_dp_world_size} but the " \ |
|
f"current world size is {self.seq_dp_world_size}. Automatic adjustment " \ |
|
"of ZeRO's optimizer state partitioning with a new world size is not " \ |
|
"currently supported.") |
|
checkpoint_folder = None |
|
zero_sd_list = self._get_all_zero_checkpoints(load_dir, tag) |
|
if zero_sd_list is None: |
|
return False |
|
|
|
param_shapes = self._get_zero_param_shapes() |
|
self.optimizer.load_state_dict(state_dict_list=zero_sd_list, |
|
load_optimizer_states=load_optimizer_states, |
|
load_from_fp32_weights=self.zero_load_from_fp32_weights(), |
|
checkpoint_folder=checkpoint_folder, |
|
load_serial=load_serial, |
|
param_shapes=param_shapes) |
|
|
|
if self.load_universal_checkpoint(): |
|
logger.info(f'loaded universal zero checkpoints from {checkpoint_folder} for rank {self.global_rank}') |
|
else: |
|
logger.info(f"loading {len(zero_sd_list)} zero partition checkpoints for rank {self.global_rank}") |
|
return True |
|
|
|
def _get_mp_rank_zero_checkpoint_names(self, load_dir, tag, mp_rank, dp_world_size, bf16_mode): |
|
zero_ckpt_names = [] |
|
for dp_rank in range(dp_world_size): |
|
ckpt_name = self._get_rank_zero_ckpt_name(checkpoints_path=load_dir, |
|
tag=tag, |
|
mp_rank=mp_rank, |
|
dp_rank=dp_rank, |
|
bf16_mode=bf16_mode) |
|
zero_ckpt_names.append(ckpt_name) |
|
|
|
return zero_ckpt_names |
|
|
|
def _get_all_zero_checkpoint_names(self, load_dir, tag, bf16_mode): |
|
mp_rank = 0 if self.mpu is None else self.mpu.get_model_parallel_rank() |
|
zero_ckpt_names = self._get_mp_rank_zero_checkpoint_names(load_dir=load_dir, |
|
tag=tag, |
|
mp_rank=mp_rank, |
|
dp_world_size=self.loaded_checkpoint_dp_world_size, |
|
bf16_mode=bf16_mode) |
|
for i, ckpt_name in enumerate(zero_ckpt_names): |
|
if not os.path.exists(ckpt_name): |
|
|
|
if "optim_states.pt" in ckpt_name: |
|
ckpt_name_try = ckpt_name.replace("_optim_states.pt", "optim_states.pt") |
|
if os.path.exists(ckpt_name_try): |
|
zero_ckpt_names[i] = ckpt_name_try |
|
continue |
|
|
|
return zero_ckpt_names |
|
|
|
def _get_all_zero_checkpoint_state_dicts(self, zero_ckpt_names): |
|
zero_sd_list = [] |
|
for i, ckpt_name in enumerate(zero_ckpt_names): |
|
_state = None |
|
if ckpt_name is None: |
|
_state = {OPTIMIZER_STATE_DICT: None} |
|
|
|
elif self.zero_elastic_checkpoint() or dist.get_rank(group=self.optimizer.dp_process_group) == i: |
|
_state = self.checkpoint_engine.load( |
|
ckpt_name, |
|
map_location='cpu', |
|
) |
|
else: |
|
_state = {OPTIMIZER_STATE_DICT: None} |
|
zero_sd_list.append(_state) |
|
|
|
zero_optimizer_sd = [sd[OPTIMIZER_STATE_DICT] for sd in zero_sd_list] |
|
logger.info(f"successfully read {len(zero_optimizer_sd)} ZeRO state_dicts for rank {self.global_rank}") |
|
return zero_optimizer_sd |
|
|
|
def _get_all_zero_checkpoints(self, load_dir, tag): |
|
for bf16_mode in [self.bfloat16_enabled(), not self.bfloat16_enabled()]: |
|
zero_ckpt_names = self._get_all_zero_checkpoint_names(load_dir, tag, bf16_mode) |
|
if zero_ckpt_names is not None: |
|
|
|
if bf16_mode is not self.bfloat16_enabled(): |
|
checkpoint_bit16 = BFLOAT16 if bf16_mode else FP16 |
|
engine_bit16 = BFLOAT16 if self.bfloat16_enabled() else FP16 |
|
logger.warning(f'Loading {checkpoint_bit16} zero checkpoints into {engine_bit16} training engine') |
|
return self._get_all_zero_checkpoint_state_dicts(zero_ckpt_names) |
|
|
|
return None |
|
|
|
def _checkpoint_tag_validation(self, tag): |
|
if self.checkpoint_tag_validation_enabled(): |
|
s_hash = hashlib.sha1(tag.encode()) |
|
bhash = torch.ByteTensor([s_hash.digest()]).flatten().to(self.device) |
|
max_bhash = bhash.clone() |
|
min_bhash = bhash.clone() |
|
dist.all_reduce(max_bhash, op=dist.ReduceOp.MAX) |
|
dist.all_reduce(min_bhash, op=dist.ReduceOp.MIN) |
|
valid = all(min_bhash == bhash) and all(max_bhash == bhash) |
|
msg = (f"[rank={dist.get_rank()}] The checkpoint tag name '{tag}' is not consistent across " |
|
"all ranks. Including rank unique information in checkpoint tag could cause issues when " |
|
"restoring with different world sizes.") |
|
if self.checkpoint_tag_validation_fail(): |
|
assert valid, msg |
|
elif not valid: |
|
logger.warning(msg) |
|
|
|
def save_checkpoint(self, save_dir, tag=None, client_state={}, save_latest=True, exclude_frozen_parameters=False): |
|
"""Save training checkpoint |
|
|
|
Arguments: |
|
save_dir: Required. Directory for saving the checkpoint |
|
tag: Optional. Checkpoint tag used as a unique identifier for the checkpoint, global step is |
|
used if not provided. Tag name must be the same across all ranks. |
|
client_state: Optional. State dictionary used for saving required training states in the client code. |
|
save_latest: Optional. Save a file 'latest' pointing to the latest saved checkpoint. |
|
exclude_frozen_parameters: Optional. Exclude frozen parameters from checkpointed state. |
|
Important: all processes must call this method and not just the process with rank 0. It is |
|
because each process needs to save its master weights and scheduler+optimizer states. This |
|
method will hang waiting to synchronize with other processes if it's called just for the |
|
process with rank 0. |
|
|
|
""" |
|
if self._optimizer_has_ckpt_event_prologue(): |
|
|
|
self.optimizer.checkpoint_event_prologue() |
|
|
|
rank = self.local_rank if self.use_node_local_storage() else self.global_rank |
|
|
|
|
|
|
|
|
|
|
|
if rank == 0: |
|
self.checkpoint_engine.makedirs(save_dir, exist_ok=True) |
|
dist.barrier() |
|
|
|
if tag is None: |
|
tag = f"global_step{self.global_steps}" |
|
|
|
|
|
tag = str(tag) |
|
commit_info = CheckpointCommitInfo(tag=tag, save_dir=save_dir, save_latest=save_latest) |
|
|
|
self.checkpoint_engine.create(commit_info) |
|
|
|
|
|
self._checkpoint_tag_validation(tag) |
|
|
|
if self.has_moe_layers: |
|
self.save_non_zero_checkpoint = False |
|
self._create_checkpoint_file(save_dir, tag, False) |
|
self._save_moe_checkpoint(save_dir, |
|
tag, |
|
client_state=client_state, |
|
exclude_frozen_parameters=exclude_frozen_parameters) |
|
|
|
|
|
|
|
|
|
|
|
if not self.has_moe_layers: |
|
self._create_checkpoint_file(save_dir, tag, False) |
|
self._save_checkpoint(save_dir, |
|
tag, |
|
client_state=client_state, |
|
exclude_frozen_parameters=exclude_frozen_parameters) |
|
|
|
if self.save_zero_checkpoint: |
|
self._create_zero_checkpoint_files(save_dir, tag) |
|
self._save_zero_checkpoint(save_dir, tag) |
|
|
|
if self.zero_nvme_offload_optimizer(): |
|
from shutil import copytree, disk_usage |
|
offload_dir = self.optimizer.optimizer_swapper.swap_folder |
|
offload_ckpt_dir = os.path.join(save_dir, tag, "offloaded_tensors") |
|
_, _, free = disk_usage(save_dir) |
|
logger.info( |
|
f"Copying NVMe offload files from {offload_dir} to {offload_ckpt_dir}, {free / 1e9:,.2f} GB free on target filesystem..." |
|
) |
|
copytree(offload_dir, |
|
offload_ckpt_dir, |
|
ignore=lambda _, dir_list: list(filter(lambda x: 'gradient' in x, dir_list)), |
|
dirs_exist_ok=False) |
|
_, _, free = disk_usage(save_dir) |
|
logger.info(f"Copying complete! {free / 1e9:,.2f} GB free on target filesystem") |
|
|
|
if self._optimizer_has_ckpt_event_epilogue(): |
|
self.optimizer.checkpoint_event_epilogue() |
|
|
|
|
|
if not self.checkpoint_engine.is_decoupled(): |
|
self.checkpoint_engine.commit(tag) |
|
if save_latest and self.global_rank == 0: |
|
with open(os.path.join(save_dir, 'latest'), 'w') as fd: |
|
fd.write(tag) |
|
|
|
dist.barrier() |
|
|
|
return True |
|
|
|
def _commit_decoupled_checkpoint(self): |
|
assert self.checkpoint_engine.is_decoupled(), \ |
|
f'{self.checkpoint_engine} is not a Decoupled Checkpoint Engine' |
|
|
|
commit_info = self.checkpoint_engine.get_commit_info() |
|
if commit_info is None: |
|
return |
|
|
|
self.checkpoint_engine.commit(commit_info) |
|
|
|
if self.global_rank == 0 and commit_info.save_latest: |
|
with open(os.path.join(commit_info.save_dir, 'latest'), 'w') as fd: |
|
fd.write(commit_info.tag) |
|
|
|
dist.barrier() |
|
|
|
def _get_non_moe_state_dict(self, full_state_dict): |
|
""" |
|
Get the state dict of the non-moe layers |
|
""" |
|
for key in list(full_state_dict.keys()): |
|
if 'expert' in key and 'moe.gate.wg.weight' not in key: |
|
full_state_dict.pop(key) |
|
|
|
return full_state_dict |
|
|
|
def _save_moe_checkpoint(self, save_dir, tag, client_state={}, exclude_frozen_parameters=False): |
|
save_path = self._get_ckpt_name(save_dir, tag) |
|
|
|
|
|
|
|
|
|
|
|
|
|
moe_layer_id = 0 |
|
for n_module, module in self.module.named_modules(): |
|
if isinstance(module, MoE): |
|
group_name = module.expert_group_name |
|
num_local_experts = module.num_local_experts |
|
expp_rank = groups._get_expert_parallel_rank(group_name) |
|
exp_dp_rank = groups._get_expert_data_parallel_rank(group_name) |
|
|
|
|
|
if not self.checkpoint_engine.is_data_parallel_writer(exp_dp_rank): |
|
moe_layer_id += 1 |
|
continue |
|
|
|
|
|
moe_state_dict = {} |
|
for n, p in module.state_dict().items(): |
|
if 'expert' in n and 'moe.gate.wg.weight' not in n: |
|
moe_state_dict[n_module + '.' + n] = p |
|
moe_str_prefix = '.deepspeed_moe.experts.deepspeed_experts.' |
|
|
|
|
|
experts_state_dict = defaultdict(dict) |
|
for key in list(moe_state_dict.keys()): |
|
m = re.match(f".*{moe_str_prefix}([0-9]+).*", key) |
|
|
|
local_expert_id = None |
|
if not m: |
|
logger.warning(f'No expert found in key {key}.') |
|
else: |
|
local_expert_id = m.group(1) |
|
|
|
global_expert_id = expp_rank * \ |
|
num_local_experts + int(local_expert_id) |
|
expert_key = key.replace(f'{moe_str_prefix}{local_expert_id}', |
|
f'{moe_str_prefix}{global_expert_id}') |
|
|
|
truncated = moe_state_dict.pop(key).clone().detach() |
|
experts_state_dict[str(global_expert_id)][expert_key] = truncated |
|
|
|
|
|
for global_expert_id, expert_state_dict in experts_state_dict.items(): |
|
|
|
moe_save_path = self._get_expert_ckpt_name(save_dir, moe_layer_id, global_expert_id, tag, self.mpu) |
|
if self.random_ltd_enabled(): |
|
expert_state_dict = remove_random_ltd_state_dict(expert_state_dict) |
|
saveable_state_dict = clone_tensors_for_torch_save(expert_state_dict) |
|
self.checkpoint_engine.save(saveable_state_dict, moe_save_path) |
|
moe_layer_id += 1 |
|
|
|
self._curr_ckpt_path = os.path.join(save_dir, tag) |
|
|
|
largest_group_name = groups._get_max_expert_size_name() |
|
expp_rank = groups._get_expert_parallel_rank(largest_group_name) |
|
exp_dp_rank = groups._get_expert_data_parallel_rank(largest_group_name) |
|
|
|
|
|
|
|
|
|
if not self.checkpoint_engine.is_data_parallel_writer(exp_dp_rank): |
|
return |
|
|
|
|
|
optimizer_state = { |
|
'optimizer': self.optimizer.state_dict() if self.optimizer and not self.zero_optimization() else None |
|
} |
|
|
|
file_path = self._get_optimizer_ckpt_name(save_dir, tag, expp_rank) |
|
saveable_state_dict = clone_tensors_for_torch_save(optimizer_state) |
|
self.checkpoint_engine.save(saveable_state_dict, file_path) |
|
|
|
|
|
if groups._get_data_parallel_rank() == 0: |
|
|
|
|
|
|
|
|
|
model_state_dict = self._get_non_moe_state_dict( |
|
DeepSpeedEngine.module_state_dict(self, exclude_frozen_parameters=exclude_frozen_parameters)) |
|
|
|
|
|
state = { |
|
'module': |
|
model_state_dict, |
|
'lr_scheduler': |
|
self.lr_scheduler.state_dict() if self.lr_scheduler is not None else None, |
|
'data_sampler': |
|
self.training_dataloader.data_sampler.state_dict() if |
|
(self.training_dataloader is not None and self.curriculum_learning_enabled()) else None, |
|
'random_ltd': |
|
self.random_ltd_scheduler.state_dict() if self.random_ltd_enabled() else None, |
|
'sparse_tensor_module_names': |
|
self.sparse_tensor_module_names, |
|
'skipped_steps': |
|
self.skipped_steps, |
|
'global_steps': |
|
self.global_steps, |
|
'global_samples': |
|
self.global_samples, |
|
'dp_world_size': |
|
self.dp_world_size, |
|
'mp_world_size': |
|
self.mp_world_size, |
|
'num_experts': |
|
self.num_experts |
|
} |
|
state.update(client_state) |
|
logger.info(f'Saving model checkpoint: {save_path}') |
|
saveable_state_dict = clone_tensors_for_torch_save(state) |
|
self.checkpoint_engine.save(saveable_state_dict, save_path) |
|
|
|
def _create_checkpoint_file(self, save_dir, tag, zero_checkpoint): |
|
name_function = (self._get_zero_ckpt_name if zero_checkpoint else self._get_ckpt_name) |
|
try: |
|
checkpoint_name = name_function(save_dir, tag) |
|
path = os.path.dirname(checkpoint_name) |
|
self.checkpoint_engine.makedirs(path, exist_ok=True) |
|
except: |
|
logger.error(f"Failed saving model checkpoint to {save_dir} with tag {tag}") |
|
return False |
|
|
|
return True |
|
|
|
def _create_zero_checkpoint_files(self, save_dir, tag): |
|
success = True |
|
|
|
for rank in range(dist.get_world_size(self.optimizer.dp_process_group)): |
|
if rank == self.global_rank: |
|
success = self._create_checkpoint_file(save_dir, tag, True) |
|
|
|
return success |
|
|
|
def _save_checkpoint(self, save_dir, tag, client_state={}, exclude_frozen_parameters=False): |
|
|
|
save_path = self._get_ckpt_name(save_dir, tag) |
|
|
|
zero_optimizer_state = self.zero_optimization() or self.bfloat16_enabled() |
|
|
|
save_frozen_param = self.zero_optimization_partition_gradients() and not exclude_frozen_parameters |
|
|
|
|
|
|
|
|
|
|
|
self._curr_ckpt_path = os.path.join(save_dir, tag) |
|
module = self.module_state_dict(exclude_frozen_parameters=exclude_frozen_parameters) |
|
self._curr_ckpt_path = None |
|
|
|
state = dict(module=module, |
|
buffer_names=self._get_buffer_names(), |
|
optimizer=self.optimizer.state_dict() if self.optimizer and not zero_optimizer_state else None, |
|
param_shapes=self._get_zero_param_shapes() if self.optimizer and zero_optimizer_state else None, |
|
frozen_param_shapes=self._get_zero_frozen_param_attributes(self._get_param_shape_func) |
|
if save_frozen_param else None, |
|
shared_params=self._get_shared_params() if self.optimizer and zero_optimizer_state else None, |
|
frozen_param_fragments=self._get_zero_frozen_param_attributes(self._get_param_fragment_func) |
|
if save_frozen_param else None, |
|
lr_scheduler=self.lr_scheduler.state_dict() if self.lr_scheduler is not None else None, |
|
data_sampler=self.training_dataloader.data_sampler.state_dict() if |
|
(self.training_dataloader is not None and self.curriculum_learning_enabled()) else None, |
|
random_ltd=self.random_ltd_scheduler.state_dict() if self.random_ltd_enabled() else None, |
|
sparse_tensor_module_names=self.sparse_tensor_module_names, |
|
skipped_steps=self.skipped_steps, |
|
global_steps=self.global_steps, |
|
global_samples=self.global_samples, |
|
dp_world_size=self.seq_dp_world_size, |
|
mp_world_size=self.mp_world_size, |
|
ds_config=self.config, |
|
ds_version=version) |
|
state.update(client_state) |
|
log_dist(message=f'Saving model checkpoint: {save_path}', ranks=[0]) |
|
|
|
if self.save_non_zero_checkpoint: |
|
self.checkpoint_engine.save(state_dict=state, path=save_path) |
|
|
|
def _get_buffer_names(self): |
|
buffer_names = [] |
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_layer_named_buffers(module, prefix=""): |
|
for name, buf in module.named_buffers(recurse=False): |
|
if buf is not None and name not in module._non_persistent_buffers_set: |
|
buffer_names.append(prefix + name) |
|
|
|
for name, child in module.named_children(): |
|
if child is not None: |
|
get_layer_named_buffers(child, prefix + name + ".") |
|
|
|
get_layer_named_buffers(self.module, prefix="") |
|
|
|
return buffer_names |
|
|
|
def _get_param_shape_func(self, param): |
|
return param.ds_shape if hasattr(param, 'ds_id') else param.shape |
|
|
|
def _get_param_fragment_func(self, param): |
|
return param.ds_tensor.detach().cpu() if hasattr(param, 'ds_id') else param.detach().cpu() |
|
|
|
def _get_zero_frozen_param_attributes(self, attr_func): |
|
frozen_param_fragments = OrderedDict() |
|
|
|
for param in self.module.parameters(): |
|
if param.requires_grad: |
|
continue |
|
if param not in self.param_names: |
|
raise ValueError(f"failed to find frozen {param} in named params") |
|
name = self.param_names[param] |
|
frozen_param_fragments[name] = attr_func(param) |
|
|
|
return frozen_param_fragments |
|
|
|
def _get_zero_param_shapes(self): |
|
"""Returns a dict of name to shape mapping, only for the flattened fp32 weights saved by the |
|
optimizer. the names are exactly as in state_dict. The order is absolutely important, since |
|
the saved data is just flattened data with no identifiers and requires reconstruction in the |
|
same order it was saved. |
|
We can't rely on self.module.named_parameters() to get the saved tensors, as some params |
|
will be missing and others unsaved and then it'd be impossible to reconstruct state_dict |
|
from the flattened weights. |
|
optimizer.bit16_groups seems to be the easiest to use as it's in all zeroX versions. |
|
""" |
|
param_group_shapes = [] |
|
cnt = 0 |
|
numel = 0 |
|
|
|
|
|
|
|
if hasattr(self.optimizer, "round_robin_bit16_groups"): |
|
bit16_groups = self.optimizer.round_robin_bit16_groups |
|
elif self.bfloat16_enabled() and hasattr(self.optimizer, "bf16_groups"): |
|
bit16_groups = self.optimizer.bf16_groups |
|
else: |
|
bit16_groups = self.optimizer.bit16_groups if self.zero_optimization_stage( |
|
) == 2 else self.optimizer.fp16_groups |
|
|
|
for bit16_group in bit16_groups: |
|
param_shapes = OrderedDict() |
|
for param in bit16_group: |
|
cnt += 1 |
|
numel += param.ds_numel if hasattr(param, "ds_numel") else param.numel() |
|
shape = param.ds_shape if hasattr(param, "ds_shape") else param.shape |
|
if param not in self.param_names: |
|
raise ValueError(f"failed to find optimizer param in named params") |
|
name = self.param_names[param] |
|
param_shapes[name] = shape |
|
|
|
|
|
|
|
param_group_shapes.append(param_shapes) |
|
|
|
|
|
return param_group_shapes |
|
|
|
def _get_shared_params(self): |
|
""" |
|
Returns a dict of shared params, which can later be used to reconstruct the original state dict, |
|
e.g. in `zero_to_fp32`. Each dict entry is a pair of param names, where the key is the name |
|
of the variable that isn't stored and the value is the actual param holding data. |
|
""" |
|
shared_index = {} |
|
shared_params_by_full_name = {} |
|
|
|
is_zero3_model = (self.zero_optimization_partition_weights() |
|
and any(hasattr(param, "ds_id") for param in self.module.parameters())) |
|
|
|
def get_layer_state_dict(module, prefix=""): |
|
|
|
for name, param in module.named_parameters(recurse=False): |
|
if param is None or (is_zero3_model and not hasattr(param, "ds_id")): |
|
continue |
|
key = prefix + name |
|
|
|
|
|
|
|
|
|
param_id = param.ds_id if is_zero3_model else param.data_ptr() |
|
|
|
if param_id in shared_index: |
|
|
|
|
|
shared_params_by_full_name[key] = shared_index[param_id] |
|
else: |
|
shared_index[param_id] = key |
|
|
|
for name, child in module.named_children(): |
|
if child is not None: |
|
get_layer_state_dict(child, prefix + name + ".") |
|
|
|
if dist.get_rank() == 0: |
|
get_layer_state_dict(self.module, prefix="") |
|
|
|
return shared_params_by_full_name |
|
|
|
def _copy_recovery_script(self, save_path): |
|
base_dir = os.path.dirname(os.path.dirname(__file__)) |
|
script = "zero_to_fp32.py" |
|
src = os.path.join(base_dir, "utils", script) |
|
dst = os.path.join(save_path, script) |
|
|
|
copyfile(src, dst) |
|
self._change_recovery_script_permissions(dst) |
|
|
|
def _change_recovery_script_permissions(self, dst): |
|
|
|
try: |
|
os.chmod(dst, os.stat(dst).st_mode | stat.S_IEXEC) |
|
except (FileNotFoundError, PermissionError) as e: |
|
|
|
logger.info( |
|
f'Warning: Could not change permissions for {dst} due to error: {e}. Continuing without changing permissions.' |
|
) |
|
|
|
def _save_zero_checkpoint(self, save_path, tag): |
|
zero_checkpoint_name = self._get_zero_ckpt_name(save_path, tag) |
|
zero_sd = dict(optimizer_state_dict=self.optimizer.state_dict(), ds_config=self.config, ds_version=version) |
|
self.checkpoint_engine.save(zero_sd, zero_checkpoint_name) |
|
|
|
if self.global_rank == 0: |
|
self._copy_recovery_script(save_path) |
|
ckpt_type = 'zero' if self.zero_optimization() else 'bf16_zero' |
|
|
|
|
|
def _replace_module_consolidated_state_dict(self): |
|
""" |
|
Get a full non-partitioned state_dict with fp16 weights on cpu. |
|
Important: this function must be called on all ranks and not just rank 0. |
|
This is similar to nn.Module.state_dict (modelled after _save_to_state_dict) |
|
This method is used for tensor parallel training. |
|
|
|
Returns: |
|
OrderedDict: The consolidated state dictionary if the current process rank is 0, otherwise None. |
|
""" |
|
|
|
|
|
state_dict = OrderedDict() if dist.get_rank() == 0 else None |
|
|
|
def get_layer_state_dict(module, prefix=""): |
|
with GatherReplacedLayerParams(list(module.parameters(recurse=False)), module, enabled=True): |
|
for name, param in module.named_parameters(recurse=False): |
|
if param is None: |
|
continue |
|
key = prefix + name |
|
if (dist.get_rank() == 0): |
|
state_dict[key] = param.detach().cpu() |
|
|
|
|
|
for name, child in module.named_children(): |
|
if child is not None: |
|
get_layer_state_dict(child, prefix + name + ".") |
|
|
|
get_layer_state_dict(self.module, prefix="") |
|
|
|
|
|
get_accelerator().synchronize() |
|
return state_dict |
|
|
|
def _consolidated_16bit_state_dict(self, exclude_frozen_parameters=False): |
|
""" |
|
Consolidate the 16-bit state dictionary. |
|
""" |
|
if self.zero_optimization_stage() == ZeroStageEnum.weights: |
|
return self._zero3_consolidated_16bit_state_dict(exclude_frozen_parameters) |
|
elif self.autotp_size() > 1: |
|
return self._replace_module_consolidated_state_dict() |
|
|
|
raise ValueError("consolidated_16bit_state_dict is only applicable to cases where weights are partitioned, " |
|
"including Zero Stage 3 and tensor parallelism.") |
|
|
|
def _zero3_consolidated_16bit_state_dict(self, exclude_frozen_parameters=False): |
|
""" |
|
Get a full non-partitioned state_dict with fp16 weights on cpu. |
|
Important: this function must be called on all ranks and not just rank 0. |
|
This is similar to nn.Module.state_dict (modelled after _save_to_state_dict), but: |
|
1. consolidates the weights from different partitions on gpu0 |
|
2. works on one layer at a time to require as little gpu0 memory as possible, by |
|
moving the already consolidated weights to cpu |
|
3. takes care to keep the shared params shared when gradually copying the params to cpu |
|
Returns: |
|
a consolidated fp16 ``state_dict`` on cpu on rank 0, ``None`` on other ranks |
|
""" |
|
if not self.zero_optimization_partition_weights(): |
|
raise ValueError("this function requires ZeRO-3 mode") |
|
|
|
state_dict = OrderedDict() if dist.get_rank() == 0 else None |
|
shared_params = {} |
|
|
|
def get_layer_state_dict(module, prefix=""): |
|
|
|
|
|
|
|
with deepspeed.zero.GatheredParameters(list(module.parameters(recurse=False)), modifier_rank=0): |
|
if dist.get_rank() == 0: |
|
|
|
for name, param in module.named_parameters(recurse=False): |
|
if param is None or (exclude_frozen_parameters and not param.requires_grad): |
|
continue |
|
key = prefix + name |
|
|
|
|
|
|
|
if param.ds_id in shared_params: |
|
|
|
|
|
state_dict[key] = state_dict[shared_params[param.ds_id]] |
|
else: |
|
state_dict[key] = param.detach().cpu() |
|
shared_params[param.ds_id] = key |
|
|
|
|
|
|
|
for name, buf in module.named_buffers(recurse=False): |
|
if (buf is not None and name not in module._non_persistent_buffers_set): |
|
state_dict[prefix + name] = buf.detach().cpu() |
|
|
|
|
|
for name, child in module.named_children(): |
|
if child is not None: |
|
get_layer_state_dict(child, prefix + name + ".") |
|
|
|
|
|
if self._optimizer_has_ckpt_event_prologue(): |
|
self.optimizer.checkpoint_event_prologue() |
|
|
|
see_memory_usage("before get_layer_state_dict", force=False) |
|
get_layer_state_dict(self.module, prefix="") |
|
see_memory_usage("after get_layer_state_dict", force=False) |
|
|
|
if self._optimizer_has_ckpt_event_epilogue(): |
|
self.optimizer.checkpoint_event_epilogue() |
|
|
|
return state_dict |
|
|
|
def save_fp16_model(self, save_dir, save_filename="pytorch_model.bin"): |
|
"""has been renamed to save_16bit_model, keeping this around for backwards |
|
compatibility""" |
|
return self.save_16bit_model(save_dir, save_filename) |
|
|
|
def save_16bit_model(self, save_dir, save_filename="pytorch_model.bin", exclude_frozen_parameters=False): |
|
""" |
|
Save 16bit model weights |
|
|
|
This method saves the 16bit model weights at the desired destination. |
|
|
|
Arguments: |
|
save_dir: Required. Directory for saving the model |
|
save_filename: Optional. Filename to save to. Defaults to ``pytorch_model.bin`` |
|
exclude_frozen_parameters: Optional. Exclude frozen parameters from checkpointed state. |
|
|
|
Returns: |
|
``True`` when a model has been saved, ``False`` otherwise. It will not be saved if |
|
stage3_gather_16bit_weights_on_model_save is ``False``. |
|
|
|
Important: all processes must call this method and not just the process with rank 0. It is |
|
because the processes need to work in sync to gather the weights. This method will hang |
|
waiting to synchronize with other processes if it's called just for the process with rank 0. |
|
|
|
""" |
|
|
|
path = os.path.join(save_dir, save_filename) |
|
|
|
if self.zero_optimization_partition_weights(): |
|
if self.zero_gather_16bit_weights_on_model_save(): |
|
|
|
state_dict = self._zero3_consolidated_16bit_state_dict( |
|
exclude_frozen_parameters=exclude_frozen_parameters) |
|
else: |
|
|
|
logger.info( |
|
f"Did not save the model {path} because stage3_gather_16bit_weights_on_model_save is False") |
|
return False |
|
else: |
|
state_dict = self.module_state_dict(exclude_frozen_parameters=exclude_frozen_parameters) |
|
|
|
tag = f"global_step{self.global_steps}" |
|
tag = str(tag) |
|
commit_info = CheckpointCommitInfo(tag=tag, save_dir=save_dir, save_latest=False) |
|
self.checkpoint_engine.create(commit_info) |
|
|
|
if dist.get_rank() == 0: |
|
self.checkpoint_engine.makedirs(save_dir, exist_ok=True) |
|
logger.info(f"Saving model weights to {path}, tag: {tag}") |
|
self.checkpoint_engine.save(state_dict, path) |
|
|
|
self.checkpoint_engine.commit(tag) |
|
|
|
return True |
|
|
|
def empty_partition_cache(self): |
|
""" |
|
Release GPU memory consumed by offloaded model parameters. |
|
""" |
|
if hasattr(self.optimizer, 'empty_partition_cache'): |
|
self.optimizer.empty_partition_cache() |
|
gc.collect() |
|
get_accelerator().empty_cache() |
|
|
|
def compile(self, backend=get_accelerator().get_compile_backend(), compile_kwargs={}, schedule=None) -> None: |
|
"""Compile the module using the specified backend and kwargs. |
|
If a compiler_fn is set, it will be used instead of torch.compile(). |
|
""" |
|
|
|
deepspeed.utils.nvtx.enable_nvtx = False |
|
|
|
if not is_compile_supported(): |
|
raise RuntimeError("compile is not supported in your version of PyTorch.") |
|
|
|
if self.is_compiled: |
|
return |
|
|
|
if 'backend' in compile_kwargs: |
|
logger.warning("The `backend` in `compile_kwargs` will be overridden. Use the `backend` argument instead.") |
|
|
|
logger.info(f"Compiling deepcompile={self.is_deepcompile_enabled()} backend={backend}") |
|
|
|
enable_deepcompile = self.is_deepcompile_enabled() |
|
if enable_deepcompile and self.zero_optimization_stage() != ZeroStageEnum.optimizer_states \ |
|
and self.zero_optimization_stage() != ZeroStageEnum.weights: |
|
logger.info( |
|
f"Currently DeepCompile supports ZeRO stage 1 or 3 only, but ZeRO stage is set to {self.zero_optimization_stage()}. Falling back to the torch compiler." |
|
) |
|
enable_deepcompile = False |
|
|
|
if enable_deepcompile: |
|
|
|
if schedule is not None: |
|
|
|
def passes_name_to_fn(passes): |
|
for p in passes: |
|
assert callable(p) or p in opt_passes, f"Unknown pass {p}" |
|
return [p if callable(p) else opt_passes[p] for p in passes] |
|
|
|
schedule = [(step, passes_name_to_fn(passes)) for step, passes in schedule] |
|
|
|
assert backend in ['inductor', 'eager'], f"Backend {backend} is not supported for DeepCompile." |
|
|
|
compile_config = self._config.compile_config |
|
if (("zero_optimization" in self.config and "offload_optimizer" in self.config["zero_optimization"] |
|
and "offload_param" in self.config["zero_optimization"]) |
|
and self._config.zero_config.offload_param.device == "cpu" |
|
and self._config.zero_config.offload_optimizer.device == "cpu"): |
|
compile_config.offload_parameters = True |
|
if self.zero_optimization_stage() == ZeroStageEnum.optimizer_states: |
|
backend = init_z1(self, backend, compile_config, compile_kwargs, schedule) |
|
elif self.zero_optimization_stage() == ZeroStageEnum.weights: |
|
backend = init_z3(self, backend, compile_config, compile_kwargs, schedule) |
|
|
|
|
|
self.module.compile(**{**compile_kwargs, 'backend': backend}) |
|
|
|
self._is_compiled = True |
|
|
|
def get_compile_time(self): |
|
from deepspeed.compile.backend import opt_pass_times |
|
return opt_pass_times |
|
|
|
def register_compile_pass(self, pass_name: str, pass_fn: Callable) -> None: |
|
register_compile_pass(pass_name, pass_fn) |
|
|
|
def is_deepcompile_enabled(self): |
|
return self._config.compile_config.deepcompile |
|
|
|
@property |
|
def is_compiled(self) -> bool: |
|
return self._is_compiled |
|
|
|
def offload_states(self, |
|
include: Container[OffloadStateTypeEnum] = None, |
|
device: OffloadDeviceEnum = OffloadDeviceEnum.cpu, |
|
pin_memory: bool = True, |
|
non_blocking: bool = False) -> None: |
|
"""Offload the engine's states to the specified device. |
|
|
|
Arguments: |
|
include: Optional. The set of states to offload. If not provided, all states are offloaded. |
|
device: Optional. The device to move the ZeRO optimizer buffers to. Currently only `OffloadDeviceEnum.cpu` is supported. |
|
pin_memory: Optional. Whether to pin the memory of the offloaded states. |
|
non_blocking: Optional. Whether to offload the states asynchronously. |
|
""" |
|
assert self.zero_optimization_stage( |
|
) == ZeroStageEnum.weights, "Moving buffers across devices is supported only for ZeRO stage 3." |
|
|
|
opt_offload_config = self.zero_offload_optimizer() |
|
assert opt_offload_config is None or opt_offload_config.device == OffloadDeviceEnum.none, "Moving states across devices is not supported for offloaded optimizer states." |
|
param_offload_config = self.zero_offload_param() |
|
assert param_offload_config is None or param_offload_config.device == OffloadDeviceEnum.none, "Moving states across devices is not supported for offloaded parameters." |
|
|
|
assert not isinstance( |
|
self.optimizer, |
|
DeepSpeedZeRoOffload), "Moving states across devices is not supported without an optimizer." |
|
|
|
if device == OffloadDeviceEnum.none: |
|
logger.warning("No device specified for offloading states.") |
|
return |
|
|
|
if device == OffloadDeviceEnum.nvme: |
|
raise ValueError("NVMe offload is not supported for offloading states.") |
|
|
|
self.optimizer.offload_states(include=include, device=device, pin_memory=pin_memory, non_blocking=non_blocking) |
|
|
|
def reload_states(self, non_blocking: bool = False) -> None: |
|
"""Reload the engine states to the original device. |
|
|
|
Arguments: |
|
non_blocking: Optional. Whether to offload the states asynchronously. |
|
""" |
|
assert self.zero_optimization_stage( |
|
) == ZeroStageEnum.weights, "Moving buffers back is supported only for ZeRO stage 3." |
|
|
|
assert not isinstance( |
|
self.optimizer, |
|
DeepSpeedZeRoOffload), "Moving states across devices is not supported without an optimizer." |
|
|
|
self.optimizer.reload_states(non_blocking=non_blocking) |
|
|