|
|
|
|
|
|
|
|
|
|
|
from typing import Dict, List, Callable |
|
import time |
|
import gc |
|
|
|
import torch |
|
from torch.fx import Graph, GraphModule |
|
|
|
try: |
|
import torch.utils._pytree as pytree |
|
import torch._dynamo |
|
from functorch.compile import make_boxed_func |
|
from torch._functorch.aot_autograd import aot_module_simplified |
|
from torch._subclasses.fake_tensor import unset_fake_temporarily |
|
from torch._subclasses.fake_tensor import is_fake |
|
except ImportError: |
|
pass |
|
|
|
from deepspeed.accelerator import get_accelerator |
|
|
|
from .fx import add_free_activations |
|
from .graph_param import DSGraphParamManager |
|
from .profilers import ProfilingResult |
|
from .profilers.graph_profile import MemoryProfilingInterpreter |
|
from .patch_compiled_func import patch_compiled_func, unpatch_compiled_func, get_backward_inputs |
|
from .util import get_input_nodes, get_activation_node_names, get_index_by_graph_id, get_deepcompile_handle, log_rank0 |
|
from .partitioner import get_wrapped_partitioner |
|
from .inductor import register_custom_ops, patch_create_aot_dispatcher_function |
|
|
|
remaining_schedule = None |
|
next_pass_step = -1 |
|
next_passes = None |
|
current_passes = None |
|
|
|
param_manager: Dict[int, DSGraphParamManager] = {} |
|
|
|
|
|
class GraphOrder: |
|
|
|
def __init__(self): |
|
self.ordered_frames = [] |
|
self.frames = {} |
|
|
|
def add_graph(self, graph_id, frame_id, needs_backward): |
|
if frame_id not in self.ordered_frames: |
|
self.ordered_frames.append(frame_id) |
|
|
|
self.frames[frame_id] = (graph_id, needs_backward) |
|
|
|
def get_graph_order(self): |
|
return [self.frames[frame_id] for frame_id in self.ordered_frames] |
|
|
|
def clear(self): |
|
self.frames.clear() |
|
|
|
|
|
graph_order_with_frame_id = GraphOrder() |
|
|
|
frames_needing_bwd = set() |
|
profiling_results: Dict[int, ProfilingResult] = {} |
|
opt_pass_times = [] |
|
opt_passes = {} |
|
|
|
fwd_real_inputs = [] |
|
|
|
|
|
def register_compile_pass(name: str, opt_pass_fn): |
|
opt_passes[name] = opt_pass_fn |
|
|
|
|
|
def init_schedule(schedule): |
|
|
|
assert isinstance(schedule, list), f"schedule should be a list, but got {type(schedule)}" |
|
|
|
for step, passes in schedule: |
|
assert isinstance(step, int), f"Each step in schedule should be an integer, but got {type(step)}" |
|
assert isinstance(passes, list), f"Passes at a certain step should be a list, but got {type(passes)}" |
|
|
|
global remaining_schedule |
|
remaining_schedule = schedule |
|
|
|
|
|
def launch_compile_passes(global_steps: int): |
|
global next_pass_step, next_passes |
|
|
|
if len(remaining_schedule) > 0 and global_steps == remaining_schedule[0][0]: |
|
_, next_passes = remaining_schedule.pop(0) |
|
log_rank0(f"Launching compile passes: global_steps={global_steps} passes={next_passes}", True) |
|
|
|
torch._dynamo.reset() |
|
get_deepcompile_handle().reset() |
|
graph_order_with_frame_id.clear() |
|
profiling_results.clear() |
|
param_manager.clear() |
|
|
|
|
|
def set_time_and_tensor_size(graph_id, graph: Graph, mem, bwd, profiling_results): |
|
node_time = [] |
|
tensor_sizes = [] |
|
|
|
for n in graph.nodes: |
|
node_time.append((n.name, n.meta["device_time"] if "device_time" in n.meta else 0.0, |
|
n.meta["wall_time"] if "wall_time" in n.meta else 0.0)) |
|
tensor_sizes.append((n.name, n.meta["tensor_size"] if "tensor_size" in n.meta else 0)) |
|
|
|
if bwd: |
|
profiling_results[graph_id].bwd_graph = graph |
|
profiling_results[graph_id].bwd_time = node_time |
|
profiling_results[graph_id].bwd_tensor_sizes = tensor_sizes |
|
profiling_results[graph_id].bwd_mem = mem |
|
else: |
|
profiling_results[graph_id].fwd_graph = graph |
|
profiling_results[graph_id].fwd_time = node_time |
|
profiling_results[graph_id].fwd_tensor_sizes = tensor_sizes |
|
profiling_results[graph_id].fwd_mem = mem |
|
|
|
|
|
def evaluate_symint_from_shape_env(sym_int_v): |
|
assert isinstance(sym_int_v, torch.SymInt) |
|
|
|
|
|
return sym_int_v.node.hint |
|
|
|
|
|
def set_example_values_to_symints(real_inputs): |
|
real_inputs_ret = [] |
|
for v in real_inputs: |
|
if isinstance(v, torch.Tensor): |
|
if is_fake(v): |
|
shape = [] |
|
for fs in v.shape: |
|
if isinstance(fs, torch.SymInt): |
|
shape.append(evaluate_symint_from_shape_env(fs)) |
|
else: |
|
shape.append(fs) |
|
stride = [] |
|
for fs in v.stride(): |
|
if isinstance(fs, torch.SymInt): |
|
stride.append(evaluate_symint_from_shape_env(fs)) |
|
else: |
|
stride.append(fs) |
|
with unset_fake_temporarily(): |
|
dummy_v = torch.ones(shape, |
|
dtype=v.dtype, |
|
layout=v.layout, |
|
device=v.device, |
|
requires_grad=v.requires_grad).as_strided(shape, stride) |
|
real_inputs_ret.append(dummy_v) |
|
else: |
|
real_inputs_ret.append(v) |
|
else: |
|
if isinstance(v, torch.SymInt): |
|
real_inputs_ret.append(evaluate_symint_from_shape_env(v)) |
|
else: |
|
real_inputs_ret.append(v) |
|
|
|
return tuple(real_inputs_ret) |
|
|
|
|
|
def run_opt_passes(opt_passes: List[Callable], |
|
gm: GraphModule, |
|
graph_id: int, |
|
graph_order: List[int], |
|
profiling_results, |
|
create_inputs_fn, |
|
mem_budget: float, |
|
param_manager, |
|
bwd: bool, |
|
debug_log=False) -> None: |
|
|
|
with unset_fake_temporarily(): |
|
get_accelerator().synchronize() |
|
gc.collect() |
|
get_accelerator().empty_cache() |
|
|
|
for i, opt_pass_fn in enumerate(opt_passes): |
|
log_rank0(f"Running opt pass {i} for graph {graph_id}. bwd={bwd}", enable=debug_log) |
|
|
|
gm_new = opt_pass_fn(gm, graph_id, graph_order, profiling_results, create_inputs_fn, mem_budget, param_manager, |
|
bwd) |
|
if gm_new is not None: |
|
gm = gm_new |
|
gm.graph.lint() |
|
gm.recompile() |
|
|
|
mem_prof = MemoryProfilingInterpreter(gm, debug_log=debug_log) |
|
mem_prof.run(*create_inputs_fn()) |
|
mem = [(name, current_alloc, delta, peak) for name, current_alloc, delta, peak in mem_prof.mem_record] |
|
|
|
set_time_and_tensor_size(graph_id, gm.graph, mem, bwd, profiling_results) |
|
|
|
with unset_fake_temporarily(): |
|
get_accelerator().synchronize() |
|
gc.collect() |
|
get_accelerator().empty_cache() |
|
|
|
|
|
def make_backend(backend, compile_kwargs={}, free_activation=False, debug_log=False): |
|
|
|
register_custom_ops() |
|
|
|
def backend_fn(gm: GraphModule, real_inputs): |
|
graph_id = id(gm.graph) |
|
|
|
needs_backward = pytree.tree_any(lambda x: x.requires_grad if torch.is_tensor(x) else False, real_inputs) |
|
frame_id = gm.meta["dynamo_compile_id"].frame_id |
|
graph_order_with_frame_id.add_graph(graph_id, frame_id, needs_backward) |
|
|
|
if needs_backward: |
|
if len(frames_needing_bwd) == 0: |
|
patch_compiled_func() |
|
frames_needing_bwd.add(frame_id) |
|
|
|
graph_order = graph_order_with_frame_id.get_graph_order() |
|
|
|
z3_partition = any(hasattr(v, "ds_id") for v in real_inputs) |
|
if z3_partition: |
|
param_indices = [(i, input_val.ds_id, input_val.ds_shape) for i, input_val in enumerate(real_inputs) |
|
if isinstance(input_val, torch.nn.Parameter)] |
|
else: |
|
assert all(hasattr(v, "param_id") for v in real_inputs |
|
if isinstance(v, torch.nn.Parameter)), "All param inputs should have param_id" |
|
param_indices = [(i, input_val.param_id, input_val.shape) for i, input_val in enumerate(real_inputs) |
|
if isinstance(input_val, torch.nn.Parameter)] |
|
|
|
global fwd_real_inputs |
|
fwd_real_inputs.append(real_inputs) |
|
|
|
global profiling_results |
|
if graph_id not in profiling_results: |
|
profiling_results[graph_id] = ProfilingResult() |
|
profiling_results[graph_id].param_indices = param_indices |
|
profiling_results[graph_id].needs_backward = needs_backward |
|
|
|
def make_fw_graph(gm, sample_inputs): |
|
time_start = time.time() |
|
graph_index = len(graph_order) - 1 |
|
real_inputs = fwd_real_inputs.pop(0) |
|
real_inputs = set_example_values_to_symints(real_inputs) |
|
|
|
param_manager[graph_id] = DSGraphParamManager(gm.graph, real_inputs, param_indices) |
|
|
|
real_inputs_with_rng = real_inputs + tuple(sample_inputs[len(real_inputs):]) |
|
run_opt_passes( |
|
opt_passes=next_passes, |
|
gm=gm, |
|
graph_id=graph_id, |
|
graph_order=graph_order, |
|
profiling_results=profiling_results, |
|
create_inputs_fn=lambda: real_inputs_with_rng, |
|
mem_budget=.0, |
|
param_manager=param_manager, |
|
bwd=False, |
|
debug_log=debug_log) |
|
|
|
opt_pass_times.append(("fwd", graph_index, graph_id, time.time() - time_start)) |
|
|
|
log_rank0(f"Fwd end {graph_index} graph_id={graph_id} alloc_mem={get_accelerator().memory_allocated()}", |
|
enable=debug_log) |
|
|
|
return gm.graph |
|
|
|
def make_bw_graph(gm, sample_inputs): |
|
time_start = time.time() |
|
|
|
graph_index = get_index_by_graph_id(graph_order, graph_id) |
|
log_rank0( |
|
f"Bwd start {graph_index} graph_id={graph_id} alloc_mem={get_accelerator().memory_allocated()} graph={gm.graph}", |
|
enable=debug_log) |
|
|
|
bwd_inputs_stack = get_backward_inputs() |
|
|
|
if len(bwd_inputs_stack) == 0: |
|
|
|
|
|
bwd_real_inputs = set_example_values_to_symints(sample_inputs) |
|
else: |
|
bwd_real_inputs = bwd_inputs_stack.pop() |
|
|
|
run_opt_passes( |
|
opt_passes=next_passes, |
|
gm=gm, |
|
graph_id=graph_id, |
|
graph_order=graph_order, |
|
profiling_results=profiling_results, |
|
create_inputs_fn=lambda: tuple(bwd_real_inputs), |
|
mem_budget=.0, |
|
param_manager=param_manager, |
|
bwd=True, |
|
debug_log=debug_log) |
|
|
|
|
|
|
|
if free_activation: |
|
param_nodes_bw, _ = param_manager[graph_id].get_bwd_mapping(gm.graph) |
|
param_names = [n.name for n in param_nodes_bw] |
|
non_param_input_names = [n.name for n in get_input_nodes(gm.graph) if n.name not in param_names] |
|
add_free_activations(graph_id, gm.graph, |
|
get_activation_node_names(gm.graph, param_nodes_bw, non_param_input_names)) |
|
|
|
frames_needing_bwd.remove(frame_id) |
|
if len(frames_needing_bwd) == 0: |
|
unpatch_compiled_func() |
|
|
|
log_rank0( |
|
f"Bwd end {graph_index} graph_id={graph_id} alloc_mem={get_accelerator().memory_allocated()} graph={gm.graph}", |
|
enable=debug_log) |
|
|
|
opt_pass_times.append(("bwd", graph_index, graph_id, time.time() - time_start)) |
|
|
|
return gm.graph |
|
|
|
if backend == "eager": |
|
|
|
def make_compiler_fn(make_graph_fn): |
|
|
|
def compiler_fn(gm, sample_inputs): |
|
return None if make_graph_fn(gm, sample_inputs) is None else make_boxed_func(gm.forward) |
|
|
|
return compiler_fn |
|
|
|
aot_mod = aot_module_simplified(gm, |
|
real_inputs, |
|
fw_compiler=make_compiler_fn(make_fw_graph), |
|
bw_compiler=make_compiler_fn(make_bw_graph), |
|
partition_fn=get_wrapped_partitioner(param_indices)) |
|
return torch._dynamo.optimize(**compile_kwargs)(aot_mod) |
|
elif backend == "inductor": |
|
patch_create_aot_dispatcher_function(graph_id, z3_partition, make_fw_graph, make_bw_graph, real_inputs, |
|
param_indices, param_manager) |
|
from .partitioner import get_wrapped_choose_saved_values_set |
|
torch._functorch.partitioners.choose_saved_values_set = get_wrapped_choose_saved_values_set(param_indices) |
|
|
|
return torch._inductor.compile(gm, real_inputs) |
|
|
|
raise ValueError(f"Unsupported backend {backend}") |
|
|
|
return backend_fn |
|
|