|
|
|
|
|
|
|
|
|
|
|
import time |
|
from typing import Any, Tuple, Dict |
|
import statistics |
|
|
|
import torch |
|
from torch.fx import GraphModule, Interpreter |
|
from torch.fx.node import map_aggregate |
|
|
|
try: |
|
from torch.utils._pytree import tree_all, tree_leaves |
|
from torch._subclasses.fake_tensor import unset_fake_temporarily, is_fake |
|
except ImportError: |
|
|
|
pass |
|
|
|
import deepspeed.comm as dist |
|
from deepspeed.accelerator import get_accelerator |
|
from ..util import is_comm_op, is_release_node, get_deepcompile_handle |
|
|
|
|
|
def _all_real_if_tensor(args): |
|
return tree_all(lambda x: not torch.is_tensor(x) or not is_fake(x), args) |
|
|
|
|
|
def _to(v, device): |
|
if torch.is_tensor(v): |
|
with unset_fake_temporarily(): |
|
return v.to(device) |
|
return v |
|
|
|
|
|
def _args_to_key(v): |
|
|
|
def _tensor_to_key(v) -> str: |
|
if torch.is_tensor(v): |
|
if v.numel() == 1: |
|
return f"{v.dtype}{v.device}{v.item()}" |
|
else: |
|
return f"{v.dtype}{v.device}{v.shape}" |
|
return str(v) |
|
|
|
return map_aggregate(v, _tensor_to_key) |
|
|
|
|
|
def _node_size(out): |
|
return sum([v.element_size() * v.numel() for v in tree_leaves(out) if torch.is_tensor(v)]) |
|
|
|
|
|
def _get_mem_usage_out_of_torch(): |
|
|
|
adjust = 0 |
|
try: |
|
import pynvml |
|
pynvml.nvmlInit() |
|
|
|
current_dev_id = get_accelerator().current_device() |
|
handle = pynvml.nvmlDeviceGetHandleByIndex(current_dev_id) |
|
info = pynvml.nvmlDeviceGetMemoryInfo(handle) |
|
|
|
torch_alloc = get_accelerator().memory_allocated() |
|
adjust = info.used - torch_alloc |
|
except: |
|
|
|
pass |
|
|
|
return adjust |
|
|
|
|
|
|
|
class ProfilingInterpreter(Interpreter): |
|
|
|
def __init__(self, gm: GraphModule, iteration: int = 10, warmup: int = 5, debug_log=False): |
|
super().__init__(gm) |
|
|
|
self.nz3 = get_deepcompile_handle() |
|
|
|
assert iteration > 0 |
|
assert warmup >= 0 |
|
self.iteration = iteration |
|
self.warmup = warmup |
|
self.device = torch.device(get_accelerator().current_device()) |
|
self.cache: Dict[Tuple, Any] = {} |
|
self.distributed = dist.is_initialized() |
|
self.allgather_mem: Dict[int, int] = {} |
|
self.debug_log = debug_log |
|
self.mem_usage_out_of_torch = 0 |
|
|
|
def run(self, *args) -> Any: |
|
"""Run the graph with profiling enabled. |
|
|
|
args: inputs to the graph. Tensors in the inpusts must be real tensors, not fake tensors. args can contain ds parameters. |
|
returns: The output of the graph. Tensor in the output is real tensors. |
|
""" |
|
try: |
|
assert _all_real_if_tensor(args), "Inputs must be real tensors" |
|
self.nz3.enable_profiling(True) |
|
|
|
with unset_fake_temporarily(): |
|
with get_accelerator().random().fork_rng(devices=[self.device]): |
|
self.mem_usage_out_of_torch = _get_mem_usage_out_of_torch() |
|
return_val = super().run(*args) |
|
except Exception as e: |
|
msg = e.msg if "msg" in dir(e) else str(e) |
|
print(f"Profiling error {msg}") |
|
finally: |
|
self.nz3.clear_all_gathered_params() |
|
self.nz3.enable_profiling(False) |
|
return return_val |
|
|
|
def run_node(self, n: torch.fx.Node) -> Any: |
|
|
|
if n.op in {"placeholder", "output"}: |
|
n.meta["device_time"] = 0.0 |
|
n.meta["wall_time"] = 0.0 |
|
n.meta["alloc_mem"] = 0 |
|
n.meta["max_memory"] = 0 |
|
n.meta["tensor_size"] = _node_size(n) |
|
return super().run_node(n) |
|
|
|
args, kwargs = self.fetch_args_kwargs_from_env(n) |
|
assert isinstance(args, tuple) |
|
assert isinstance(kwargs, dict) |
|
|
|
def rebuild_param_if_necessary(v): |
|
if hasattr(v, "ds_id"): |
|
v.all_gather(param_list=[v]) |
|
return v |
|
|
|
args = map_aggregate(args, lambda x: rebuild_param_if_necessary(x)) |
|
|
|
args = map_aggregate(args, lambda x: _to(x, self.device)) |
|
kwargs = map_aggregate(kwargs, lambda x: _to(x, self.device)) |
|
|
|
cache_key = (n.target, _args_to_key(args), _args_to_key(kwargs)) |
|
cache_hit = cache_key in self.cache |
|
|
|
cache_hit_flag = torch.tensor([0 if cache_hit else 1], device=self.device, dtype=torch.int) |
|
if self.distributed: |
|
dist.all_reduce(cache_hit_flag, dist.ReduceOp.SUM) |
|
cache_hit = cache_hit_flag.item() == 0 |
|
|
|
if cache_hit: |
|
device_time, wall_time, alloc_mem, max_mem, tensor_size = self.cache[cache_key] |
|
n.meta["device_time"] = device_time |
|
n.meta["wall_time"] = wall_time |
|
n.meta["alloc_mem"] = alloc_mem |
|
n.meta["max_mem"] = max_mem |
|
n.meta["tensor_size"] = tensor_size |
|
|
|
is_release_op = is_release_node(n) |
|
run_only_once = cache_hit or is_release_op |
|
iteration = 1 if run_only_once else self.iteration |
|
accelerator = get_accelerator() |
|
start_events = [accelerator.Event(enable_timing=True) for _ in range(iteration)] |
|
end_events = [accelerator.Event(enable_timing=True) for _ in range(iteration)] |
|
|
|
get_accelerator().reset_peak_memory_stats() |
|
alloc_mem_start = get_accelerator().memory_allocated() |
|
max_mem_start = get_accelerator().max_memory_allocated() |
|
|
|
if not run_only_once: |
|
for i in range(self.warmup): |
|
out = getattr(self, n.op)(n.target, args, kwargs) |
|
|
|
if is_comm_op(n): |
|
assert self.distributed, f"Distributed environment is not initialized but comm operator {n.name} {n.target} is used." |
|
dist.barrier() |
|
|
|
start = time.time() |
|
for i in range(iteration): |
|
start_events[i].record() |
|
out = getattr(self, n.op)(n.target, args, kwargs) |
|
end_events[i].record() |
|
accelerator.synchronize() |
|
walltime_sum = time.time() - start |
|
|
|
if is_comm_op(n): |
|
dist.barrier() |
|
|
|
alloc_mem = get_accelerator().memory_allocated() - alloc_mem_start + self.mem_usage_out_of_torch |
|
max_memory = get_accelerator().max_memory_allocated() - max_mem_start + self.mem_usage_out_of_torch |
|
tensor_size = _node_size(out) |
|
|
|
def partition_param_if_necessary(v): |
|
if hasattr(v, "ds_id") and not v.ds_persist: |
|
v.partition(param_list=[v], has_been_updated=False) |
|
return v |
|
|
|
args = map_aggregate(args, lambda x: partition_param_if_necessary(x)) |
|
|
|
if not cache_hit: |
|
device_time = statistics.mean([s.elapsed_time(e) for s, e in zip(start_events, end_events)]) |
|
wall_time = walltime_sum / iteration * 1000 |
|
|
|
with unset_fake_temporarily(): |
|
vals_to_bcast = torch.tensor([device_time, wall_time, alloc_mem, max_memory, tensor_size], |
|
device=self.device) |
|
if self.distributed: |
|
dist.all_reduce(vals_to_bcast, dist.ReduceOp.AVG) |
|
n.meta["device_time"] = vals_to_bcast[0].item() |
|
n.meta["wall_time"] = vals_to_bcast[1].item() |
|
n.meta["alloc_mem"] = int(vals_to_bcast[2].item()) |
|
n.meta["max_mem"] = int(vals_to_bcast[3].item()) |
|
n.meta["tensor_size"] = int(vals_to_bcast[4].item()) |
|
self.cache[cache_key] = (n.meta["device_time"], n.meta["wall_time"], n.meta["alloc_mem"], |
|
n.meta["max_mem"], n.meta["tensor_size"]) |
|
|
|
if is_release_op: |
|
n.meta["alloc_mem"] = -self.allgather_mem.get(args[2], 0) |
|
|
|
if dist.get_rank() == 0 and self.debug_log: |
|
print( |
|
f"{n.target} {n.meta['device_time']:.2f}ms {n.meta['wall_time']:.2f}ms alloc_mem={n.meta['alloc_mem'] / 1024 / 1024:.2f}MB max_mem={n.meta['max_mem'] / 1024 / 1024:.2f}MB tensor_size={n.meta['tensor_size']}" |
|
) |
|
|
|
if n.target == torch.ops.dc.allgather_param.default: |
|
out = args[0] |
|
assert hasattr(out, "ds_id") |
|
if not out.ds_persist: |
|
self.nz3.invalidate_gathered_param(args[2]) |
|
self.allgather_mem[out.ds_id] = n.meta["alloc_mem"] |
|
|
|
return out |
|
|
|
|
|
class MemoryProfilingInterpreter(Interpreter): |
|
|
|
def __init__(self, gm: GraphModule, debug_log=False): |
|
super().__init__(gm) |
|
self.nz3 = get_deepcompile_handle() |
|
self.device = torch.device(get_accelerator().current_device()) |
|
self.mem_record = [] |
|
self.last_alloc = get_accelerator().memory_allocated() |
|
|
|
self.node_counter = 0 |
|
self.node_num = len(gm.graph.nodes) |
|
self.debug_log = debug_log |
|
|
|
def run(self, *args) -> Any: |
|
try: |
|
assert _all_real_if_tensor(args), "Inputs must be real tensors" |
|
self.nz3.enable_profiling(True) |
|
self.mem_usage_out_of_torch = _get_mem_usage_out_of_torch() |
|
|
|
with unset_fake_temporarily(): |
|
with get_accelerator().random().fork_rng(devices=[self.device]): |
|
return_val = super().run(*args) |
|
except Exception as e: |
|
print(f"MemoryProfiling error {e}") |
|
finally: |
|
self.nz3.enable_profiling(False) |
|
|
|
return return_val |
|
|
|
def run_node(self, n: torch.fx.Node) -> Any: |
|
get_accelerator().reset_peak_memory_stats() |
|
|
|
if n.op in {"placeholder", "output"}: |
|
ret = super().run_node(n) |
|
else: |
|
args, kwargs = self.fetch_args_kwargs_from_env(n) |
|
args = map_aggregate(args, lambda x: _to(x, self.device)) |
|
kwargs = map_aggregate(kwargs, lambda x: _to(x, self.device)) |
|
ret = getattr(self, n.op)(n.target, args, kwargs) |
|
|
|
del args, kwargs |
|
|
|
current_alloc = get_accelerator().memory_allocated() + self.mem_usage_out_of_torch |
|
max_alloc = get_accelerator().max_memory_allocated() + self.mem_usage_out_of_torch |
|
vals_to_bcast = torch.tensor([current_alloc, max_alloc], device=self.device) |
|
dist.all_reduce(vals_to_bcast, dist.ReduceOp.MAX) |
|
current_alloc = vals_to_bcast[0].item() |
|
max_alloc = vals_to_bcast[1].item() |
|
|
|
self.mem_record.append((n.name, current_alloc, current_alloc - self.last_alloc, max_alloc)) |
|
|
|
self.node_counter += 1 |
|
if self.debug_log and dist.get_rank() == 0: |
|
print( |
|
f"Mem prof Node {self.node_counter}/{self.node_num} {n.name} memory {current_alloc / 1024 / 1024:.2f}MB delta {(current_alloc - self.last_alloc) / 1024 / 1024:.2f}MB" |
|
) |
|
|
|
self.last_alloc = current_alloc |
|
|
|
return ret |
|
|
|
def dump(self, path): |
|
import pandas as pd |
|
df = pd.DataFrame(self.mem_record, columns=["node", "memory", "delta", "max_mem"]) |
|
df.to_csv(path, index=False) |
|
|