# Copyright (c) Microsoft Corporation. # SPDX-License-Identifier: Apache-2.0 # DeepSpeed Team from dataclasses import dataclass, field from typing import Any, Dict, List, Tuple from functools import reduce import torch from torch.fx import Graph, Node from .fx import get_output_node from .util import get_param_nodes @dataclass class DSGraphParam: name: str shape: torch.Size dtype: torch.dtype device: torch.device node: Node allgather_node: Node release_node: Node param: torch.Tensor numel: int = field(init=False) def __post_init__(self): self.numel = reduce(lambda x, y: x * y, self.shape) class DSGraphParamManager: def __init__(self, fw_graph: Graph, sample_inputs: Any, index_to_ds_ids: List[Tuple[int, int, int]]): self._fw_graph = fw_graph self._bw_graph = None self._params: Dict[str, DSGraphParam] = {} self._param_name_to_grad: Dict[str, Node] = {} self._ds_ids: Dict[str, int] = {} param_nodes = get_param_nodes(fw_graph, index_to_ds_ids) self._param_names = [pn.name for pn in param_nodes] self._param_indices = [i for i, _, _ in index_to_ds_ids] param_inputs = [sample_inputs[i] for i, _, _ in index_to_ds_ids] ds_ids = [ds_id for _, ds_id, _ in index_to_ds_ids] ds_shapes = [ds_shape for _, _, ds_shape in index_to_ds_ids] for pn, pi, ds_id, ds_shape in zip(param_nodes, param_inputs, ds_ids, ds_shapes): self._params[pn.name] = DSGraphParam(name=pn.name, shape=ds_shape, dtype=pi.dtype, device=pi.device, node=pn, allgather_node=None, release_node=None, param=pi) self._ds_ids[pn.name] = ds_id def get_bwd_mapping(self, bw_graph: Graph): self._bw_graph = bw_graph output_node = get_output_node(bw_graph) param_nodes_bw = [n for n in self._bw_graph.nodes if n.name in self.param_names] grad_outputs = [output_node.args[0][i] for i in self._param_indices] param_name_to_grad = {param_name: grad for param_name, grad in zip(self.param_names, grad_outputs)} return param_nodes_bw, param_name_to_grad @property def param_names(self) -> List[str]: return self._param_names @property def params(self) -> Dict[str, DSGraphParam]: return self._params @property def ds_ids(self) -> Dict[str, int]: return self._ds_ids def get_grad_name(self, param_name) -> str: assert self._param_name_to_grad is not None, "Backward graph is not added yet" return self._param_name_to_grad[param_name]