File size: 5,466 Bytes
9c6594c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 |
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from collections import defaultdict
from typing import List
import torch
from torch.fx import GraphModule
import deepspeed.comm as dist
from deepspeed.accelerator import get_accelerator
from ..util import get_deepcompile_handle
from ..graph_param import DSGraphParamManager
NAME = "selective_gather"
max_alloc_mem = 0
last_optimize_step = 0
def selective_gather(gm: GraphModule, graph_id: int, graph_order: List[int], profiling_results, create_inputs_fn,
mem_budget: float, param_manager: DSGraphParamManager, bwd: bool) -> GraphModule:
if not bwd:
return gm
last_backward_graph_id = None
for g_id, needs_bwd in graph_order:
if needs_bwd:
last_backward_graph_id = g_id
break
# Run only on the last backward graph
if last_backward_graph_id is None or graph_id != last_backward_graph_id:
return gm
peak_mem = 0
for graph_id, prof in profiling_results.items():
# Use peak memory
fwd_max_mem = max(m[3] for m in prof.fwd_mem)
bwd_max_mem = max(m[3] for m in prof.bwd_mem) if len(prof.bwd_mem) > 0 else 0
peak_mem = max(peak_mem, fwd_max_mem, bwd_max_mem)
if dist.get_rank() == 0:
print(
f"selective_gather graph_id={graph_id} max_mem={peak_mem} fwd_max_mem={fwd_max_mem} bwd_max_mem={bwd_max_mem}"
)
persistent_ds_ids = set()
for graph_id, pm in param_manager.items():
for name, ds_param in pm.params.items():
if ds_param.param.ds_persist:
persistent_ds_ids.add(pm.ds_ids[name])
ds_id_to_size = {}
ds_id_to_time = defaultdict(float)
ds_id_to_prof_dtime = defaultdict(float)
ds_id_to_prof_wtime = defaultdict(float)
for graph_id, pm in param_manager.items():
params = pm.params
for param_name, param in params.items():
ds_id = pm.ds_ids[param_name]
ds_id_to_size[ds_id] = param.numel * param.dtype.itemsize
profile = profiling_results[graph_id]
for n in profile.fwd_graph.nodes:
if n.target == torch.ops.dc.allgather_param.default:
assert "tensor_size" in n.meta
ds_id_to_size[n.args[2]] = n.meta["tensor_size"]
assert "device_time" in n.meta
ds_id_to_time[n.args[2]] += n.meta["device_time"]
ds_id_to_prof_dtime[n.args[2]] = n.meta["device_time"]
ds_id_to_prof_wtime[n.args[2]] = n.meta["wall_time"]
if profile.bwd_graph is not None:
for n in profile.bwd_graph.nodes:
if n.target == torch.ops.dc.allgather_param.default:
assert "tensor_size" in n.meta
ds_id_to_size[n.args[2]] = n.meta["tensor_size"]
assert "device_time" in n.meta
ds_id_to_time[n.args[2]] += n.meta["device_time"]
ds_ids = [ds_id for ds_id in ds_id_to_size if ds_id not in persistent_ds_ids]
ds_ids.sort(key=lambda ds_id: ds_id_to_time[ds_id] / ds_id_to_size[ds_id], reverse=True)
# print(f"ds_id_to_size={ds_id_to_size}")
# print(f"ds_id_to_time={ds_id_to_time}")
# if dist.get_rank() == 0:
# for ds_id in ds_ids:
# dtime_in_sec = ds_id_to_prof_dtime[ds_id]
# wtime_in_sec = ds_id_to_prof_wtime[ds_id]
# size_in_mb = ds_id_to_size[ds_id] / 1024 / 1024
# print(
# f"ds_id={ds_id} time_per_size={ds_id_to_time[ds_id] / ds_id_to_size[ds_id]:.5f} dtime={dtime_in_sec:.3f} wtime={wtime_in_sec:.3f} size={size_in_mb:.2f}MB bw={size_in_mb/dtime_in_sec:.2f}MB/s"
# )
sorted_ds_ids = {ds_id: ds_id_to_size[ds_id] for ds_id in ds_ids}
accelerator = get_accelerator()
total_mem = accelerator.total_memory()
vals_to_bcast = torch.tensor([total_mem], device=torch.device(get_accelerator().current_device()))
dist.all_reduce(vals_to_bcast, dist.ReduceOp.MIN)
total_mem = vals_to_bcast[0].item()
MEM_MARGIN = 0.1
available_mem = total_mem * (1 - MEM_MARGIN) - peak_mem
if dist.get_rank() == 0:
print(
f"selective_gather max_mem={peak_mem} total_mem={total_mem} MEM_MARGIN={MEM_MARGIN} available_mem={available_mem}"
)
ds_id_to_param = {}
for g_id, g_pm in param_manager.items():
for name, ds_param in g_pm.params.items():
ds_id_to_param[g_pm.ds_ids[name]] = ds_param.param
persistent_mem = 0
nz3 = get_deepcompile_handle()
for ds_id, size in sorted_ds_ids.items():
if persistent_mem + size > available_mem:
break
persistent_mem += size
param_obj = ds_id_to_param[ds_id]
nz3.set_persistent(ds_id)
if dist.get_rank() == 0:
print(f"Set persistent: {ds_id} size: {size} persistent_mem: {persistent_mem} shape: {param_obj.ds_shape}")
return gm
# def make_selective_gather(z3_optimizer, nz3):
# def selective_gather_wrapper(graph: Graph, graph_id: int, graph_order: List[int], profiling_results,
# mem_budget: float, param_manager, bwd: bool) -> Graph:
# return selective_gather(graph, graph_id, graph_order, profiling_results, mem_budget, param_manager, bwd,
# z3_optimizer, nz3)
# return selective_gather_wrapper
|