# Copyright (c) Microsoft Corporation. # SPDX-License-Identifier: Apache-2.0 # DeepSpeed Team from typing import Dict, List, Callable import time import gc import torch from torch.fx import Graph, GraphModule try: import torch.utils._pytree as pytree import torch._dynamo from functorch.compile import make_boxed_func from torch._functorch.aot_autograd import aot_module_simplified from torch._subclasses.fake_tensor import unset_fake_temporarily from torch._subclasses.fake_tensor import is_fake except ImportError: pass from deepspeed.accelerator import get_accelerator from .fx import add_free_activations from .graph_param import DSGraphParamManager from .profilers import ProfilingResult from .profilers.graph_profile import MemoryProfilingInterpreter from .patch_compiled_func import patch_compiled_func, unpatch_compiled_func, get_backward_inputs from .util import get_input_nodes, get_activation_node_names, get_index_by_graph_id, get_deepcompile_handle, log_rank0 from .partitioner import get_wrapped_partitioner from .inductor import register_custom_ops, patch_create_aot_dispatcher_function remaining_schedule = None next_pass_step = -1 next_passes = None current_passes = None param_manager: Dict[int, DSGraphParamManager] = {} class GraphOrder: def __init__(self): self.ordered_frames = [] self.frames = {} def add_graph(self, graph_id, frame_id, needs_backward): if frame_id not in self.ordered_frames: self.ordered_frames.append(frame_id) self.frames[frame_id] = (graph_id, needs_backward) def get_graph_order(self): return [self.frames[frame_id] for frame_id in self.ordered_frames] def clear(self): self.frames.clear() graph_order_with_frame_id = GraphOrder() frames_needing_bwd = set() profiling_results: Dict[int, ProfilingResult] = {} opt_pass_times = [] opt_passes = {} fwd_real_inputs = [] def register_compile_pass(name: str, opt_pass_fn): opt_passes[name] = opt_pass_fn def init_schedule(schedule): assert isinstance(schedule, list), f"schedule should be a list, but got {type(schedule)}" for step, passes in schedule: assert isinstance(step, int), f"Each step in schedule should be an integer, but got {type(step)}" assert isinstance(passes, list), f"Passes at a certain step should be a list, but got {type(passes)}" global remaining_schedule remaining_schedule = schedule def launch_compile_passes(global_steps: int): global next_pass_step, next_passes if len(remaining_schedule) > 0 and global_steps == remaining_schedule[0][0]: _, next_passes = remaining_schedule.pop(0) log_rank0(f"Launching compile passes: global_steps={global_steps} passes={next_passes}", True) torch._dynamo.reset() get_deepcompile_handle().reset() graph_order_with_frame_id.clear() profiling_results.clear() param_manager.clear() def set_time_and_tensor_size(graph_id, graph: Graph, mem, bwd, profiling_results): node_time = [] tensor_sizes = [] for n in graph.nodes: node_time.append((n.name, n.meta["device_time"] if "device_time" in n.meta else 0.0, n.meta["wall_time"] if "wall_time" in n.meta else 0.0)) tensor_sizes.append((n.name, n.meta["tensor_size"] if "tensor_size" in n.meta else 0)) if bwd: profiling_results[graph_id].bwd_graph = graph profiling_results[graph_id].bwd_time = node_time profiling_results[graph_id].bwd_tensor_sizes = tensor_sizes profiling_results[graph_id].bwd_mem = mem else: profiling_results[graph_id].fwd_graph = graph profiling_results[graph_id].fwd_time = node_time profiling_results[graph_id].fwd_tensor_sizes = tensor_sizes profiling_results[graph_id].fwd_mem = mem def evaluate_symint_from_shape_env(sym_int_v): assert isinstance(sym_int_v, torch.SymInt) # shape_env = sym_int_v.node.shape_env # v = shape_env.evaluate_sym_node(sym_int_v.node) return sym_int_v.node.hint def set_example_values_to_symints(real_inputs): real_inputs_ret = [] for v in real_inputs: if isinstance(v, torch.Tensor): if is_fake(v): shape = [] for fs in v.shape: if isinstance(fs, torch.SymInt): shape.append(evaluate_symint_from_shape_env(fs)) else: shape.append(fs) stride = [] for fs in v.stride(): if isinstance(fs, torch.SymInt): stride.append(evaluate_symint_from_shape_env(fs)) else: stride.append(fs) with unset_fake_temporarily(): dummy_v = torch.ones(shape, dtype=v.dtype, layout=v.layout, device=v.device, requires_grad=v.requires_grad).as_strided(shape, stride) real_inputs_ret.append(dummy_v) else: real_inputs_ret.append(v) else: if isinstance(v, torch.SymInt): real_inputs_ret.append(evaluate_symint_from_shape_env(v)) else: real_inputs_ret.append(v) return tuple(real_inputs_ret) def run_opt_passes(opt_passes: List[Callable], gm: GraphModule, graph_id: int, graph_order: List[int], profiling_results, create_inputs_fn, mem_budget: float, param_manager, bwd: bool, debug_log=False) -> None: with unset_fake_temporarily(): get_accelerator().synchronize() gc.collect() get_accelerator().empty_cache() for i, opt_pass_fn in enumerate(opt_passes): log_rank0(f"Running opt pass {i} for graph {graph_id}. bwd={bwd}", enable=debug_log) gm_new = opt_pass_fn(gm, graph_id, graph_order, profiling_results, create_inputs_fn, mem_budget, param_manager, bwd) if gm_new is not None: gm = gm_new gm.graph.lint() gm.recompile() mem_prof = MemoryProfilingInterpreter(gm, debug_log=debug_log) mem_prof.run(*create_inputs_fn()) mem = [(name, current_alloc, delta, peak) for name, current_alloc, delta, peak in mem_prof.mem_record] set_time_and_tensor_size(graph_id, gm.graph, mem, bwd, profiling_results) with unset_fake_temporarily(): get_accelerator().synchronize() gc.collect() get_accelerator().empty_cache() def make_backend(backend, compile_kwargs={}, free_activation=False, debug_log=False): register_custom_ops() def backend_fn(gm: GraphModule, real_inputs): graph_id = id(gm.graph) needs_backward = pytree.tree_any(lambda x: x.requires_grad if torch.is_tensor(x) else False, real_inputs) frame_id = gm.meta["dynamo_compile_id"].frame_id graph_order_with_frame_id.add_graph(graph_id, frame_id, needs_backward) if needs_backward: if len(frames_needing_bwd) == 0: patch_compiled_func() frames_needing_bwd.add(frame_id) graph_order = graph_order_with_frame_id.get_graph_order() z3_partition = any(hasattr(v, "ds_id") for v in real_inputs) if z3_partition: param_indices = [(i, input_val.ds_id, input_val.ds_shape) for i, input_val in enumerate(real_inputs) if isinstance(input_val, torch.nn.Parameter)] else: assert all(hasattr(v, "param_id") for v in real_inputs if isinstance(v, torch.nn.Parameter)), "All param inputs should have param_id" param_indices = [(i, input_val.param_id, input_val.shape) for i, input_val in enumerate(real_inputs) if isinstance(input_val, torch.nn.Parameter)] global fwd_real_inputs fwd_real_inputs.append(real_inputs) global profiling_results if graph_id not in profiling_results: profiling_results[graph_id] = ProfilingResult() profiling_results[graph_id].param_indices = param_indices profiling_results[graph_id].needs_backward = needs_backward def make_fw_graph(gm, sample_inputs): time_start = time.time() graph_index = len(graph_order) - 1 real_inputs = fwd_real_inputs.pop(0) real_inputs = set_example_values_to_symints(real_inputs) param_manager[graph_id] = DSGraphParamManager(gm.graph, real_inputs, param_indices) real_inputs_with_rng = real_inputs + tuple(sample_inputs[len(real_inputs):]) run_opt_passes( opt_passes=next_passes, gm=gm, graph_id=graph_id, graph_order=graph_order, profiling_results=profiling_results, create_inputs_fn=lambda: real_inputs_with_rng, mem_budget=.0, # unused param_manager=param_manager, bwd=False, debug_log=debug_log) opt_pass_times.append(("fwd", graph_index, graph_id, time.time() - time_start)) log_rank0(f"Fwd end {graph_index} graph_id={graph_id} alloc_mem={get_accelerator().memory_allocated()}", enable=debug_log) return gm.graph def make_bw_graph(gm, sample_inputs): time_start = time.time() graph_index = get_index_by_graph_id(graph_order, graph_id) log_rank0( f"Bwd start {graph_index} graph_id={graph_id} alloc_mem={get_accelerator().memory_allocated()} graph={gm.graph}", enable=debug_log) bwd_inputs_stack = get_backward_inputs() if len(bwd_inputs_stack) == 0: # dynamo calls bw compiler ahead of time when symints are saved for backward. See the details for aot_dispatch_autograd in jit_compile_runtime_wrappers. # As we currently use actually bwd input values in bw compiler, we make dummy data for profiling. bwd_real_inputs = set_example_values_to_symints(sample_inputs) else: bwd_real_inputs = bwd_inputs_stack.pop() run_opt_passes( opt_passes=next_passes, gm=gm, graph_id=graph_id, graph_order=graph_order, profiling_results=profiling_results, create_inputs_fn=lambda: tuple(bwd_real_inputs), mem_budget=.0, # unused param_manager=param_manager, bwd=True, debug_log=debug_log) # assert graph_id in param_manager, f"Graph {graph_id} not found in param_manager" if free_activation: param_nodes_bw, _ = param_manager[graph_id].get_bwd_mapping(gm.graph) param_names = [n.name for n in param_nodes_bw] non_param_input_names = [n.name for n in get_input_nodes(gm.graph) if n.name not in param_names] add_free_activations(graph_id, gm.graph, get_activation_node_names(gm.graph, param_nodes_bw, non_param_input_names)) frames_needing_bwd.remove(frame_id) if len(frames_needing_bwd) == 0: unpatch_compiled_func() log_rank0( f"Bwd end {graph_index} graph_id={graph_id} alloc_mem={get_accelerator().memory_allocated()} graph={gm.graph}", enable=debug_log) opt_pass_times.append(("bwd", graph_index, graph_id, time.time() - time_start)) return gm.graph if backend == "eager": def make_compiler_fn(make_graph_fn): def compiler_fn(gm, sample_inputs): return None if make_graph_fn(gm, sample_inputs) is None else make_boxed_func(gm.forward) return compiler_fn aot_mod = aot_module_simplified(gm, real_inputs, fw_compiler=make_compiler_fn(make_fw_graph), bw_compiler=make_compiler_fn(make_bw_graph), partition_fn=get_wrapped_partitioner(param_indices)) return torch._dynamo.optimize(**compile_kwargs)(aot_mod) elif backend == "inductor": patch_create_aot_dispatcher_function(graph_id, z3_partition, make_fw_graph, make_bw_graph, real_inputs, param_indices, param_manager) from .partitioner import get_wrapped_choose_saved_values_set torch._functorch.partitioners.choose_saved_values_set = get_wrapped_choose_saved_values_set(param_indices) return torch._inductor.compile(gm, real_inputs) raise ValueError(f"Unsupported backend {backend}") return backend_fn