File size: 11,041 Bytes
9c6594c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 |
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import time
from typing import Any, Tuple, Dict
import statistics
import torch
from torch.fx import GraphModule, Interpreter
from torch.fx.node import map_aggregate
try:
from torch.utils._pytree import tree_all, tree_leaves
from torch._subclasses.fake_tensor import unset_fake_temporarily, is_fake
except ImportError:
# Unsupported torch version
pass
import deepspeed.comm as dist
from deepspeed.accelerator import get_accelerator
from ..util import is_comm_op, is_release_node, get_deepcompile_handle
def _all_real_if_tensor(args):
return tree_all(lambda x: not torch.is_tensor(x) or not is_fake(x), args)
def _to(v, device):
if torch.is_tensor(v):
with unset_fake_temporarily():
return v.to(device)
return v
def _args_to_key(v):
def _tensor_to_key(v) -> str:
if torch.is_tensor(v):
if v.numel() == 1:
return f"{v.dtype}{v.device}{v.item()}"
else:
return f"{v.dtype}{v.device}{v.shape}"
return str(v)
return map_aggregate(v, _tensor_to_key)
def _node_size(out):
return sum([v.element_size() * v.numel() for v in tree_leaves(out) if torch.is_tensor(v)])
def _get_mem_usage_out_of_torch():
adjust = 0
try:
import pynvml
pynvml.nvmlInit()
current_dev_id = get_accelerator().current_device()
handle = pynvml.nvmlDeviceGetHandleByIndex(current_dev_id)
info = pynvml.nvmlDeviceGetMemoryInfo(handle)
torch_alloc = get_accelerator().memory_allocated()
adjust = info.used - torch_alloc
except:
# pynvml not available
pass
return adjust
# https://pytorch.org/tutorials/intermediate/fx_profiling_tutorial.html
class ProfilingInterpreter(Interpreter):
def __init__(self, gm: GraphModule, iteration: int = 10, warmup: int = 5, debug_log=False):
super().__init__(gm)
self.nz3 = get_deepcompile_handle()
assert iteration > 0
assert warmup >= 0
self.iteration = iteration
self.warmup = warmup
self.device = torch.device(get_accelerator().current_device())
self.cache: Dict[Tuple, Any] = {}
self.distributed = dist.is_initialized()
self.allgather_mem: Dict[int, int] = {}
self.debug_log = debug_log
self.mem_usage_out_of_torch = 0
def run(self, *args) -> Any:
"""Run the graph with profiling enabled.
args: inputs to the graph. Tensors in the inpusts must be real tensors, not fake tensors. args can contain ds parameters.
returns: The output of the graph. Tensor in the output is real tensors.
"""
try:
assert _all_real_if_tensor(args), "Inputs must be real tensors"
self.nz3.enable_profiling(True)
with unset_fake_temporarily():
with get_accelerator().random().fork_rng(devices=[self.device]):
self.mem_usage_out_of_torch = _get_mem_usage_out_of_torch()
return_val = super().run(*args)
except Exception as e:
msg = e.msg if "msg" in dir(e) else str(e)
print(f"Profiling error {msg}")
finally:
self.nz3.clear_all_gathered_params()
self.nz3.enable_profiling(False)
return return_val
def run_node(self, n: torch.fx.Node) -> Any:
if n.op in {"placeholder", "output"}:
n.meta["device_time"] = 0.0
n.meta["wall_time"] = 0.0
n.meta["alloc_mem"] = 0
n.meta["max_memory"] = 0
n.meta["tensor_size"] = _node_size(n)
return super().run_node(n)
args, kwargs = self.fetch_args_kwargs_from_env(n)
assert isinstance(args, tuple)
assert isinstance(kwargs, dict)
def rebuild_param_if_necessary(v):
if hasattr(v, "ds_id"):
v.all_gather(param_list=[v])
return v
args = map_aggregate(args, lambda x: rebuild_param_if_necessary(x))
args = map_aggregate(args, lambda x: _to(x, self.device))
kwargs = map_aggregate(kwargs, lambda x: _to(x, self.device))
cache_key = (n.target, _args_to_key(args), _args_to_key(kwargs))
cache_hit = cache_key in self.cache
cache_hit_flag = torch.tensor([0 if cache_hit else 1], device=self.device, dtype=torch.int)
if self.distributed:
dist.all_reduce(cache_hit_flag, dist.ReduceOp.SUM)
cache_hit = cache_hit_flag.item() == 0
if cache_hit:
device_time, wall_time, alloc_mem, max_mem, tensor_size = self.cache[cache_key]
n.meta["device_time"] = device_time
n.meta["wall_time"] = wall_time
n.meta["alloc_mem"] = alloc_mem
n.meta["max_mem"] = max_mem
n.meta["tensor_size"] = tensor_size
is_release_op = is_release_node(n)
run_only_once = cache_hit or is_release_op
iteration = 1 if run_only_once else self.iteration
accelerator = get_accelerator()
start_events = [accelerator.Event(enable_timing=True) for _ in range(iteration)]
end_events = [accelerator.Event(enable_timing=True) for _ in range(iteration)]
get_accelerator().reset_peak_memory_stats()
alloc_mem_start = get_accelerator().memory_allocated()
max_mem_start = get_accelerator().max_memory_allocated()
if not run_only_once:
for i in range(self.warmup):
out = getattr(self, n.op)(n.target, args, kwargs)
if is_comm_op(n):
assert self.distributed, f"Distributed environment is not initialized but comm operator {n.name} {n.target} is used."
dist.barrier()
start = time.time()
for i in range(iteration):
start_events[i].record()
out = getattr(self, n.op)(n.target, args, kwargs)
end_events[i].record()
accelerator.synchronize()
walltime_sum = time.time() - start
if is_comm_op(n):
dist.barrier()
alloc_mem = get_accelerator().memory_allocated() - alloc_mem_start + self.mem_usage_out_of_torch
max_memory = get_accelerator().max_memory_allocated() - max_mem_start + self.mem_usage_out_of_torch
tensor_size = _node_size(out)
def partition_param_if_necessary(v):
if hasattr(v, "ds_id") and not v.ds_persist:
v.partition(param_list=[v], has_been_updated=False)
return v
args = map_aggregate(args, lambda x: partition_param_if_necessary(x))
if not cache_hit:
device_time = statistics.mean([s.elapsed_time(e) for s, e in zip(start_events, end_events)])
wall_time = walltime_sum / iteration * 1000
with unset_fake_temporarily():
vals_to_bcast = torch.tensor([device_time, wall_time, alloc_mem, max_memory, tensor_size],
device=self.device)
if self.distributed:
dist.all_reduce(vals_to_bcast, dist.ReduceOp.AVG)
n.meta["device_time"] = vals_to_bcast[0].item()
n.meta["wall_time"] = vals_to_bcast[1].item()
n.meta["alloc_mem"] = int(vals_to_bcast[2].item())
n.meta["max_mem"] = int(vals_to_bcast[3].item())
n.meta["tensor_size"] = int(vals_to_bcast[4].item())
self.cache[cache_key] = (n.meta["device_time"], n.meta["wall_time"], n.meta["alloc_mem"],
n.meta["max_mem"], n.meta["tensor_size"])
if is_release_op:
n.meta["alloc_mem"] = -self.allgather_mem.get(args[2], 0)
if dist.get_rank() == 0 and self.debug_log:
print(
f"{n.target} {n.meta['device_time']:.2f}ms {n.meta['wall_time']:.2f}ms alloc_mem={n.meta['alloc_mem'] / 1024 / 1024:.2f}MB max_mem={n.meta['max_mem'] / 1024 / 1024:.2f}MB tensor_size={n.meta['tensor_size']}"
)
if n.target == torch.ops.dc.allgather_param.default:
out = args[0]
assert hasattr(out, "ds_id")
if not out.ds_persist:
self.nz3.invalidate_gathered_param(args[2])
self.allgather_mem[out.ds_id] = n.meta["alloc_mem"]
return out
class MemoryProfilingInterpreter(Interpreter):
def __init__(self, gm: GraphModule, debug_log=False):
super().__init__(gm)
self.nz3 = get_deepcompile_handle()
self.device = torch.device(get_accelerator().current_device())
self.mem_record = []
self.last_alloc = get_accelerator().memory_allocated()
self.node_counter = 0
self.node_num = len(gm.graph.nodes)
self.debug_log = debug_log
def run(self, *args) -> Any:
try:
assert _all_real_if_tensor(args), "Inputs must be real tensors"
self.nz3.enable_profiling(True)
self.mem_usage_out_of_torch = _get_mem_usage_out_of_torch()
with unset_fake_temporarily():
with get_accelerator().random().fork_rng(devices=[self.device]):
return_val = super().run(*args)
except Exception as e:
print(f"MemoryProfiling error {e}")
finally:
self.nz3.enable_profiling(False)
return return_val
def run_node(self, n: torch.fx.Node) -> Any:
get_accelerator().reset_peak_memory_stats()
if n.op in {"placeholder", "output"}:
ret = super().run_node(n)
else:
args, kwargs = self.fetch_args_kwargs_from_env(n)
args = map_aggregate(args, lambda x: _to(x, self.device))
kwargs = map_aggregate(kwargs, lambda x: _to(x, self.device))
ret = getattr(self, n.op)(n.target, args, kwargs)
del args, kwargs
current_alloc = get_accelerator().memory_allocated() + self.mem_usage_out_of_torch
max_alloc = get_accelerator().max_memory_allocated() + self.mem_usage_out_of_torch
vals_to_bcast = torch.tensor([current_alloc, max_alloc], device=self.device)
dist.all_reduce(vals_to_bcast, dist.ReduceOp.MAX)
current_alloc = vals_to_bcast[0].item()
max_alloc = vals_to_bcast[1].item()
self.mem_record.append((n.name, current_alloc, current_alloc - self.last_alloc, max_alloc))
self.node_counter += 1
if self.debug_log and dist.get_rank() == 0:
print(
f"Mem prof Node {self.node_counter}/{self.node_num} {n.name} memory {current_alloc / 1024 / 1024:.2f}MB delta {(current_alloc - self.last_alloc) / 1024 / 1024:.2f}MB"
)
self.last_alloc = current_alloc
return ret
def dump(self, path):
import pandas as pd
df = pd.DataFrame(self.mem_record, columns=["node", "memory", "delta", "max_mem"])
df.to_csv(path, index=False)
|