# Copyright (c) Microsoft Corporation. # SPDX-License-Identifier: Apache-2.0 # DeepSpeed Team import time from typing import Any, Tuple, Dict import statistics import torch from torch.fx import GraphModule, Interpreter from torch.fx.node import map_aggregate try: from torch.utils._pytree import tree_all, tree_leaves from torch._subclasses.fake_tensor import unset_fake_temporarily, is_fake except ImportError: # Unsupported torch version pass import deepspeed.comm as dist from deepspeed.accelerator import get_accelerator from ..util import is_comm_op, is_release_node, get_deepcompile_handle def _all_real_if_tensor(args): return tree_all(lambda x: not torch.is_tensor(x) or not is_fake(x), args) def _to(v, device): if torch.is_tensor(v): with unset_fake_temporarily(): return v.to(device) return v def _args_to_key(v): def _tensor_to_key(v) -> str: if torch.is_tensor(v): if v.numel() == 1: return f"{v.dtype}{v.device}{v.item()}" else: return f"{v.dtype}{v.device}{v.shape}" return str(v) return map_aggregate(v, _tensor_to_key) def _node_size(out): return sum([v.element_size() * v.numel() for v in tree_leaves(out) if torch.is_tensor(v)]) def _get_mem_usage_out_of_torch(): adjust = 0 try: import pynvml pynvml.nvmlInit() current_dev_id = get_accelerator().current_device() handle = pynvml.nvmlDeviceGetHandleByIndex(current_dev_id) info = pynvml.nvmlDeviceGetMemoryInfo(handle) torch_alloc = get_accelerator().memory_allocated() adjust = info.used - torch_alloc except: # pynvml not available pass return adjust # https://pytorch.org/tutorials/intermediate/fx_profiling_tutorial.html class ProfilingInterpreter(Interpreter): def __init__(self, gm: GraphModule, iteration: int = 10, warmup: int = 5, debug_log=False): super().__init__(gm) self.nz3 = get_deepcompile_handle() assert iteration > 0 assert warmup >= 0 self.iteration = iteration self.warmup = warmup self.device = torch.device(get_accelerator().current_device()) self.cache: Dict[Tuple, Any] = {} self.distributed = dist.is_initialized() self.allgather_mem: Dict[int, int] = {} self.debug_log = debug_log self.mem_usage_out_of_torch = 0 def run(self, *args) -> Any: """Run the graph with profiling enabled. args: inputs to the graph. Tensors in the inpusts must be real tensors, not fake tensors. args can contain ds parameters. returns: The output of the graph. Tensor in the output is real tensors. """ try: assert _all_real_if_tensor(args), "Inputs must be real tensors" self.nz3.enable_profiling(True) with unset_fake_temporarily(): with get_accelerator().random().fork_rng(devices=[self.device]): self.mem_usage_out_of_torch = _get_mem_usage_out_of_torch() return_val = super().run(*args) except Exception as e: msg = e.msg if "msg" in dir(e) else str(e) print(f"Profiling error {msg}") finally: self.nz3.clear_all_gathered_params() self.nz3.enable_profiling(False) return return_val def run_node(self, n: torch.fx.Node) -> Any: if n.op in {"placeholder", "output"}: n.meta["device_time"] = 0.0 n.meta["wall_time"] = 0.0 n.meta["alloc_mem"] = 0 n.meta["max_memory"] = 0 n.meta["tensor_size"] = _node_size(n) return super().run_node(n) args, kwargs = self.fetch_args_kwargs_from_env(n) assert isinstance(args, tuple) assert isinstance(kwargs, dict) def rebuild_param_if_necessary(v): if hasattr(v, "ds_id"): v.all_gather(param_list=[v]) return v args = map_aggregate(args, lambda x: rebuild_param_if_necessary(x)) args = map_aggregate(args, lambda x: _to(x, self.device)) kwargs = map_aggregate(kwargs, lambda x: _to(x, self.device)) cache_key = (n.target, _args_to_key(args), _args_to_key(kwargs)) cache_hit = cache_key in self.cache cache_hit_flag = torch.tensor([0 if cache_hit else 1], device=self.device, dtype=torch.int) if self.distributed: dist.all_reduce(cache_hit_flag, dist.ReduceOp.SUM) cache_hit = cache_hit_flag.item() == 0 if cache_hit: device_time, wall_time, alloc_mem, max_mem, tensor_size = self.cache[cache_key] n.meta["device_time"] = device_time n.meta["wall_time"] = wall_time n.meta["alloc_mem"] = alloc_mem n.meta["max_mem"] = max_mem n.meta["tensor_size"] = tensor_size is_release_op = is_release_node(n) run_only_once = cache_hit or is_release_op iteration = 1 if run_only_once else self.iteration accelerator = get_accelerator() start_events = [accelerator.Event(enable_timing=True) for _ in range(iteration)] end_events = [accelerator.Event(enable_timing=True) for _ in range(iteration)] get_accelerator().reset_peak_memory_stats() alloc_mem_start = get_accelerator().memory_allocated() max_mem_start = get_accelerator().max_memory_allocated() if not run_only_once: for i in range(self.warmup): out = getattr(self, n.op)(n.target, args, kwargs) if is_comm_op(n): assert self.distributed, f"Distributed environment is not initialized but comm operator {n.name} {n.target} is used." dist.barrier() start = time.time() for i in range(iteration): start_events[i].record() out = getattr(self, n.op)(n.target, args, kwargs) end_events[i].record() accelerator.synchronize() walltime_sum = time.time() - start if is_comm_op(n): dist.barrier() alloc_mem = get_accelerator().memory_allocated() - alloc_mem_start + self.mem_usage_out_of_torch max_memory = get_accelerator().max_memory_allocated() - max_mem_start + self.mem_usage_out_of_torch tensor_size = _node_size(out) def partition_param_if_necessary(v): if hasattr(v, "ds_id") and not v.ds_persist: v.partition(param_list=[v], has_been_updated=False) return v args = map_aggregate(args, lambda x: partition_param_if_necessary(x)) if not cache_hit: device_time = statistics.mean([s.elapsed_time(e) for s, e in zip(start_events, end_events)]) wall_time = walltime_sum / iteration * 1000 with unset_fake_temporarily(): vals_to_bcast = torch.tensor([device_time, wall_time, alloc_mem, max_memory, tensor_size], device=self.device) if self.distributed: dist.all_reduce(vals_to_bcast, dist.ReduceOp.AVG) n.meta["device_time"] = vals_to_bcast[0].item() n.meta["wall_time"] = vals_to_bcast[1].item() n.meta["alloc_mem"] = int(vals_to_bcast[2].item()) n.meta["max_mem"] = int(vals_to_bcast[3].item()) n.meta["tensor_size"] = int(vals_to_bcast[4].item()) self.cache[cache_key] = (n.meta["device_time"], n.meta["wall_time"], n.meta["alloc_mem"], n.meta["max_mem"], n.meta["tensor_size"]) if is_release_op: n.meta["alloc_mem"] = -self.allgather_mem.get(args[2], 0) if dist.get_rank() == 0 and self.debug_log: print( f"{n.target} {n.meta['device_time']:.2f}ms {n.meta['wall_time']:.2f}ms alloc_mem={n.meta['alloc_mem'] / 1024 / 1024:.2f}MB max_mem={n.meta['max_mem'] / 1024 / 1024:.2f}MB tensor_size={n.meta['tensor_size']}" ) if n.target == torch.ops.dc.allgather_param.default: out = args[0] assert hasattr(out, "ds_id") if not out.ds_persist: self.nz3.invalidate_gathered_param(args[2]) self.allgather_mem[out.ds_id] = n.meta["alloc_mem"] return out class MemoryProfilingInterpreter(Interpreter): def __init__(self, gm: GraphModule, debug_log=False): super().__init__(gm) self.nz3 = get_deepcompile_handle() self.device = torch.device(get_accelerator().current_device()) self.mem_record = [] self.last_alloc = get_accelerator().memory_allocated() self.node_counter = 0 self.node_num = len(gm.graph.nodes) self.debug_log = debug_log def run(self, *args) -> Any: try: assert _all_real_if_tensor(args), "Inputs must be real tensors" self.nz3.enable_profiling(True) self.mem_usage_out_of_torch = _get_mem_usage_out_of_torch() with unset_fake_temporarily(): with get_accelerator().random().fork_rng(devices=[self.device]): return_val = super().run(*args) except Exception as e: print(f"MemoryProfiling error {e}") finally: self.nz3.enable_profiling(False) return return_val def run_node(self, n: torch.fx.Node) -> Any: get_accelerator().reset_peak_memory_stats() if n.op in {"placeholder", "output"}: ret = super().run_node(n) else: args, kwargs = self.fetch_args_kwargs_from_env(n) args = map_aggregate(args, lambda x: _to(x, self.device)) kwargs = map_aggregate(kwargs, lambda x: _to(x, self.device)) ret = getattr(self, n.op)(n.target, args, kwargs) del args, kwargs current_alloc = get_accelerator().memory_allocated() + self.mem_usage_out_of_torch max_alloc = get_accelerator().max_memory_allocated() + self.mem_usage_out_of_torch vals_to_bcast = torch.tensor([current_alloc, max_alloc], device=self.device) dist.all_reduce(vals_to_bcast, dist.ReduceOp.MAX) current_alloc = vals_to_bcast[0].item() max_alloc = vals_to_bcast[1].item() self.mem_record.append((n.name, current_alloc, current_alloc - self.last_alloc, max_alloc)) self.node_counter += 1 if self.debug_log and dist.get_rank() == 0: print( f"Mem prof Node {self.node_counter}/{self.node_num} {n.name} memory {current_alloc / 1024 / 1024:.2f}MB delta {(current_alloc - self.last_alloc) / 1024 / 1024:.2f}MB" ) self.last_alloc = current_alloc return ret def dump(self, path): import pandas as pd df = pd.DataFrame(self.mem_record, columns=["node", "memory", "delta", "max_mem"]) df.to_csv(path, index=False)