|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
|
from deepspeed import comm as dist |
|
from deepspeed.accelerator import get_accelerator |
|
from deepspeed.runtime.zero.partition_parameters import InsertPostInitMethodToModuleSubClasses |
|
from deepspeed.runtime.zero.parameter_offload import DeepSpeedZeRoOffload |
|
|
|
from .passes import zero3_compile, prefetch, selective_gather, offload_parameters |
|
from .backend import make_backend, launch_compile_passes, init_schedule |
|
from .patch_fake_tensor import patch_fake_tensor |
|
from .util import get_deepcompile_handle, add_pre_backward_hook, is_backend_inductor |
|
|
|
WARMUP = 5 |
|
|
|
|
|
def init_z3(engine, backend, compile_config, compile_kwargs, schedule=None): |
|
|
|
optimizer = engine.optimizer |
|
use_opt = not isinstance(optimizer, DeepSpeedZeRoOffload) |
|
|
|
if use_opt and hasattr(optimizer, '_DeepSpeedZeroOptimizer_Stage3__ipg_bucket_flat_buffer'): |
|
optimizer._DeepSpeedZeroOptimizer_Stage3__ipg_bucket_flat_buffer = None |
|
get_accelerator().empty_cache() |
|
|
|
dc = get_deepcompile_handle() |
|
dc.init(engine.data_parallel_group, |
|
engine.zero_reduce_bucket_size(), compile_config.double_buffer, compile_config.symmetric_memory, |
|
is_backend_inductor(backend), compile_config.sync_before_reduce, compile_config.sync_after_reduce, |
|
compile_config.sync_before_allgather, compile_config.sync_after_allgather) |
|
|
|
|
|
for m in engine.module.modules(): |
|
m._parameters = m._original_parameters |
|
|
|
if use_opt: |
|
optimizer.parameter_offload._remove_module_hooks() |
|
|
|
for hook in optimizer._grad_acc_hooks: |
|
hook.remove() |
|
optimizer._grad_acc_hooks.clear() |
|
|
|
|
|
if hasattr(InsertPostInitMethodToModuleSubClasses, "linear_bk"): |
|
torch.nn.functional.linear = InsertPostInitMethodToModuleSubClasses.linear_bk |
|
|
|
if compile_config.symmetric_memory: |
|
group_name = engine.data_parallel_group.group_name |
|
dist.enable_symm_mem_for_group(group_name) |
|
|
|
for p in engine.module.parameters(): |
|
grad_buffer = torch.Tensor() |
|
if use_opt: |
|
grad_buffer = optimizer._DeepSpeedZeroOptimizer_Stage3__param_id_to_grad_partition[p.ds_id] |
|
|
|
|
|
p.ds_persist = False |
|
dc.register_z3_param(p.ds_id, p.ds_shape, p.ds_tensor, grad_buffer, p.ds_persist) |
|
|
|
if schedule is None: |
|
schedule = [] |
|
if (compile_config.offload_parameters): |
|
schedule.append((0, [zero3_compile.add_z3_gather_release, offload_parameters.offload_parameter_fwd])) |
|
else: |
|
schedule.append((0, [zero3_compile.add_z3_gather_release])) |
|
schedule.append( |
|
(WARMUP, |
|
[zero3_compile.add_z3_gather_release, prefetch.schedule_prefetch, selective_gather.selective_gather])) |
|
|
|
init_schedule(schedule) |
|
|
|
if use_opt: |
|
|
|
def set_grad_buffer(): |
|
for i, sub_group in enumerate(optimizer.fp16_groups): |
|
optimizer.averaged_gradients[i] = [ |
|
optimizer._DeepSpeedZeroOptimizer_Stage3__param_id_to_grad_partition[param.ds_id] |
|
if param.requires_grad else torch.zeros_like(param.ds_tensor) for param in sub_group |
|
] |
|
|
|
add_pre_backward_hook(set_grad_buffer) |
|
|
|
|
|
from .passes.offload_adam_states import move_opt_states, move_opt_states_sync, init_offload_opt_states |
|
for _, passes in schedule: |
|
if move_opt_states in passes or move_opt_states_sync in passes: |
|
init_offload_opt_states(optimizer, dc) |
|
|
|
engine.launch_compile_passes = launch_compile_passes |
|
|
|
patch_fake_tensor() |
|
free_activation = compile_config.free_activation and not is_backend_inductor(backend) |
|
|
|
torch._inductor.config.size_asserts = False |
|
|
|
return make_backend(backend, |
|
compile_kwargs=compile_kwargs, |
|
free_activation=free_activation, |
|
debug_log=compile_config.debug_log) |
|
|