File size: 4,121 Bytes
9c6594c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 |
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import torch
from deepspeed import comm as dist
from deepspeed.accelerator import get_accelerator
from deepspeed.runtime.zero.partition_parameters import InsertPostInitMethodToModuleSubClasses
from deepspeed.runtime.zero.parameter_offload import DeepSpeedZeRoOffload
from .passes import zero3_compile, prefetch, selective_gather, offload_parameters
from .backend import make_backend, launch_compile_passes, init_schedule
from .patch_fake_tensor import patch_fake_tensor
from .util import get_deepcompile_handle, add_pre_backward_hook, is_backend_inductor
WARMUP = 5
def init_z3(engine, backend, compile_config, compile_kwargs, schedule=None):
optimizer = engine.optimizer
use_opt = not isinstance(optimizer, DeepSpeedZeRoOffload)
if use_opt and hasattr(optimizer, '_DeepSpeedZeroOptimizer_Stage3__ipg_bucket_flat_buffer'):
optimizer._DeepSpeedZeroOptimizer_Stage3__ipg_bucket_flat_buffer = None
get_accelerator().empty_cache()
dc = get_deepcompile_handle()
dc.init(engine.data_parallel_group,
engine.zero_reduce_bucket_size(), compile_config.double_buffer, compile_config.symmetric_memory,
is_backend_inductor(backend), compile_config.sync_before_reduce, compile_config.sync_after_reduce,
compile_config.sync_before_allgather, compile_config.sync_after_allgather)
# Unset hooks
for m in engine.module.modules():
m._parameters = m._original_parameters
if use_opt:
optimizer.parameter_offload._remove_module_hooks()
for hook in optimizer._grad_acc_hooks:
hook.remove()
optimizer._grad_acc_hooks.clear()
# Unpatch linear
if hasattr(InsertPostInitMethodToModuleSubClasses, "linear_bk"):
torch.nn.functional.linear = InsertPostInitMethodToModuleSubClasses.linear_bk
if compile_config.symmetric_memory:
group_name = engine.data_parallel_group.group_name
dist.enable_symm_mem_for_group(group_name)
for p in engine.module.parameters():
grad_buffer = torch.Tensor()
if use_opt:
grad_buffer = optimizer._DeepSpeedZeroOptimizer_Stage3__param_id_to_grad_partition[p.ds_id]
# Disable persistent param
p.ds_persist = False
dc.register_z3_param(p.ds_id, p.ds_shape, p.ds_tensor, grad_buffer, p.ds_persist)
if schedule is None:
schedule = []
if (compile_config.offload_parameters):
schedule.append((0, [zero3_compile.add_z3_gather_release, offload_parameters.offload_parameter_fwd]))
else:
schedule.append((0, [zero3_compile.add_z3_gather_release]))
schedule.append(
(WARMUP,
[zero3_compile.add_z3_gather_release, prefetch.schedule_prefetch, selective_gather.selective_gather]))
init_schedule(schedule)
if use_opt:
def set_grad_buffer():
for i, sub_group in enumerate(optimizer.fp16_groups):
optimizer.averaged_gradients[i] = [
optimizer._DeepSpeedZeroOptimizer_Stage3__param_id_to_grad_partition[param.ds_id]
if param.requires_grad else torch.zeros_like(param.ds_tensor) for param in sub_group
]
add_pre_backward_hook(set_grad_buffer)
# offloading opt states need additional setup
from .passes.offload_adam_states import move_opt_states, move_opt_states_sync, init_offload_opt_states
for _, passes in schedule:
if move_opt_states in passes or move_opt_states_sync in passes:
init_offload_opt_states(optimizer, dc)
engine.launch_compile_passes = launch_compile_passes
patch_fake_tensor()
free_activation = compile_config.free_activation and not is_backend_inductor(backend)
torch._inductor.config.size_asserts = False
return make_backend(backend,
compile_kwargs=compile_kwargs,
free_activation=free_activation,
debug_log=compile_config.debug_log)
|