File size: 6,902 Bytes
9c6594c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

# This file was copied from PyTorch and modified for DeepSpeed.

from typing import Tuple, List
import operator

import torch
from torch.fx import GraphModule, Graph, Node

try:
    from torch._functorch.partitioners import is_sym_node, _is_primal, _is_fwd_seed_offset, _extract_fwd_bwd_outputs, _extract_graph_with_inputs_outputs, _extract_fwd_bwd_modules, has_recomputable_ops, min_cut_rematerialization_partition, choose_saved_values_set
except ImportError:
    pass

from .util import get_no_copy_ops

_recompute_ops = {torch.ops.aten.t.default}


def _find_recompute_nodes(graph: Graph, ds_param_node: Node) -> List[Node]:
    """
    Given a graph and a node that represents a parameter that was allgathered,
    find all nodes that use the parameter and require recomputation.
    """
    no_copy_ops = get_no_copy_ops()
    recompute_nodes = set()
    for node in graph.nodes:
        if node.target in no_copy_ops:
            if ds_param_node in node.args:
                recompute_nodes.add(node)
            if any(a in recompute_nodes for a in node.args):
                recompute_nodes.add(node)

    return recompute_nodes


def _get_values_from_ds_params(joint_graph, param_indices):
    primal_inputs = list(filter(_is_primal, joint_graph.nodes))
    ds_param_inputs = [primal_inputs[arg_idx] for arg_idx, _, _ in param_indices]

    no_copy_ops = get_no_copy_ops()

    ds_param_inputs = set(ds_param_inputs)
    ds_param_users = {}

    for node in joint_graph.nodes:
        if node.target in no_copy_ops and any((a in ds_param_inputs or a in ds_param_users) for a in node.args):
            for a in node.args:
                if a in ds_param_inputs:
                    ds_param_users[node] = a
                elif a in ds_param_users:
                    ds_param_users[node] = ds_param_users[a]

    return ds_param_users


def get_wrapped_choose_saved_values_set(param_indices: List[Tuple[int, int, torch.Size]]):

    def ds_choose_saved_values_set(joint_graph: torch.fx.Graph, node_info, memory_budget=1) -> List[Node]:
        saved_values = choose_saved_values_set(joint_graph, node_info, memory_budget)
        ds_param_users = _get_values_from_ds_params(joint_graph, param_indices)

        new_saved_values = []
        for v in saved_values:
            if v in ds_param_users:
                ds_val = ds_param_users[v]
                if ds_val not in new_saved_values:
                    new_saved_values.append(ds_val)
            else:
                new_saved_values.append(v)

        return new_saved_values

    return ds_choose_saved_values_set


def get_wrapped_partitioner(param_indices: List[Tuple[int, int, torch.Size]]):

    def partition_recompute_ds_params(joint_module: GraphModule, _joint_inputs, *,
                                      num_fwd_outputs) -> Tuple[GraphModule, GraphModule]:
        """
        This is basically the same as the default_partition function, but
        it doesn't save the gathered params and values computed from them.
        """
        if has_recomputable_ops(joint_module):
            return min_cut_rematerialization_partition(joint_module, _joint_inputs, num_fwd_outputs=num_fwd_outputs)

        primal_inputs = list(filter(_is_primal, joint_module.graph.nodes))
        fwd_seed_offset_inputs = list(filter(_is_fwd_seed_offset, joint_module.graph.nodes))
        inputs = primal_inputs + fwd_seed_offset_inputs
        fwd_outputs, bwd_outputs = _extract_fwd_bwd_outputs(joint_module, num_fwd_outputs=num_fwd_outputs)
        forward_only_graph = _extract_graph_with_inputs_outputs(joint_module.graph, inputs, fwd_outputs, "forward")
        forward_node_names = {node.name for node in forward_only_graph.nodes if node.op != "output"}
        saved_values = []
        saved_sym_nodes = []

        fwd_inputs = list(filter(_is_primal, forward_only_graph.nodes))
        ds_param_inputs = [fwd_inputs[arg_idx] for arg_idx, _, _ in param_indices]
        ds_param_input_names = {node.name for node in ds_param_inputs}

        ds_param_recompute_nodes = set()

        for node in joint_module.graph.nodes:
            if node.name not in forward_node_names:
                continue

            if is_sym_node(node):
                # Symints must be kept separate from tensors so that PythonFunction only calls
                # save_for_backward on tensors and stashes symints in autograd .ctx
                saved_sym_nodes.append(node)
            elif "tensor_meta" not in node.meta and node.op == "call_function":
                # Since we can't save tuple of tensor values, we need to flatten out what we're saving
                users = node.users
                assert all(user.target == operator.getitem for user in users)
                saved_values.extend(users)
            else:
                backward_usages = [n for n in node.users if n.name not in forward_node_names]

                if "tensor_meta" in node.meta and all(is_sym_node(n) for n in backward_usages):
                    # If we have a tensor in the forward, where only its sizes/strides are needed in the backward,
                    # and not the actual tensor data,
                    # then it will be a lot cheaper to save only the sizes/strides, and not the actual tensor.
                    #
                    # Note that saving the tensor could also cause compilation problems:
                    # If the user mutated an input in the forward and uses its sizes/strides in the backward,
                    # then we would be obligated to clone the input before saving it to appease autograd.
                    # (This is how we originally found this bug).
                    saved_sym_nodes.extend(backward_usages)

                    if node.name in ds_param_input_names:
                        saved_values.append(node)
                        recompute_nodes = _find_recompute_nodes(joint_module.graph, node)
                        recompute_nodes = [n for n in recompute_nodes if n.name in forward_node_names]
                        for recompute_node in recompute_nodes:
                            ds_param_recompute_nodes.add(recompute_node)

                        if len(recompute_nodes) > 0:
                            saved_values.append(node)
                else:
                    if node not in ds_param_recompute_nodes:
                        saved_values.append(node)
        saved_values = list(dict.fromkeys(saved_values).keys())
        saved_sym_nodes = list(dict.fromkeys(saved_sym_nodes).keys())

        f_gm, b_gm = _extract_fwd_bwd_modules(
            joint_module,
            saved_values,
            saved_sym_nodes=saved_sym_nodes,
            num_fwd_outputs=num_fwd_outputs,
        )

        return f_gm, b_gm

    return partition_recompute_ds_params