File size: 2,948 Bytes
9c6594c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

from dataclasses import dataclass, field
from typing import Any, Dict, List, Tuple
from functools import reduce

import torch
from torch.fx import Graph, Node

from .fx import get_output_node
from .util import get_param_nodes


@dataclass
class DSGraphParam:
    name: str
    shape: torch.Size
    dtype: torch.dtype
    device: torch.device
    node: Node
    allgather_node: Node
    release_node: Node
    param: torch.Tensor
    numel: int = field(init=False)

    def __post_init__(self):
        self.numel = reduce(lambda x, y: x * y, self.shape)


class DSGraphParamManager:

    def __init__(self, fw_graph: Graph, sample_inputs: Any, index_to_ds_ids: List[Tuple[int, int, int]]):
        self._fw_graph = fw_graph
        self._bw_graph = None
        self._params: Dict[str, DSGraphParam] = {}
        self._param_name_to_grad: Dict[str, Node] = {}
        self._ds_ids: Dict[str, int] = {}

        param_nodes = get_param_nodes(fw_graph, index_to_ds_ids)
        self._param_names = [pn.name for pn in param_nodes]
        self._param_indices = [i for i, _, _ in index_to_ds_ids]

        param_inputs = [sample_inputs[i] for i, _, _ in index_to_ds_ids]
        ds_ids = [ds_id for _, ds_id, _ in index_to_ds_ids]
        ds_shapes = [ds_shape for _, _, ds_shape in index_to_ds_ids]

        for pn, pi, ds_id, ds_shape in zip(param_nodes, param_inputs, ds_ids, ds_shapes):
            self._params[pn.name] = DSGraphParam(name=pn.name,
                                                 shape=ds_shape,
                                                 dtype=pi.dtype,
                                                 device=pi.device,
                                                 node=pn,
                                                 allgather_node=None,
                                                 release_node=None,
                                                 param=pi)
            self._ds_ids[pn.name] = ds_id

    def get_bwd_mapping(self, bw_graph: Graph):
        self._bw_graph = bw_graph

        output_node = get_output_node(bw_graph)
        param_nodes_bw = [n for n in self._bw_graph.nodes if n.name in self.param_names]
        grad_outputs = [output_node.args[0][i] for i in self._param_indices]
        param_name_to_grad = {param_name: grad for param_name, grad in zip(self.param_names, grad_outputs)}
        return param_nodes_bw, param_name_to_grad

    @property
    def param_names(self) -> List[str]:
        return self._param_names

    @property
    def params(self) -> Dict[str, DSGraphParam]:
        return self._params

    @property
    def ds_ids(self) -> Dict[str, int]:
        return self._ds_ids

    def get_grad_name(self, param_name) -> str:
        assert self._param_name_to_grad is not None, "Backward graph is not added yet"
        return self._param_name_to_grad[param_name]