jamtur01's picture
Upload folder using huggingface_hub
9c6594c verified
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import gc
from typing import List, Dict
import torch
from torch.fx import Graph, Node, GraphModule
from ..util import get_input_nodes, get_param_nodes, get_index_by_graph_id, get_deepcompile_handle, get_real_uses
from ..fx import add_postprocess, _make_node_meta, get_output_node, move_primals_to_head
from ..profilers.graph_profile import ProfilingInterpreter
from ..list_schedule import fast_free_schedule
import deepspeed.comm as dist
from deepspeed.accelerator import get_accelerator
NAME = "zero3_compile"
def add_allgather(graph_id: int, graph: Graph, node: Node, ds_id: int):
new_ag_node = add_postprocess(graph,
node,
torch.ops.dc.allgather_param.default,
extra_args=[graph_id, ds_id],
name=f"allgather_ds_param_{node.target}_{ds_id}",
meta=_make_node_meta(node, ds_id, True))
new_ag_node.meta["val"] = node.meta["val"]
# Set the previous node back to output
# We don't want to change the output node to allgather
output_node = get_output_node(graph)
output_node.replace_input_with(new_ag_node, node)
# Add wait as well
new_wait_node = add_postprocess(graph,
new_ag_node,
torch.ops.dc.wait_allgather.default,
extra_args=[graph_id, ds_id],
name=f"wait_allgather_ds_param__{node.target}_{ds_id}",
meta=_make_node_meta(node, ds_id, False))
new_wait_node.meta["val"] = node.meta["val"]
return new_ag_node
def add_release(graph_id: int, graph: Graph, node: Node, release_node: Node, ds_id: int, n_users: int):
new_node = add_postprocess(graph,
node,
torch.ops.dc.release_param.default,
extra_args=[graph_id, ds_id, n_users],
name=f"release_ds_param_{release_node.target}_{node.name}_{ds_id}",
meta=_make_node_meta(node, ds_id, False))
new_node.meta["val"] = None
def add_reduce(graph_id: int, graph: Graph, grad_node: Node, param_name: str, ds_id: int):
new_node = add_postprocess(graph,
grad_node,
torch.ops.dc.reduce_grad.default,
extra_args=[graph_id, ds_id],
name=f"reduce_ds_param_{param_name}",
meta=_make_node_meta(grad_node, ds_id, True))
new_node.meta["val"] = None
def add_gather_and_release(graph_id: int, graph: Graph, param_manager, param_nodes: List[Node]) -> Graph:
node_to_uses = get_real_uses(graph)
for pn in param_nodes:
add_allgather(graph_id, graph, pn, param_manager.ds_ids[pn.name])
ds_id = param_manager.ds_ids[pn.name]
users = node_to_uses[pn]
for user in users:
add_release(graph_id, graph, user, pn, ds_id, len(users))
return move_primals_to_head(graph)
def add_gather_and_reduce(graph_id: int, graph: Graph, param_manager, param_nodes_bw: List[Node],
param_name_to_grad: Dict[str, Node]) -> Graph:
add_gather_and_release(graph_id, graph, param_manager, param_nodes_bw)
for param_name in param_manager.param_names:
add_reduce(graph_id, graph, param_name_to_grad[param_name], param_name, param_manager.ds_ids[param_name])
return move_primals_to_head(graph)
def add_z3_gather_release_fw(gm: GraphModule,
graph_id: int,
graph_order: List[int],
profiling_results,
create_inputs_fn,
param_manager,
debug_log=False) -> GraphModule:
nz3 = get_deepcompile_handle()
real_inputs = create_inputs_fn()
param_indices = profiling_results[graph_id].param_indices
gm.graph = add_gather_and_release(graph_id, gm.graph, param_manager[graph_id],
get_param_nodes(gm.graph, param_indices))
nz3.register_graph_z3(graph_id, [v[1] for v in param_indices]) # Need this before profiling
profiler = ProfilingInterpreter(gm, debug_log=debug_log)
profiler.run(*real_inputs)
del profiler
gc.collect()
get_accelerator().empty_cache()
rank = dist.get_rank()
graph_index = get_index_by_graph_id(graph_order, graph_id)
if rank == 0 and debug_log:
print(f"Fwd before scheduling graph {graph_index} graph_id={graph_id} {gm.graph}")
for n in gm.graph.nodes:
is_ds_param = n.name in param_manager[graph_id].ds_ids
if "val" in n.meta and is_ds_param:
# Used for Inductor's validation
n.meta["val"] = torch.empty([0], dtype=n.meta['val'].dtype, device=n.meta['val'].device)
gm.graph = fast_free_schedule(
gm.graph,
get_accelerator().available_memory(),
0, # unused
debug_log=debug_log)
if rank == 0 and debug_log:
print(f"Fwd after scheduling graph {graph_index} graph_id={graph_id} {gm.graph}")
return gm
def add_z3_gather_release_bw(gm: GraphModule,
graph_id: int,
graph_order: List[int],
profiling_results,
create_inputs_fn,
param_manager,
debug_log=False) -> GraphModule:
param_nodes_bw, param_name_to_grad = param_manager[graph_id].get_bwd_mapping(gm.graph)
gm.graph = add_gather_and_reduce(graph_id, gm.graph, param_manager[graph_id], param_nodes_bw, param_name_to_grad)
input_nodes = get_input_nodes(gm.graph)
real_inputs = create_inputs_fn()
assert len(input_nodes) == len(real_inputs), f"Expected {len(real_inputs)} inputs, got {len(input_nodes)}"
real_outputs = ProfilingInterpreter(gm, debug_log=debug_log).run(*real_inputs)
del real_outputs
gc.collect()
get_accelerator().empty_cache()
rank = dist.get_rank()
graph_index = get_index_by_graph_id(graph_order, graph_id)
if rank == 0 and debug_log:
print(f"Bwd before scheduling graph {graph_index} graph_id={graph_id} {gm.graph}")
gm.graph = fast_free_schedule(
gm.graph,
get_accelerator().available_memory(),
0, # unused
debug_log=debug_log)
return gm
def add_z3_gather_release(gm: GraphModule, graph_id: int, graph_order: List[int], profiling_results, create_inputs_fn,
mem_budget: float, param_manager, bwd: bool) -> GraphModule:
if bwd:
return add_z3_gather_release_bw(gm,
graph_id,
graph_order,
profiling_results,
create_inputs_fn,
param_manager,
debug_log=False)
return add_z3_gather_release_fw(gm,
graph_id,
graph_order,
profiling_results,
create_inputs_fn,
param_manager,
debug_log=False)