jamtur01's picture
Upload folder using huggingface_hub
9c6594c verified
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from typing import List
import torch
from torch.fx import Graph, Node, GraphModule
from deepspeed.accelerator import get_accelerator
import deepspeed.comm as dist
from ..profilers.comm_profile import create_predictor
from ..graph_param import DSGraphParamManager
NAME = "prefetch"
FUSE_FACTOR = 0.8
MARGIN = 0.1
MAX_FUSE_SIZE = 1e9
MAX_BUFFERED_SIZE = 4e9
run_prefetch_pass = False
def print_rank_0(message):
if dist.get_rank() == 0:
print(message)
def get_ds_id(node: Node):
assert node.target == torch.ops.dc.allgather_param.default
return node.args[2]
def schedule_prefetch(gm: GraphModule, graph_id: int, graph_order: List[int], profiling_results, create_inputs_fn,
mem_budget: float, param_manager: DSGraphParamManager, bwd: bool) -> GraphModule:
max_mem = get_accelerator().total_memory() * (1 - MARGIN)
vals_to_bcast = torch.tensor([max_mem], device=torch.device(get_accelerator().current_device()))
dist.all_reduce(vals_to_bcast, dist.ReduceOp.MIN)
max_mem = vals_to_bcast[0].item()
mem = profiling_results[graph_id].bwd_mem if bwd else profiling_results[graph_id].fwd_mem
op_time = profiling_results[graph_id].bwd_time if bwd else profiling_results[graph_id].fwd_time
tensor_sizes = profiling_results[graph_id].bwd_tensor_sizes if bwd else profiling_results[graph_id].fwd_tensor_sizes
mem_dict = {name: (alloc_mem, peak) for name, alloc_mem, delta, peak in mem}
time_dict = {name: (device_time, wall_time) for name, device_time, wall_time in op_time}
tensor_size_dict = {name: size for name, size in tensor_sizes}
graph = gm.graph
total_param_size = sum(
[tensor_size_dict[n.name] for n in graph.nodes if n.target == torch.ops.dc.allgather_param.default])
print_rank_0(
f"schedule_prefetch graph_id={graph_id} max_mem={max_mem} available_memory={get_accelerator().available_memory()} memory_allocated={get_accelerator().memory_allocated()} max_allocated={get_accelerator().max_memory_allocated()} total_param_size={total_param_size} margin={MARGIN}"
)
# Fill missing values
prev_mem = 0
prev_peak = 0
for node in graph.nodes:
if node.name in mem_dict:
prev_mem = mem_dict[node.name][0]
prev_peak = mem_dict[node.name][1]
else:
print_rank_0(f"node {node.name} not in mem_dict")
mem_dict[node.name] = (prev_mem, prev_peak)
comm_predictor = create_predictor()
order_rev = list(reversed(graph.nodes))
new_order_rev = []
prefetch_ags = []
prefetch_ag_groups = []
ag_tensor_size_sum = 0
for i, node in enumerate(order_rev):
# print_rank_0(
# f"Checking node reverse order {node.name} {node.target} ag_tensor_size_sum={ag_tensor_size_sum} max_mem={max_mem}"
# )
if node.op != "placeholder":
assert i < len(order_rev) - 1
assert node.name in mem_dict
next_node = order_rev[i + 1]
next_alloc_mem, next_peak = mem_dict[next_node.name]
# Free up memory
while next_peak + ag_tensor_size_sum > max_mem or ag_tensor_size_sum > MAX_BUFFERED_SIZE:
if len(prefetch_ag_groups) > 0:
# launch prefetch
fused_ag_nodes = prefetch_ag_groups.pop(0)
total_ag_tensor_size = sum([tensor_size_dict[ag_node.name] for ag_node in fused_ag_nodes])
ag_tensor_size_sum -= total_ag_tensor_size
new_order_rev.append(fused_ag_nodes)
assert len(fused_ag_nodes) > 0
# print_rank_0(
# f"Free up memory fused_ag_nodes={fused_ag_nodes} next_alloc_mem={next_alloc_mem} total_ag_tensor_size={total_ag_tensor_size} ag_tensor_size_sum={ag_tensor_size_sum} max_mem={max_mem}"
# )
elif len(prefetch_ags) > 0:
prefetch_ag_groups.append(prefetch_ags)
prefetch_ags = []
# print_rank_0(
# f"Free up memory prefetch_ags={prefetch_ag_groups} next_alloc_mem={next_alloc_mem} ag_tensor_size_sum={ag_tensor_size_sum} max_mem={max_mem}"
# )
else:
break
if node.target == torch.ops.dc.allgather_param.default:
current_ag_size = sum([tensor_size_dict[ag_node.name] for ag_node in prefetch_ags])
pred_time_current = comm_predictor(current_ag_size)
pred_time_next = comm_predictor(tensor_size_dict[node.name])
pred_time_fused = comm_predictor(current_ag_size + tensor_size_dict[node.name])
do_fuse = max(pred_time_current, pred_time_next) * 1.2 > pred_time_fused and (
current_ag_size + tensor_size_dict[node.name]) < MAX_FUSE_SIZE
# print_rank_0(
# f"found allgather_param do_fuse={do_fuse} current_ag_size={current_ag_size} tensor_size_dict[node.name]={tensor_size_dict[node.name]} pred_time_current={pred_time_current} pred_time_next={pred_time_next} pred_time_fused={pred_time_fused}"
# )
if len(prefetch_ags) > 0 and not do_fuse:
# stop fusing here
prefetch_ag_groups.append(prefetch_ags)
prefetch_ags = []
# print_rank_0(
# f"stop fusing prefetch_ags={prefetch_ag_groups} ag_tensor_size_sum={ag_tensor_size_sum}")
# else:
# print_rank_0(
# f"continue fusing ag_tensor_size_sum={ag_tensor_size_sum} ag_size={tensor_size_dict[node.name]} prefetch_ags={prefetch_ags} prefetch_ag_groups={prefetch_ag_groups}"
# )
prefetch_ags.append(node)
ag_tensor_size_sum += tensor_size_dict[node.name]
new_order_rev.append(node)
if (node.op != "placeholder"
and node.target != torch.ops.dc.reload_parameter) and order_rev[i + 1].op == "placeholder":
for ag_group in prefetch_ag_groups:
assert len(ag_group) > 0
new_order_rev.append(ag_group)
total_ag_tensor_size = sum([tensor_size_dict[ag_node.name] for ag_node in ag_group])
ag_tensor_size_sum -= total_ag_tensor_size
if len(prefetch_ags) > 0:
new_order_rev.append(prefetch_ags)
ag_tensor_size_sum -= sum([tensor_size_dict[ag_node.name] for ag_node in prefetch_ags])
assert ag_tensor_size_sum == 0
# print_rank_0(
# f"node={node} next_alloc_mem={next_alloc_mem} pending_ags={len(prefetch_ags)} ag_tensor_size_sum={ag_tensor_size_sum}"
# )
assert ag_tensor_size_sum >= 0
new_graph = Graph()
env = {}
for node in reversed(new_order_rev):
if isinstance(node, Node):
#print(f"reconstruct {node.name} {node.target}")
new_node = new_graph.node_copy(node, lambda n: env[n.name])
env[node.name] = new_node
else:
param_nodes = [ag_node.args[0] for ag_node in node]
param_nodes_copy = [env[param_node.name] for param_node in param_nodes]
ds_ids = [get_ds_id(ag_node) for ag_node in node]
new_graph.call_function(torch.ops.dc.prefetch_params_fused.default,
args=(graph_id, param_nodes_copy, ds_ids))
new_graph.lint()
gm.graph = new_graph
return gm