jamtur01's picture
Upload folder using huggingface_hub
9c6594c verified
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from collections import defaultdict
from typing import List
import torch
from torch.fx import GraphModule
import deepspeed.comm as dist
from deepspeed.accelerator import get_accelerator
from ..util import get_deepcompile_handle
from ..graph_param import DSGraphParamManager
NAME = "selective_gather"
max_alloc_mem = 0
last_optimize_step = 0
def selective_gather(gm: GraphModule, graph_id: int, graph_order: List[int], profiling_results, create_inputs_fn,
mem_budget: float, param_manager: DSGraphParamManager, bwd: bool) -> GraphModule:
if not bwd:
return gm
last_backward_graph_id = None
for g_id, needs_bwd in graph_order:
if needs_bwd:
last_backward_graph_id = g_id
break
# Run only on the last backward graph
if last_backward_graph_id is None or graph_id != last_backward_graph_id:
return gm
peak_mem = 0
for graph_id, prof in profiling_results.items():
# Use peak memory
fwd_max_mem = max(m[3] for m in prof.fwd_mem)
bwd_max_mem = max(m[3] for m in prof.bwd_mem) if len(prof.bwd_mem) > 0 else 0
peak_mem = max(peak_mem, fwd_max_mem, bwd_max_mem)
if dist.get_rank() == 0:
print(
f"selective_gather graph_id={graph_id} max_mem={peak_mem} fwd_max_mem={fwd_max_mem} bwd_max_mem={bwd_max_mem}"
)
persistent_ds_ids = set()
for graph_id, pm in param_manager.items():
for name, ds_param in pm.params.items():
if ds_param.param.ds_persist:
persistent_ds_ids.add(pm.ds_ids[name])
ds_id_to_size = {}
ds_id_to_time = defaultdict(float)
ds_id_to_prof_dtime = defaultdict(float)
ds_id_to_prof_wtime = defaultdict(float)
for graph_id, pm in param_manager.items():
params = pm.params
for param_name, param in params.items():
ds_id = pm.ds_ids[param_name]
ds_id_to_size[ds_id] = param.numel * param.dtype.itemsize
profile = profiling_results[graph_id]
for n in profile.fwd_graph.nodes:
if n.target == torch.ops.dc.allgather_param.default:
assert "tensor_size" in n.meta
ds_id_to_size[n.args[2]] = n.meta["tensor_size"]
assert "device_time" in n.meta
ds_id_to_time[n.args[2]] += n.meta["device_time"]
ds_id_to_prof_dtime[n.args[2]] = n.meta["device_time"]
ds_id_to_prof_wtime[n.args[2]] = n.meta["wall_time"]
if profile.bwd_graph is not None:
for n in profile.bwd_graph.nodes:
if n.target == torch.ops.dc.allgather_param.default:
assert "tensor_size" in n.meta
ds_id_to_size[n.args[2]] = n.meta["tensor_size"]
assert "device_time" in n.meta
ds_id_to_time[n.args[2]] += n.meta["device_time"]
ds_ids = [ds_id for ds_id in ds_id_to_size if ds_id not in persistent_ds_ids]
ds_ids.sort(key=lambda ds_id: ds_id_to_time[ds_id] / ds_id_to_size[ds_id], reverse=True)
# print(f"ds_id_to_size={ds_id_to_size}")
# print(f"ds_id_to_time={ds_id_to_time}")
# if dist.get_rank() == 0:
# for ds_id in ds_ids:
# dtime_in_sec = ds_id_to_prof_dtime[ds_id]
# wtime_in_sec = ds_id_to_prof_wtime[ds_id]
# size_in_mb = ds_id_to_size[ds_id] / 1024 / 1024
# print(
# f"ds_id={ds_id} time_per_size={ds_id_to_time[ds_id] / ds_id_to_size[ds_id]:.5f} dtime={dtime_in_sec:.3f} wtime={wtime_in_sec:.3f} size={size_in_mb:.2f}MB bw={size_in_mb/dtime_in_sec:.2f}MB/s"
# )
sorted_ds_ids = {ds_id: ds_id_to_size[ds_id] for ds_id in ds_ids}
accelerator = get_accelerator()
total_mem = accelerator.total_memory()
vals_to_bcast = torch.tensor([total_mem], device=torch.device(get_accelerator().current_device()))
dist.all_reduce(vals_to_bcast, dist.ReduceOp.MIN)
total_mem = vals_to_bcast[0].item()
MEM_MARGIN = 0.1
available_mem = total_mem * (1 - MEM_MARGIN) - peak_mem
if dist.get_rank() == 0:
print(
f"selective_gather max_mem={peak_mem} total_mem={total_mem} MEM_MARGIN={MEM_MARGIN} available_mem={available_mem}"
)
ds_id_to_param = {}
for g_id, g_pm in param_manager.items():
for name, ds_param in g_pm.params.items():
ds_id_to_param[g_pm.ds_ids[name]] = ds_param.param
persistent_mem = 0
nz3 = get_deepcompile_handle()
for ds_id, size in sorted_ds_ids.items():
if persistent_mem + size > available_mem:
break
persistent_mem += size
param_obj = ds_id_to_param[ds_id]
nz3.set_persistent(ds_id)
if dist.get_rank() == 0:
print(f"Set persistent: {ds_id} size: {size} persistent_mem: {persistent_mem} shape: {param_obj.ds_shape}")
return gm
# def make_selective_gather(z3_optimizer, nz3):
# def selective_gather_wrapper(graph: Graph, graph_id: int, graph_order: List[int], profiling_results,
# mem_budget: float, param_manager, bwd: bool) -> Graph:
# return selective_gather(graph, graph_id, graph_order, profiling_results, mem_budget, param_manager, bwd,
# z3_optimizer, nz3)
# return selective_gather_wrapper