# Copyright (c) Microsoft Corporation. # SPDX-License-Identifier: Apache-2.0 # DeepSpeed Team import functools import operator from typing import List, Tuple, Dict from collections import defaultdict import torch from torch.fx import Node, Graph from torch.fx.node import map_aggregate, Argument, map_arg try: from torch._subclasses.fake_tensor import unset_fake_temporarily except ImportError: # Unsupported torch version pass import deepspeed.comm as dist from deepspeed.accelerator import get_accelerator from deepspeed.utils.torch import required_torch_version from deepspeed.ops.op_builder.dc import DeepCompileBuilder def is_deepcompile_supported() -> bool: return required_torch_version(min_version=2.6, max_version=2.7) and get_accelerator().device_name() == "cuda" dc_handle = None if is_deepcompile_supported(): sym_size_ops = { operator.ge, operator.le, operator.eq, operator.ne, operator.gt, operator.lt, torch.ops.aten.sym_size.int, operator.getitem, } def get_deepcompile_handle(): global dc_handle if dc_handle is None: dc_handle = DeepCompileBuilder().load() return dc_handle def is_backend_inductor(backend): return backend == "inductor" backward_started = False pre_backward_hooks = [] def add_pre_backward_hook(hook): pre_backward_hooks.append(hook) def deepcompile_backward_prologue(is_gradient_accumulation_boundary): for hook in pre_backward_hooks: hook() dc = get_deepcompile_handle() dc.start_backward(is_gradient_accumulation_boundary) def log_rank0(msg: str, enable: bool = False): if dist.get_rank() == 0 and enable: print(msg) def get_no_copy_ops(): # Need to compile custom ops get_deepcompile_handle() return { torch.ops.aten.t.default, torch.ops.aten.view.default, torch.ops.aten.detach.default, torch.ops.aten.permute.default, torch.ops.dc.wait_allgather.default } def get_input_nodes(graph: Graph) -> List[Node]: return [n for n in graph.nodes if n.op == "placeholder"] def get_param_nodes(graph: Graph, index_to_ds_ids: List[Tuple[int, int]]) -> List[Node]: all_input_nodes = get_input_nodes(graph) return [all_input_nodes[i] for i, _, _ in index_to_ds_ids] def is_comm_op(node: Node) -> bool: return "comm" in node.meta and node.meta["comm"] def exclude_from_act_offload(node: Node) -> bool: return node.target in sym_size_ops def dtype_to_elem_size(dtype: torch.dtype) -> int: if dtype == torch.float32: elem_size = 4 elif dtype == torch.float64: elem_size = 8 elif dtype == torch.float16: elem_size = 2 else: raise ValueError(f"Unsupported dtype: {dtype}") return elem_size def tensor_meta_size(tensor_meta) -> int: numel = 1 if len(tensor_meta.shape) == 0 else functools.reduce(operator.mul, tensor_meta.shape) dtype = tensor_meta.dtype if dtype == torch.float32: elem_size = 4 elif dtype == torch.float64 or dtype == torch.int64: elem_size = 8 elif dtype == torch.float16 or dtype == torch.bfloat16: elem_size = 2 elif dtype == torch.bool: elem_size = 1 else: raise ValueError(f"Unsupported dtype: {dtype}") return numel * elem_size class NodeValueOffloadHelper: def __init__(self, device): self.device = device self.env_values: Dict[str, Argument] = {} self.original_device: Dict[torch.Tensor, torch.device] = {} def _to_cpu(self, v): if torch.is_tensor(v): with unset_fake_temporarily(): device = v.device offloaded = v.to('cpu').detach() self.original_device[offloaded] = device return offloaded return v def _from_cpu(self, v): if torch.is_tensor(v) and v in self.original_device: return v.to(self.original_device[v]) return v def save(self, name: str, v: Argument, offload) -> None: self.env_values[name] = map_aggregate(v, lambda x: self._to_cpu(x) if offload else x) def load(self, name: str) -> Argument: return map_aggregate(self.env_values[name], lambda x: self._from_cpu(x)) def get_offloaded_value(self, name: str) -> Argument: return self.env_values[name] def has_value(self, name: str) -> bool: return name in self.env_values def clear(self) -> None: self.env_values.clear() self.original_device.clear() def materialize_fake(v, device=None): from torch._subclasses.fake_tensor import is_fake def convert(t): if is_fake(t): with unset_fake_temporarily(): if t.is_floating_point(): return torch.randn(t.shape, dtype=t.dtype, device=t.device if device is None else device, layout=t.layout, requires_grad=t.requires_grad, pin_memory=t.is_pinned()) else: return torch.zeros(t.shape, dtype=t.dtype, device=t.device if device is None else device, requires_grad=t.requires_grad) return t return map_aggregate(v, lambda x: convert(x)) def get_last_uses(graph: Graph): position = {node: i for i, node in enumerate(graph.nodes)} node_to_last_use: Dict[Node, Node] = {} user_to_last_uses: Dict[Node, List[Node]] = {} no_copy_ops = get_no_copy_ops() def register_last_uses(n: Node, user: Node): update = False known_last_use = None if user.target in no_copy_ops and n in node_to_last_use: last_user = node_to_last_use[user] last_use_position = position[last_user] known_last_use = node_to_last_use[n] known_last_use_position = position[known_last_use] update = last_use_position > known_last_use_position if n not in node_to_last_use or update: if user.target in no_copy_ops: user = node_to_last_use[user] node_to_last_use[n] = user user_to_last_uses.setdefault(user, []).append(n) if known_last_use: user_to_last_uses[known_last_use].remove(n) for node in reversed(graph.nodes): map_arg(node.args, lambda n: register_last_uses(n, node)) map_arg(node.kwargs, lambda n: register_last_uses(n, node)) return node_to_last_use, user_to_last_uses def get_real_uses(graph: Graph): node_to_uses: Dict[Node, List[Node]] = defaultdict(list) no_copy_ops = get_no_copy_ops() def register_last_uses(n: Node, user: Node): if user.target == "output": return if user.target in no_copy_ops: users = node_to_uses[user] node_to_uses[n].extend(users) else: node_to_uses[n].append(user) for node in reversed(graph.nodes): map_arg(node.args, lambda n: register_last_uses(n, node)) map_arg(node.kwargs, lambda n: register_last_uses(n, node)) return node_to_uses def count_inflight_values(graph: Graph, file_path: str): position = {node: i for i, node in enumerate(graph.nodes)} node_to_last_use, user_to_last_uses = get_last_uses(graph) max_inflight_size = 0 inflight_values = set() # Output csv. csv_filename = file_path csv_data = [] header = [ 'Node', 'tensor_size', 'inflight_size', 'inflight_size_in_output', 'args', 'users', 'node_to_last_use', 'lifetime', 'user_to_last_uses', 'inflight_values' ] csv_data.append(header) from .fx import get_output_node output_node = get_output_node(graph) values_in_output = set([n for n in output_node.args[0] if isinstance(n, Node)]) for node in graph.nodes: inflight_values.add(node) if node in user_to_last_uses: for to_delete in user_to_last_uses[node]: inflight_values.remove(to_delete) assert "tensor_size" in node.meta, f"Node {node} does not have tensor_size" inflight_size = sum(n.meta["tensor_size"] for n in inflight_values) inflight_size_in_output = sum(n.meta["tensor_size"] for n in inflight_values if n in values_in_output) lifetime = position[node_to_last_use[node]] - position[node] if node in node_to_last_use else 0 row = [ node.name, node.meta["tensor_size"], inflight_size, inflight_size_in_output, [a.name for a in node.args if isinstance(a, Node)], list(node.users.keys()), node_to_last_use[node] if node in node_to_last_use else 'NA', lifetime, user_to_last_uses[node] if node in user_to_last_uses else 'NA', list(inflight_values) ] csv_data.append(row) # print( # f"Node: {node.name} users: {list(node.users.keys())} node_to_last_use: {node_to_last_use[node] if node in node_to_last_use else 'NA'} user_to_last_uses: {user_to_last_uses[node] if node in user_to_last_uses else 'NA'} inflight_values: {inflight_values} inflight_size: {inflight_size}" # ) max_inflight_size = max(max_inflight_size, inflight_size) import csv with open(csv_filename, mode='w', newline='') as file: writer = csv.writer(file) writer.writerows(csv_data) print(f"Max inflight size: {max_inflight_size}") print(f"Data successfully written to {csv_filename}") def get_activation_node_names(graph: Graph, param_nodes_bw: List[Node], fwd_output_names: List[str]): input_nodes = get_input_nodes(graph) param_node_names = set([n.name for n in param_nodes_bw]) activation_node_names = [] for in_node in input_nodes: if in_node.name in fwd_output_names: if in_node.name not in param_node_names: activation_node_names.append(in_node.name) return activation_node_names class TensorOffloadHelper(): def __init__(self): self.devices = {} self.base_tensors = {} self.views = {} self.arg_list = [] self.offloaded = {} self.non_tensor = {} def offload(self, argument): def is_base_tensor(tensor): return torch.is_tensor(a) and not a._is_view() and not hasattr(tensor, "ds_id") base_tensor_ids = set() for a in argument: if is_base_tensor(a): base_tensor_ids.add(id(a)) for a in argument: a_id = id(a) if is_base_tensor(a): # Base tensor self.devices[a_id] = a.device self.base_tensors[a_id] = a # elif torch.is_tensor(a) and not hasattr(a, "ds_id") and id(a._base) in base_tensor_ids: # # View # self.views[a_id] = { # "base_id": id(a._base), # "size": a.size(), # "stride": a.stride(), # "offset": a.storage_offset(), # } else: # other types or ds tensor self.non_tensor[a_id] = a self.arg_list.append(a_id) for a in argument: if is_base_tensor(a): a.data = a.data.to("cpu") def reload(self, in_place): loaded_base_tensors = {} for a_id in self.arg_list: if a_id in self.base_tensors: device = self.devices[a_id] if in_place: self.base_tensors[a_id].data = self.base_tensors[a_id].to(device) loaded_base_tensors[a_id] = self.base_tensors[a_id] else: loaded_base_tensors[a_id] = self.base_tensors[a_id].to(device) results = [] for a_id in self.arg_list: if a_id in self.base_tensors: results.append(loaded_base_tensors[a_id]) # elif a_id in self.views: # view_info = self.views[a_id] # # print(f"load_args loading view {a_id} base_id={view_info['base_id']} size={view_info['size']} stride={view_info['stride']} offset={view_info['offset']}") # base_tensor = loaded_base_tensors[view_info["base_id"]] # view_tensor = base_tensor.as_strided( # view_info["size"], view_info["stride"], view_info["offset"] # ) # results.append(view_tensor) elif a_id in self.non_tensor: results.append(self.non_tensor[a_id]) return results def add_mem_profile_nodes(graph: Graph, prefix: str): def show_memory(label: str): if dist.get_rank() == 0: print( f"{prefix} {label} alloc_mem={get_accelerator().memory_allocated()} max_mem={get_accelerator().max_memory_allocated()}" ) nodes = list(graph.nodes) for node in nodes: if node.op == "output": continue with graph.inserting_after(node): msg = f"Mem {node.name}" name = f"show_memory_{node.name}" graph.create_node('call_function', show_memory, (msg, ), {}, name=name) def is_release_node(n: Node) -> bool: return n.target == torch.ops.dc.release_param.default def get_index_by_graph_id(graph_order, target_graph_id): for index, (graph_id, _) in enumerate(graph_order): if graph_id == target_graph_id: return index return -1 def pad_tensors(specs: List[Tuple[torch.Tensor, int, int]]) -> List[torch.Tensor]: """ specs = [ (input_ids, 1, pad_token_id), # Example: Pad the right side with (attention_mask, 1, 0), # Example: Pad the right side with 0 ... ] - Share the "maximum length of the dim dimension" across ranks for all specs - Pad the right side for the missing parts and return - Communication (`all_reduce`) happens only once """ assert len(specs) > 0, "specs is empty" device = specs[0][0].device # Vectorize local lengths local_sizes = torch.tensor( [tensor.size(dim) for tensor, dim, _ in specs], dtype=torch.long, device=device, ) # Element-wise MAX across ranks dist.all_reduce(local_sizes, op=dist.ReduceOp.MAX) max_sizes = local_sizes.tolist() # Pad each tensor as needed padded: List[torch.Tensor] = [] for (tensor, dim, pad_val), max_len in zip(specs, max_sizes): cur_len = tensor.size(dim) if cur_len < max_len: pad_len = max_len - cur_len pad_shape = [0] * (tensor.dim() * 2) # F.pad specification pad_shape[-(2 * dim + 1)] = pad_len # Pad the right side tensor = torch.nn.functional.pad(tensor, pad_shape, value=pad_val) padded.append(tensor) return padded