File size: 7,645 Bytes
9c6594c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 |
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import gc
from typing import List, Dict
import torch
from torch.fx import Graph, Node, GraphModule
from ..util import get_input_nodes, get_param_nodes, get_index_by_graph_id, get_deepcompile_handle, get_real_uses
from ..fx import add_postprocess, _make_node_meta, get_output_node, move_primals_to_head
from ..profilers.graph_profile import ProfilingInterpreter
from ..list_schedule import fast_free_schedule
import deepspeed.comm as dist
from deepspeed.accelerator import get_accelerator
NAME = "zero3_compile"
def add_allgather(graph_id: int, graph: Graph, node: Node, ds_id: int):
new_ag_node = add_postprocess(graph,
node,
torch.ops.dc.allgather_param.default,
extra_args=[graph_id, ds_id],
name=f"allgather_ds_param_{node.target}_{ds_id}",
meta=_make_node_meta(node, ds_id, True))
new_ag_node.meta["val"] = node.meta["val"]
# Set the previous node back to output
# We don't want to change the output node to allgather
output_node = get_output_node(graph)
output_node.replace_input_with(new_ag_node, node)
# Add wait as well
new_wait_node = add_postprocess(graph,
new_ag_node,
torch.ops.dc.wait_allgather.default,
extra_args=[graph_id, ds_id],
name=f"wait_allgather_ds_param__{node.target}_{ds_id}",
meta=_make_node_meta(node, ds_id, False))
new_wait_node.meta["val"] = node.meta["val"]
return new_ag_node
def add_release(graph_id: int, graph: Graph, node: Node, release_node: Node, ds_id: int, n_users: int):
new_node = add_postprocess(graph,
node,
torch.ops.dc.release_param.default,
extra_args=[graph_id, ds_id, n_users],
name=f"release_ds_param_{release_node.target}_{node.name}_{ds_id}",
meta=_make_node_meta(node, ds_id, False))
new_node.meta["val"] = None
def add_reduce(graph_id: int, graph: Graph, grad_node: Node, param_name: str, ds_id: int):
new_node = add_postprocess(graph,
grad_node,
torch.ops.dc.reduce_grad.default,
extra_args=[graph_id, ds_id],
name=f"reduce_ds_param_{param_name}",
meta=_make_node_meta(grad_node, ds_id, True))
new_node.meta["val"] = None
def add_gather_and_release(graph_id: int, graph: Graph, param_manager, param_nodes: List[Node]) -> Graph:
node_to_uses = get_real_uses(graph)
for pn in param_nodes:
add_allgather(graph_id, graph, pn, param_manager.ds_ids[pn.name])
ds_id = param_manager.ds_ids[pn.name]
users = node_to_uses[pn]
for user in users:
add_release(graph_id, graph, user, pn, ds_id, len(users))
return move_primals_to_head(graph)
def add_gather_and_reduce(graph_id: int, graph: Graph, param_manager, param_nodes_bw: List[Node],
param_name_to_grad: Dict[str, Node]) -> Graph:
add_gather_and_release(graph_id, graph, param_manager, param_nodes_bw)
for param_name in param_manager.param_names:
add_reduce(graph_id, graph, param_name_to_grad[param_name], param_name, param_manager.ds_ids[param_name])
return move_primals_to_head(graph)
def add_z3_gather_release_fw(gm: GraphModule,
graph_id: int,
graph_order: List[int],
profiling_results,
create_inputs_fn,
param_manager,
debug_log=False) -> GraphModule:
nz3 = get_deepcompile_handle()
real_inputs = create_inputs_fn()
param_indices = profiling_results[graph_id].param_indices
gm.graph = add_gather_and_release(graph_id, gm.graph, param_manager[graph_id],
get_param_nodes(gm.graph, param_indices))
nz3.register_graph_z3(graph_id, [v[1] for v in param_indices]) # Need this before profiling
profiler = ProfilingInterpreter(gm, debug_log=debug_log)
profiler.run(*real_inputs)
del profiler
gc.collect()
get_accelerator().empty_cache()
rank = dist.get_rank()
graph_index = get_index_by_graph_id(graph_order, graph_id)
if rank == 0 and debug_log:
print(f"Fwd before scheduling graph {graph_index} graph_id={graph_id} {gm.graph}")
for n in gm.graph.nodes:
is_ds_param = n.name in param_manager[graph_id].ds_ids
if "val" in n.meta and is_ds_param:
# Used for Inductor's validation
n.meta["val"] = torch.empty([0], dtype=n.meta['val'].dtype, device=n.meta['val'].device)
gm.graph = fast_free_schedule(
gm.graph,
get_accelerator().available_memory(),
0, # unused
debug_log=debug_log)
if rank == 0 and debug_log:
print(f"Fwd after scheduling graph {graph_index} graph_id={graph_id} {gm.graph}")
return gm
def add_z3_gather_release_bw(gm: GraphModule,
graph_id: int,
graph_order: List[int],
profiling_results,
create_inputs_fn,
param_manager,
debug_log=False) -> GraphModule:
param_nodes_bw, param_name_to_grad = param_manager[graph_id].get_bwd_mapping(gm.graph)
gm.graph = add_gather_and_reduce(graph_id, gm.graph, param_manager[graph_id], param_nodes_bw, param_name_to_grad)
input_nodes = get_input_nodes(gm.graph)
real_inputs = create_inputs_fn()
assert len(input_nodes) == len(real_inputs), f"Expected {len(real_inputs)} inputs, got {len(input_nodes)}"
real_outputs = ProfilingInterpreter(gm, debug_log=debug_log).run(*real_inputs)
del real_outputs
gc.collect()
get_accelerator().empty_cache()
rank = dist.get_rank()
graph_index = get_index_by_graph_id(graph_order, graph_id)
if rank == 0 and debug_log:
print(f"Bwd before scheduling graph {graph_index} graph_id={graph_id} {gm.graph}")
gm.graph = fast_free_schedule(
gm.graph,
get_accelerator().available_memory(),
0, # unused
debug_log=debug_log)
return gm
def add_z3_gather_release(gm: GraphModule, graph_id: int, graph_order: List[int], profiling_results, create_inputs_fn,
mem_budget: float, param_manager, bwd: bool) -> GraphModule:
if bwd:
return add_z3_gather_release_bw(gm,
graph_id,
graph_order,
profiling_results,
create_inputs_fn,
param_manager,
debug_log=False)
return add_z3_gather_release_fw(gm,
graph_id,
graph_order,
profiling_results,
create_inputs_fn,
param_manager,
debug_log=False)
|