# Copyright (c) Microsoft Corporation. # SPDX-License-Identifier: Apache-2.0 # DeepSpeed Team import torch from deepspeed import comm as dist from deepspeed.accelerator import get_accelerator from deepspeed.runtime.zero.partition_parameters import InsertPostInitMethodToModuleSubClasses from deepspeed.runtime.zero.parameter_offload import DeepSpeedZeRoOffload from .passes import zero3_compile, prefetch, selective_gather, offload_parameters from .backend import make_backend, launch_compile_passes, init_schedule from .patch_fake_tensor import patch_fake_tensor from .util import get_deepcompile_handle, add_pre_backward_hook, is_backend_inductor WARMUP = 5 def init_z3(engine, backend, compile_config, compile_kwargs, schedule=None): optimizer = engine.optimizer use_opt = not isinstance(optimizer, DeepSpeedZeRoOffload) if use_opt and hasattr(optimizer, '_DeepSpeedZeroOptimizer_Stage3__ipg_bucket_flat_buffer'): optimizer._DeepSpeedZeroOptimizer_Stage3__ipg_bucket_flat_buffer = None get_accelerator().empty_cache() dc = get_deepcompile_handle() dc.init(engine.data_parallel_group, engine.zero_reduce_bucket_size(), compile_config.double_buffer, compile_config.symmetric_memory, is_backend_inductor(backend), compile_config.sync_before_reduce, compile_config.sync_after_reduce, compile_config.sync_before_allgather, compile_config.sync_after_allgather) # Unset hooks for m in engine.module.modules(): m._parameters = m._original_parameters if use_opt: optimizer.parameter_offload._remove_module_hooks() for hook in optimizer._grad_acc_hooks: hook.remove() optimizer._grad_acc_hooks.clear() # Unpatch linear if hasattr(InsertPostInitMethodToModuleSubClasses, "linear_bk"): torch.nn.functional.linear = InsertPostInitMethodToModuleSubClasses.linear_bk if compile_config.symmetric_memory: group_name = engine.data_parallel_group.group_name dist.enable_symm_mem_for_group(group_name) for p in engine.module.parameters(): grad_buffer = torch.Tensor() if use_opt: grad_buffer = optimizer._DeepSpeedZeroOptimizer_Stage3__param_id_to_grad_partition[p.ds_id] # Disable persistent param p.ds_persist = False dc.register_z3_param(p.ds_id, p.ds_shape, p.ds_tensor, grad_buffer, p.ds_persist) if schedule is None: schedule = [] if (compile_config.offload_parameters): schedule.append((0, [zero3_compile.add_z3_gather_release, offload_parameters.offload_parameter_fwd])) else: schedule.append((0, [zero3_compile.add_z3_gather_release])) schedule.append( (WARMUP, [zero3_compile.add_z3_gather_release, prefetch.schedule_prefetch, selective_gather.selective_gather])) init_schedule(schedule) if use_opt: def set_grad_buffer(): for i, sub_group in enumerate(optimizer.fp16_groups): optimizer.averaged_gradients[i] = [ optimizer._DeepSpeedZeroOptimizer_Stage3__param_id_to_grad_partition[param.ds_id] if param.requires_grad else torch.zeros_like(param.ds_tensor) for param in sub_group ] add_pre_backward_hook(set_grad_buffer) # offloading opt states need additional setup from .passes.offload_adam_states import move_opt_states, move_opt_states_sync, init_offload_opt_states for _, passes in schedule: if move_opt_states in passes or move_opt_states_sync in passes: init_offload_opt_states(optimizer, dc) engine.launch_compile_passes = launch_compile_passes patch_fake_tensor() free_activation = compile_config.free_activation and not is_backend_inductor(backend) torch._inductor.config.size_asserts = False return make_backend(backend, compile_kwargs=compile_kwargs, free_activation=free_activation, debug_log=compile_config.debug_log)