File size: 7,696 Bytes
9c6594c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 |
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from typing import List
import torch
from torch.fx import Graph, Node, GraphModule
from deepspeed.accelerator import get_accelerator
import deepspeed.comm as dist
from ..profilers.comm_profile import create_predictor
from ..graph_param import DSGraphParamManager
NAME = "prefetch"
FUSE_FACTOR = 0.8
MARGIN = 0.1
MAX_FUSE_SIZE = 1e9
MAX_BUFFERED_SIZE = 4e9
run_prefetch_pass = False
def print_rank_0(message):
if dist.get_rank() == 0:
print(message)
def get_ds_id(node: Node):
assert node.target == torch.ops.dc.allgather_param.default
return node.args[2]
def schedule_prefetch(gm: GraphModule, graph_id: int, graph_order: List[int], profiling_results, create_inputs_fn,
mem_budget: float, param_manager: DSGraphParamManager, bwd: bool) -> GraphModule:
max_mem = get_accelerator().total_memory() * (1 - MARGIN)
vals_to_bcast = torch.tensor([max_mem], device=torch.device(get_accelerator().current_device()))
dist.all_reduce(vals_to_bcast, dist.ReduceOp.MIN)
max_mem = vals_to_bcast[0].item()
mem = profiling_results[graph_id].bwd_mem if bwd else profiling_results[graph_id].fwd_mem
op_time = profiling_results[graph_id].bwd_time if bwd else profiling_results[graph_id].fwd_time
tensor_sizes = profiling_results[graph_id].bwd_tensor_sizes if bwd else profiling_results[graph_id].fwd_tensor_sizes
mem_dict = {name: (alloc_mem, peak) for name, alloc_mem, delta, peak in mem}
time_dict = {name: (device_time, wall_time) for name, device_time, wall_time in op_time}
tensor_size_dict = {name: size for name, size in tensor_sizes}
graph = gm.graph
total_param_size = sum(
[tensor_size_dict[n.name] for n in graph.nodes if n.target == torch.ops.dc.allgather_param.default])
print_rank_0(
f"schedule_prefetch graph_id={graph_id} max_mem={max_mem} available_memory={get_accelerator().available_memory()} memory_allocated={get_accelerator().memory_allocated()} max_allocated={get_accelerator().max_memory_allocated()} total_param_size={total_param_size} margin={MARGIN}"
)
# Fill missing values
prev_mem = 0
prev_peak = 0
for node in graph.nodes:
if node.name in mem_dict:
prev_mem = mem_dict[node.name][0]
prev_peak = mem_dict[node.name][1]
else:
print_rank_0(f"node {node.name} not in mem_dict")
mem_dict[node.name] = (prev_mem, prev_peak)
comm_predictor = create_predictor()
order_rev = list(reversed(graph.nodes))
new_order_rev = []
prefetch_ags = []
prefetch_ag_groups = []
ag_tensor_size_sum = 0
for i, node in enumerate(order_rev):
# print_rank_0(
# f"Checking node reverse order {node.name} {node.target} ag_tensor_size_sum={ag_tensor_size_sum} max_mem={max_mem}"
# )
if node.op != "placeholder":
assert i < len(order_rev) - 1
assert node.name in mem_dict
next_node = order_rev[i + 1]
next_alloc_mem, next_peak = mem_dict[next_node.name]
# Free up memory
while next_peak + ag_tensor_size_sum > max_mem or ag_tensor_size_sum > MAX_BUFFERED_SIZE:
if len(prefetch_ag_groups) > 0:
# launch prefetch
fused_ag_nodes = prefetch_ag_groups.pop(0)
total_ag_tensor_size = sum([tensor_size_dict[ag_node.name] for ag_node in fused_ag_nodes])
ag_tensor_size_sum -= total_ag_tensor_size
new_order_rev.append(fused_ag_nodes)
assert len(fused_ag_nodes) > 0
# print_rank_0(
# f"Free up memory fused_ag_nodes={fused_ag_nodes} next_alloc_mem={next_alloc_mem} total_ag_tensor_size={total_ag_tensor_size} ag_tensor_size_sum={ag_tensor_size_sum} max_mem={max_mem}"
# )
elif len(prefetch_ags) > 0:
prefetch_ag_groups.append(prefetch_ags)
prefetch_ags = []
# print_rank_0(
# f"Free up memory prefetch_ags={prefetch_ag_groups} next_alloc_mem={next_alloc_mem} ag_tensor_size_sum={ag_tensor_size_sum} max_mem={max_mem}"
# )
else:
break
if node.target == torch.ops.dc.allgather_param.default:
current_ag_size = sum([tensor_size_dict[ag_node.name] for ag_node in prefetch_ags])
pred_time_current = comm_predictor(current_ag_size)
pred_time_next = comm_predictor(tensor_size_dict[node.name])
pred_time_fused = comm_predictor(current_ag_size + tensor_size_dict[node.name])
do_fuse = max(pred_time_current, pred_time_next) * 1.2 > pred_time_fused and (
current_ag_size + tensor_size_dict[node.name]) < MAX_FUSE_SIZE
# print_rank_0(
# f"found allgather_param do_fuse={do_fuse} current_ag_size={current_ag_size} tensor_size_dict[node.name]={tensor_size_dict[node.name]} pred_time_current={pred_time_current} pred_time_next={pred_time_next} pred_time_fused={pred_time_fused}"
# )
if len(prefetch_ags) > 0 and not do_fuse:
# stop fusing here
prefetch_ag_groups.append(prefetch_ags)
prefetch_ags = []
# print_rank_0(
# f"stop fusing prefetch_ags={prefetch_ag_groups} ag_tensor_size_sum={ag_tensor_size_sum}")
# else:
# print_rank_0(
# f"continue fusing ag_tensor_size_sum={ag_tensor_size_sum} ag_size={tensor_size_dict[node.name]} prefetch_ags={prefetch_ags} prefetch_ag_groups={prefetch_ag_groups}"
# )
prefetch_ags.append(node)
ag_tensor_size_sum += tensor_size_dict[node.name]
new_order_rev.append(node)
if (node.op != "placeholder"
and node.target != torch.ops.dc.reload_parameter) and order_rev[i + 1].op == "placeholder":
for ag_group in prefetch_ag_groups:
assert len(ag_group) > 0
new_order_rev.append(ag_group)
total_ag_tensor_size = sum([tensor_size_dict[ag_node.name] for ag_node in ag_group])
ag_tensor_size_sum -= total_ag_tensor_size
if len(prefetch_ags) > 0:
new_order_rev.append(prefetch_ags)
ag_tensor_size_sum -= sum([tensor_size_dict[ag_node.name] for ag_node in prefetch_ags])
assert ag_tensor_size_sum == 0
# print_rank_0(
# f"node={node} next_alloc_mem={next_alloc_mem} pending_ags={len(prefetch_ags)} ag_tensor_size_sum={ag_tensor_size_sum}"
# )
assert ag_tensor_size_sum >= 0
new_graph = Graph()
env = {}
for node in reversed(new_order_rev):
if isinstance(node, Node):
#print(f"reconstruct {node.name} {node.target}")
new_node = new_graph.node_copy(node, lambda n: env[n.name])
env[node.name] = new_node
else:
param_nodes = [ag_node.args[0] for ag_node in node]
param_nodes_copy = [env[param_node.name] for param_node in param_nodes]
ds_ids = [get_ds_id(ag_node) for ag_node in node]
new_graph.call_function(torch.ops.dc.prefetch_params_fused.default,
args=(graph_id, param_nodes_copy, ds_ids))
new_graph.lint()
gm.graph = new_graph
return gm
|