|
|
|
|
|
|
|
|
|
|
|
import sys |
|
import torch |
|
from collections import OrderedDict |
|
from deepspeed.utils import z3_leaf_module, set_z3_leaf_module |
|
from deepspeed.runtime.utils import see_memory_usage |
|
from deepspeed.runtime.zero.utils import apply_to_tensors_only, is_zero_param |
|
from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum |
|
from deepspeed.runtime.zero.partition_parameters import _init_external_params |
|
from deepspeed.runtime.zero.partition_parameters import * |
|
from deepspeed.runtime.zero.partitioned_param_coordinator import PartitionedParameterCoordinator, InflightParamRegistry, iter_params |
|
from deepspeed.accelerator import get_accelerator |
|
from deepspeed import utils |
|
|
|
FWD_MODULE_STACK = list() |
|
|
|
|
|
|
|
def _apply_forward_and_backward_to_tensors_only(module, forward_function, backward_function, outputs): |
|
if type(outputs) is tuple: |
|
touched_outputs = [] |
|
for output in outputs: |
|
touched_output = _apply_forward_and_backward_to_tensors_only(module, forward_function, backward_function, |
|
output) |
|
touched_outputs.append(touched_output) |
|
return tuple(touched_outputs) |
|
elif type(outputs) is torch.Tensor: |
|
forward_function(outputs) |
|
if outputs.requires_grad: |
|
outputs.register_hook(backward_function) |
|
return outputs |
|
else: |
|
return outputs |
|
|
|
|
|
class ZeROOrderedDict(OrderedDict): |
|
|
|
def __init__(self, parent_module, *args, **kwargs): |
|
"""A replacement for ``collections.OrderedDict`` to detect external ZeRO params. |
|
|
|
Args: |
|
parent_module (``collections.OrderedDict``): the collection to replace |
|
""" |
|
|
|
super().__init__(*args, **kwargs) |
|
self._parent_module = parent_module |
|
self._in_forward = False |
|
|
|
def __reduce__(self): |
|
r0, _, *r2 = super().__reduce__() |
|
return (r0, (self._parent_module, )) + tuple(r2) |
|
|
|
def __getitem__(self, key): |
|
param = super().__getitem__(key) |
|
|
|
|
|
if param is None: |
|
return param |
|
|
|
|
|
if hasattr(param, "ds_status") and param.ds_status == ZeroParamStatus.NOT_AVAILABLE: |
|
if self._parent_module._parameters._in_forward: |
|
register_external_parameter(FWD_MODULE_STACK[-1], param) |
|
param.all_gather() |
|
print_rank_0(f'Registering external parameter from getter {key} ds_id = {param.ds_id}', force=False) |
|
|
|
return param |
|
|
|
|
|
def _inject_parameters(module, cls): |
|
for module in module.modules(): |
|
module._original_parameters = module._parameters |
|
|
|
if cls == ZeROOrderedDict: |
|
new_param = cls(parent_module=module) |
|
else: |
|
new_param = cls() |
|
|
|
for key, param in module._parameters.items(): |
|
new_param[key] = param |
|
|
|
module._parameters = new_param |
|
|
|
|
|
class DeepSpeedZeRoOffload(object): |
|
|
|
def __init__( |
|
self, |
|
module, |
|
timers, |
|
ds_config, |
|
overlap_comm=True, |
|
prefetch_bucket_size=50000000, |
|
max_reuse_distance=1000000000, |
|
max_live_parameters=1000000000, |
|
param_persistence_threshold=100000, |
|
model_persistence_threshold=sys.maxsize, |
|
dp_process_group=None, |
|
offload_param_config=None, |
|
mpu=None, |
|
zero_param_parallel_group=None, |
|
zero_quantized_weights=False, |
|
zero_quantized_nontrainable_weights=False, |
|
zero_module_granularity_threshold=0, |
|
log_trace_cache_warnings=False, |
|
): |
|
|
|
see_memory_usage("DeepSpeedZeRoOffload initialize [begin]", force=True) |
|
|
|
print_rank_0(f"initialized {__class__.__name__} with args: {locals()}", force=False) |
|
|
|
self.module = module |
|
self.timers = timers |
|
self.dtype = list(module.parameters())[0].dtype |
|
self.dp_process_group = dp_process_group |
|
self.offload_device = None |
|
self.offload_param_pin_memory = False |
|
self.zero_param_parallel_group = zero_param_parallel_group |
|
self.zero_quantized_weights = zero_quantized_weights |
|
self.zero_quantized_nontrainable_weights = zero_quantized_nontrainable_weights |
|
self.log_trace_cache_warnings = log_trace_cache_warnings |
|
|
|
if offload_param_config is not None and offload_param_config.device != OffloadDeviceEnum.none: |
|
self.offload_device = offload_param_config.device |
|
self.offload_param_pin_memory = offload_param_config.pin_memory |
|
|
|
self._convert_to_zero_parameters(ds_config, module, mpu) |
|
|
|
for m in module.modules(): |
|
_init_external_params(m) |
|
|
|
_inject_parameters(module, ZeROOrderedDict) |
|
|
|
self.param_numel_persistence_threshold = int(param_persistence_threshold) |
|
self.model_persistence_threshold = int(model_persistence_threshold) |
|
self.persistent_parameters = self.mark_persistent_parameters(self.param_numel_persistence_threshold, |
|
self.model_persistence_threshold) |
|
|
|
self._prefetch_bucket_sz = int(prefetch_bucket_size) |
|
self._max_reuse_distance_in_numel = int(max_reuse_distance) |
|
self._max_available_parameters_in_numel = int(max_live_parameters) |
|
self.__allgather_stream = None if get_accelerator().is_synchronized_device() else get_accelerator().Stream( |
|
) if overlap_comm else get_accelerator().default_stream() |
|
|
|
if not hasattr(module, "ds_inflight_param_registry"): |
|
module.ds_inflight_param_registry = InflightParamRegistry() |
|
self.__inflight_param_registry = module.ds_inflight_param_registry |
|
|
|
self.fast_sharding_for_leaf_module = False |
|
|
|
if zero_module_granularity_threshold > 0: |
|
self.min_granularity_value = sys.maxsize |
|
self.min_granularity_layer = None |
|
self.granularity_info = set() |
|
self.z3_leaf_layers = [] |
|
self._set_z3_leaf_modules_by_threshold(module, zero_module_granularity_threshold) |
|
self.fast_sharding_for_leaf_module = True |
|
|
|
self.param_coordinator = PartitionedParameterCoordinator( |
|
prefetch_bucket_sz=self._prefetch_bucket_sz, |
|
max_reuse_distance_in_numel=self._max_reuse_distance_in_numel, |
|
max_available_parameters_in_numel=self._max_available_parameters_in_numel, |
|
allgather_stream=self.__allgather_stream, |
|
inflight_param_registry=self.__inflight_param_registry, |
|
prefetch_nvme=self.offload_device == OffloadDeviceEnum.nvme, |
|
timers=self.timers, |
|
zero_quantized_weights=self.zero_quantized_weights, |
|
zero_quantized_nontrainable_weights=self.zero_quantized_nontrainable_weights, |
|
fast_sharding_for_leaf_module=self.fast_sharding_for_leaf_module, |
|
log_trace_cache_warnings=self.log_trace_cache_warnings, |
|
) |
|
|
|
self.forward_hooks = [] |
|
self.backward_hooks = [] |
|
|
|
self.setup_zero_stage3_hooks() |
|
print_rank_0( |
|
f'Created module hooks: forward = {len(self.forward_hooks)}, backward = {len(self.backward_hooks)}', |
|
force=False) |
|
|
|
see_memory_usage("DeepSpeedZeRoOffload initialize [end]", force=True) |
|
|
|
@instrument_w_nvtx |
|
def partition_all_parameters(self): |
|
"""Partitioning Parameters that were not partitioned usually if parameters |
|
of modules whose input parameters do not require grad computation do not |
|
trigger post call and will therefore will remain unpartitioned""" |
|
self.get_param_coordinator().release_and_reset_all(self.module) |
|
for param in iter_params(self.module, recurse=True): |
|
if param.ds_status != ZeroParamStatus.NOT_AVAILABLE: |
|
raise RuntimeError(f"{param.ds_summary()} expected to be released") |
|
|
|
def get_param_coordinator(self): |
|
return self.param_coordinator |
|
|
|
def empty_partition_cache(self): |
|
self.partition_all_parameters() |
|
|
|
def _convert_to_zero_parameters(self, ds_config, module, mpu): |
|
non_zero_params = [p for p in module.parameters() if not is_zero_param(p)] |
|
if non_zero_params: |
|
zero_params = [p for p in module.parameters() if is_zero_param(p)] |
|
if zero_params: |
|
zero_params[0].convert_to_zero_parameters(param_list=non_zero_params) |
|
else: |
|
group = None |
|
|
|
if mpu and hasattr(mpu, "get_data_parallel_group"): |
|
group = mpu.get_data_parallel_group() |
|
|
|
Init(module=module, |
|
data_parallel_group=group, |
|
dtype=self.dtype, |
|
config_dict_or_path=ds_config, |
|
remote_device=self.offload_device, |
|
pin_memory=self.offload_param_pin_memory, |
|
mpu=mpu, |
|
zero_param_parallel_group=self.zero_param_parallel_group, |
|
zero_quantized_weights=self.zero_quantized_weights, |
|
zero_quantized_nontrainable_weights=self.zero_quantized_nontrainable_weights) |
|
|
|
def destroy(self): |
|
self._remove_module_hooks() |
|
|
|
def _remove_module_hooks(self): |
|
num_forward_hooks = len(self.forward_hooks) |
|
num_backward_hooks = len(self.backward_hooks) |
|
|
|
for hook in self.forward_hooks: |
|
hook.remove() |
|
|
|
for hook in self.backward_hooks: |
|
hook.remove() |
|
|
|
self.fwd_pre_hook.remove() |
|
|
|
print_rank_0(f'Deleted module hooks: forward = {num_forward_hooks}, backward = {num_backward_hooks}', |
|
force=False) |
|
|
|
def setup_zero_stage3_hooks(self): |
|
self.hierarchy = 0 |
|
|
|
|
|
@instrument_w_nvtx |
|
def _start_of_forward_hook(module, *args): |
|
|
|
self.get_param_coordinator().reset_step() |
|
|
|
self.fwd_pre_hook = self.module.register_forward_pre_hook(_start_of_forward_hook) |
|
|
|
|
|
self._register_deepspeed_module(self.module) |
|
|
|
|
|
global FWD_MODULE_STACK |
|
FWD_MODULE_STACK.append(self.module) |
|
|
|
def mark_persistent_parameters(self, param_threshold, model_threshold): |
|
persistent_params = [] |
|
total_persistent_parameters = 0 |
|
params_count = 0 |
|
for name, param in self.module.named_parameters(recurse=True): |
|
if param.ds_numel + total_persistent_parameters > model_threshold: |
|
continue |
|
|
|
if param.ds_numel <= param_threshold: |
|
params_count += 1 |
|
param.ds_persist = True |
|
persistent_params.append(param) |
|
total_persistent_parameters += param.ds_numel |
|
|
|
print_rank_0( |
|
f"Parameter Offload - Persistent parameters statistics: param_count = {params_count}, numel = {total_persistent_parameters}", |
|
force=True) |
|
|
|
return persistent_params |
|
|
|
def _register_deepspeed_module(self, module, count=[0]): |
|
my_count = count[0] |
|
module.ds_id = my_count |
|
|
|
|
|
|
|
if z3_leaf_module(module): |
|
for param in module.parameters(): |
|
param.ds_z3_leaf_module = module |
|
else: |
|
for child in module.children(): |
|
count[0] = count[0] + 1 |
|
self._register_deepspeed_module(child, count=count) |
|
|
|
@torch.compiler.disable |
|
def _pre_forward_module_hook(module, *args): |
|
self.pre_sub_module_forward_function(module) |
|
|
|
@instrument_w_nvtx |
|
def _post_forward_module_hook(module, input, output): |
|
|
|
global FWD_MODULE_STACK |
|
FWD_MODULE_STACK.pop() |
|
if output is None: |
|
output = [] |
|
elif not isinstance(output, (list, tuple)): |
|
if torch.is_tensor(output): |
|
output = [output] |
|
else: |
|
|
|
outputs = [] |
|
output = output if isinstance(output, dict) else vars(output) |
|
for name, val in output.items(): |
|
if not name.startswith('__') and torch.is_tensor(val): |
|
outputs.append(val) |
|
output = outputs |
|
|
|
for item in filter(lambda item: is_zero_param(item) or hasattr(item, 'ds_param_alias'), output): |
|
key = id(item) if hasattr(item, 'ds_id') else id(item.ds_param_alias) |
|
actual_external_param = item if hasattr(item, 'ds_id') else item.ds_param_alias |
|
|
|
if not any(key in m._external_params for m in FWD_MODULE_STACK): |
|
actual_external_param.is_external_param = True |
|
module_to_register = FWD_MODULE_STACK[-1] |
|
register_external_parameter(module_to_register, actual_external_param) |
|
print_rank_0( |
|
f'Registering dangling parameter for module {module_to_register.__class__.__name__}, ds_id = {actual_external_param.ds_id}.', |
|
force=False) |
|
|
|
|
|
|
|
if key in module._external_params: |
|
print_rank_0( |
|
f' Unregistering nested dangling parameter from module {module.__class__.__name__}, ds_id = {actual_external_param.ds_id}', |
|
force=False) |
|
unregister_external_parameter(module, actual_external_param) |
|
|
|
actual_external_param.all_gather() |
|
|
|
self.post_sub_module_forward_function(module) |
|
|
|
def _bwd_hook_unexpected_inputs_msg(value): |
|
return f"A module has unknown inputs or outputs type ({type(value)}) and the tensors embedded in it cannot be detected. " \ |
|
"The ZeRO-3 hooks designed to trigger before or after backward pass of the module relies on knowing the input and " \ |
|
"output tensors and therefore may not get triggered properly." |
|
|
|
def _pre_backward_module_hook(module, inputs, output): |
|
|
|
return apply_to_tensors_only(module.pre_bwd_fn.apply, |
|
output, |
|
warning_msg_fn=_bwd_hook_unexpected_inputs_msg) |
|
|
|
|
|
|
|
def _alternate_post_backward_module_hook(module, inputs): |
|
module.ds_grads_remaining = 0 |
|
|
|
|
|
|
|
def _run_after_backward_hook(*unused): |
|
module.ds_grads_remaining = module.ds_grads_remaining - 1 |
|
if module.ds_grads_remaining == 0: |
|
|
|
self.post_sub_module_backward_function(module) |
|
|
|
def _run_before_forward_function(input): |
|
if input.requires_grad: |
|
module.ds_grads_remaining += 1 |
|
|
|
return _apply_forward_and_backward_to_tensors_only(module, _run_before_forward_function, |
|
_run_after_backward_hook, inputs) |
|
|
|
@torch.compiler.disable |
|
def _post_backward_module_hook(module, inputs): |
|
module.ds_grads_remaining = 0 |
|
|
|
return apply_to_tensors_only(module.post_bwd_fn.apply, |
|
inputs, |
|
warning_msg_fn=_bwd_hook_unexpected_inputs_msg) |
|
|
|
|
|
self.forward_hooks.append(module.register_forward_pre_hook(_pre_forward_module_hook)) |
|
|
|
|
|
self.forward_hooks.append(module.register_forward_hook(_post_forward_module_hook)) |
|
|
|
|
|
if not hasattr(module, "pre_bwd_fn"): |
|
|
|
@instrument_w_nvtx |
|
def _run_before_backward_function(sub_module): |
|
|
|
|
|
|
|
|
|
if sub_module.applied_pre_backward_ref_cnt > 0: |
|
self.pre_sub_module_backward_function(sub_module) |
|
sub_module.applied_pre_backward_ref_cnt -= 1 |
|
|
|
|
|
class PreBackwardFunctionForModule(torch.autograd.Function): |
|
|
|
@staticmethod |
|
def forward(ctx, outputs): |
|
|
|
ctx.module = module |
|
ctx.pre_backward_function = _run_before_backward_function |
|
if not hasattr(ctx.module, "applied_pre_backward_ref_cnt"): |
|
ctx.module.applied_pre_backward_ref_cnt = 0 |
|
ctx.module.applied_pre_backward_ref_cnt += 1 |
|
outputs = outputs.detach() |
|
return outputs |
|
|
|
@staticmethod |
|
def backward(ctx, *args): |
|
ctx.pre_backward_function(ctx.module) |
|
return args |
|
|
|
module.pre_bwd_fn = PreBackwardFunctionForModule |
|
|
|
self.backward_hooks.append(module.register_forward_hook(_pre_backward_module_hook)) |
|
|
|
|
|
if not hasattr(module, "post_bwd_fn"): |
|
|
|
@instrument_w_nvtx |
|
def _run_after_backward_function(sub_module): |
|
if sub_module.ds_grads_remaining == 0: |
|
self.post_sub_module_backward_function(sub_module) |
|
|
|
class PostBackwardFunctionModule(torch.autograd.Function): |
|
|
|
@staticmethod |
|
def forward(ctx, output): |
|
ctx.module = module |
|
if output.requires_grad: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
module.ds_grads_remaining += 1 |
|
ctx.post_backward_function = _run_after_backward_function |
|
output = output.detach() |
|
return output |
|
|
|
@staticmethod |
|
def backward(ctx, *args): |
|
ctx.module.ds_grads_remaining = ctx.module.ds_grads_remaining - 1 |
|
if ctx.module.ds_grads_remaining == 0: |
|
ctx.post_backward_function(ctx.module) |
|
return args |
|
|
|
module.post_bwd_fn = PostBackwardFunctionModule |
|
|
|
self.backward_hooks.append(module.register_forward_pre_hook(_post_backward_module_hook)) |
|
|
|
@torch.no_grad() |
|
def pre_sub_module_forward_function(self, sub_module): |
|
see_memory_usage(f"Before sub module function {sub_module.__class__.__name__}", force=False) |
|
|
|
global FWD_MODULE_STACK |
|
FWD_MODULE_STACK.append(sub_module) |
|
|
|
param_coordinator = self.get_param_coordinator() |
|
param_coordinator.trace_prologue(sub_module) |
|
if param_coordinator.is_record_trace(): |
|
param_coordinator.record_module(sub_module) |
|
param_coordinator.fetch_sub_module(sub_module, forward=True) |
|
|
|
see_memory_usage(f"Before sub module function {sub_module.__class__.__name__} after fetch", force=False) |
|
|
|
@torch.no_grad() |
|
def post_sub_module_forward_function(self, sub_module): |
|
see_memory_usage( |
|
f"After sub module function {sub_module.__class__.__name__} {sub_module.ds_id} before release", |
|
force=False) |
|
|
|
param_coordinator = self.get_param_coordinator() |
|
param_coordinator.release_sub_module(sub_module, forward=True) |
|
|
|
see_memory_usage( |
|
f"After sub module function {sub_module.__class__.__name__} {sub_module.ds_id} after release", |
|
force=False) |
|
|
|
@torch.no_grad() |
|
def pre_sub_module_backward_function(self, sub_module): |
|
|
|
param_coordinator = self.get_param_coordinator() |
|
param_coordinator.trace_prologue(sub_module) |
|
if param_coordinator.is_record_trace(): |
|
param_coordinator.record_module(sub_module) |
|
param_coordinator.fetch_sub_module(sub_module, forward=False) |
|
|
|
@torch.no_grad() |
|
def post_sub_module_backward_function(self, sub_module): |
|
|
|
see_memory_usage( |
|
f"After sub module backward function {sub_module.__class__.__name__} {sub_module.ds_id} before release", |
|
force=False) |
|
|
|
self.get_param_coordinator().release_sub_module(sub_module, forward=False) |
|
|
|
see_memory_usage( |
|
f"After sub module backward function {sub_module.__class__.__name__} {sub_module.ds_id} after release", |
|
force=False) |
|
|
|
def _set_z3_leaf_modules_by_threshold(self, module, zero_module_granularity_threshold): |
|
|
|
self._get_granularity_recursively(module) |
|
print_rank_0(f"{'MODULE NAME'.ljust(30)}|{'GRANULARITY VALUE'.rjust(20)}", force=True) |
|
for granularity in self.granularity_info: |
|
print_rank_0(granularity, force=True) |
|
|
|
if self.min_granularity_value <= zero_module_granularity_threshold: |
|
self._set_leaf_by_threshold_preorder(module, zero_module_granularity_threshold) |
|
utils.logger.info( |
|
f"z3_leaf_module was set by stage3_module_granularity_threshold:{zero_module_granularity_threshold}") |
|
for layer in self.z3_leaf_layers: |
|
print_rank_0(f"{layer.__class__.__name__}:{layer.ds_model_granularity}", force=True) |
|
else: |
|
utils.logger.warning( |
|
f"The smallest module granularity is [{self.min_granularity_layer}:{self.min_granularity_value}]. "\ |
|
f"To make stage3_module_granularity_threshold effective, you need to set stage3_module_granularity_threshold >= {self.min_granularity_value}. "\ |
|
f"Current Value:{zero_module_granularity_threshold}" |
|
) |
|
|
|
def _get_granularity_recursively(self, module): |
|
"""This function is used to recursively obtain the granularity of each module.""" |
|
|
|
|
|
|
|
Z3_MAX_LEAF_SIZE = 1e9 |
|
|
|
if not list(module.parameters()): |
|
|
|
module.ds_model_granularity = sys.maxsize |
|
return 0, 0 |
|
|
|
num_layers = 0 |
|
num_params = 0 |
|
num_params += sum(p.ds_numel for p in module.parameters(recurse=False)) |
|
if not any(module.children()): |
|
|
|
module.ds_model_granularity = sys.maxsize |
|
return 1, num_params |
|
|
|
for child in module.children(): |
|
layers_in_child, params_in_child = self._get_granularity_recursively(child) |
|
num_layers += layers_in_child |
|
num_params += params_in_child |
|
|
|
if module.__class__.__name__ in torch.nn.modules.container.__all__: |
|
|
|
|
|
|
|
module.ds_model_granularity = sys.maxsize |
|
return num_layers, num_params |
|
|
|
num_layers += 1 |
|
ds_model_granularity = (num_params // num_layers) if num_params <= Z3_MAX_LEAF_SIZE else sys.maxsize |
|
module.ds_model_granularity = ds_model_granularity |
|
|
|
|
|
if self.min_granularity_value > ds_model_granularity: |
|
self.min_granularity_value = ds_model_granularity |
|
self.min_granularity_layer = module.__class__.__name__ |
|
self.granularity_info.add(f"{module.__class__.__name__.ljust(30)}|{str(ds_model_granularity).rjust(20)}") |
|
|
|
return num_layers, num_params |
|
|
|
def _set_leaf_by_threshold_preorder(self, module, granularity_treshhold): |
|
'''Set modules as leaf modules based on the threshold, prioritizing parent nodes.''' |
|
|
|
num_params = sum(p.ds_numel for p in module.parameters()) |
|
if num_params == 0: |
|
|
|
return |
|
if module.ds_model_granularity <= granularity_treshhold: |
|
set_z3_leaf_module(module, True) |
|
self.z3_leaf_layers.append(module) |
|
return |
|
|
|
for sub_module in module.children(): |
|
self._set_leaf_by_threshold_preorder(sub_module, granularity_treshhold) |
|
|