|
|
|
|
|
|
|
|
|
|
|
import gc |
|
from typing import List, Dict |
|
|
|
import torch |
|
from torch.fx import Graph, Node, GraphModule |
|
|
|
from ..util import get_input_nodes, get_param_nodes, get_index_by_graph_id, get_deepcompile_handle, get_real_uses |
|
from ..fx import add_postprocess, _make_node_meta, get_output_node, move_primals_to_head |
|
from ..profilers.graph_profile import ProfilingInterpreter |
|
from ..list_schedule import fast_free_schedule |
|
|
|
import deepspeed.comm as dist |
|
from deepspeed.accelerator import get_accelerator |
|
|
|
NAME = "zero3_compile" |
|
|
|
|
|
def add_allgather(graph_id: int, graph: Graph, node: Node, ds_id: int): |
|
new_ag_node = add_postprocess(graph, |
|
node, |
|
torch.ops.dc.allgather_param.default, |
|
extra_args=[graph_id, ds_id], |
|
name=f"allgather_ds_param_{node.target}_{ds_id}", |
|
meta=_make_node_meta(node, ds_id, True)) |
|
new_ag_node.meta["val"] = node.meta["val"] |
|
|
|
|
|
|
|
output_node = get_output_node(graph) |
|
output_node.replace_input_with(new_ag_node, node) |
|
|
|
|
|
new_wait_node = add_postprocess(graph, |
|
new_ag_node, |
|
torch.ops.dc.wait_allgather.default, |
|
extra_args=[graph_id, ds_id], |
|
name=f"wait_allgather_ds_param__{node.target}_{ds_id}", |
|
meta=_make_node_meta(node, ds_id, False)) |
|
new_wait_node.meta["val"] = node.meta["val"] |
|
|
|
return new_ag_node |
|
|
|
|
|
def add_release(graph_id: int, graph: Graph, node: Node, release_node: Node, ds_id: int, n_users: int): |
|
new_node = add_postprocess(graph, |
|
node, |
|
torch.ops.dc.release_param.default, |
|
extra_args=[graph_id, ds_id, n_users], |
|
name=f"release_ds_param_{release_node.target}_{node.name}_{ds_id}", |
|
meta=_make_node_meta(node, ds_id, False)) |
|
new_node.meta["val"] = None |
|
|
|
|
|
def add_reduce(graph_id: int, graph: Graph, grad_node: Node, param_name: str, ds_id: int): |
|
new_node = add_postprocess(graph, |
|
grad_node, |
|
torch.ops.dc.reduce_grad.default, |
|
extra_args=[graph_id, ds_id], |
|
name=f"reduce_ds_param_{param_name}", |
|
meta=_make_node_meta(grad_node, ds_id, True)) |
|
new_node.meta["val"] = None |
|
|
|
|
|
def add_gather_and_release(graph_id: int, graph: Graph, param_manager, param_nodes: List[Node]) -> Graph: |
|
|
|
node_to_uses = get_real_uses(graph) |
|
for pn in param_nodes: |
|
add_allgather(graph_id, graph, pn, param_manager.ds_ids[pn.name]) |
|
ds_id = param_manager.ds_ids[pn.name] |
|
users = node_to_uses[pn] |
|
for user in users: |
|
add_release(graph_id, graph, user, pn, ds_id, len(users)) |
|
|
|
return move_primals_to_head(graph) |
|
|
|
|
|
def add_gather_and_reduce(graph_id: int, graph: Graph, param_manager, param_nodes_bw: List[Node], |
|
param_name_to_grad: Dict[str, Node]) -> Graph: |
|
|
|
add_gather_and_release(graph_id, graph, param_manager, param_nodes_bw) |
|
|
|
for param_name in param_manager.param_names: |
|
add_reduce(graph_id, graph, param_name_to_grad[param_name], param_name, param_manager.ds_ids[param_name]) |
|
|
|
return move_primals_to_head(graph) |
|
|
|
|
|
def add_z3_gather_release_fw(gm: GraphModule, |
|
graph_id: int, |
|
graph_order: List[int], |
|
profiling_results, |
|
create_inputs_fn, |
|
param_manager, |
|
debug_log=False) -> GraphModule: |
|
|
|
nz3 = get_deepcompile_handle() |
|
|
|
real_inputs = create_inputs_fn() |
|
param_indices = profiling_results[graph_id].param_indices |
|
|
|
gm.graph = add_gather_and_release(graph_id, gm.graph, param_manager[graph_id], |
|
get_param_nodes(gm.graph, param_indices)) |
|
|
|
nz3.register_graph_z3(graph_id, [v[1] for v in param_indices]) |
|
|
|
profiler = ProfilingInterpreter(gm, debug_log=debug_log) |
|
profiler.run(*real_inputs) |
|
del profiler |
|
gc.collect() |
|
get_accelerator().empty_cache() |
|
|
|
rank = dist.get_rank() |
|
graph_index = get_index_by_graph_id(graph_order, graph_id) |
|
if rank == 0 and debug_log: |
|
print(f"Fwd before scheduling graph {graph_index} graph_id={graph_id} {gm.graph}") |
|
|
|
for n in gm.graph.nodes: |
|
is_ds_param = n.name in param_manager[graph_id].ds_ids |
|
if "val" in n.meta and is_ds_param: |
|
|
|
n.meta["val"] = torch.empty([0], dtype=n.meta['val'].dtype, device=n.meta['val'].device) |
|
|
|
gm.graph = fast_free_schedule( |
|
gm.graph, |
|
get_accelerator().available_memory(), |
|
0, |
|
debug_log=debug_log) |
|
|
|
if rank == 0 and debug_log: |
|
print(f"Fwd after scheduling graph {graph_index} graph_id={graph_id} {gm.graph}") |
|
|
|
return gm |
|
|
|
|
|
def add_z3_gather_release_bw(gm: GraphModule, |
|
graph_id: int, |
|
graph_order: List[int], |
|
profiling_results, |
|
create_inputs_fn, |
|
param_manager, |
|
debug_log=False) -> GraphModule: |
|
|
|
param_nodes_bw, param_name_to_grad = param_manager[graph_id].get_bwd_mapping(gm.graph) |
|
gm.graph = add_gather_and_reduce(graph_id, gm.graph, param_manager[graph_id], param_nodes_bw, param_name_to_grad) |
|
|
|
input_nodes = get_input_nodes(gm.graph) |
|
real_inputs = create_inputs_fn() |
|
assert len(input_nodes) == len(real_inputs), f"Expected {len(real_inputs)} inputs, got {len(input_nodes)}" |
|
|
|
real_outputs = ProfilingInterpreter(gm, debug_log=debug_log).run(*real_inputs) |
|
|
|
del real_outputs |
|
gc.collect() |
|
get_accelerator().empty_cache() |
|
|
|
rank = dist.get_rank() |
|
graph_index = get_index_by_graph_id(graph_order, graph_id) |
|
if rank == 0 and debug_log: |
|
print(f"Bwd before scheduling graph {graph_index} graph_id={graph_id} {gm.graph}") |
|
|
|
gm.graph = fast_free_schedule( |
|
gm.graph, |
|
get_accelerator().available_memory(), |
|
0, |
|
debug_log=debug_log) |
|
|
|
return gm |
|
|
|
|
|
def add_z3_gather_release(gm: GraphModule, graph_id: int, graph_order: List[int], profiling_results, create_inputs_fn, |
|
mem_budget: float, param_manager, bwd: bool) -> GraphModule: |
|
if bwd: |
|
return add_z3_gather_release_bw(gm, |
|
graph_id, |
|
graph_order, |
|
profiling_results, |
|
create_inputs_fn, |
|
param_manager, |
|
debug_log=False) |
|
return add_z3_gather_release_fw(gm, |
|
graph_id, |
|
graph_order, |
|
profiling_results, |
|
create_inputs_fn, |
|
param_manager, |
|
debug_log=False) |
|
|