File size: 8,933 Bytes
9c6594c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

import torch

try:
    import torch.utils._pytree as pytree
    from torch._functorch.aot_autograd import create_aot_dispatcher_function
    from torch._inductor.lowering import register_lowering, fallbacks, add_needs_realized_inputs
    from torch._inductor.ir import TensorBox, FallbackKernel, Layout, IRNode
    from torch._inductor.virtualized import V
    from torch._inductor.scheduler import Scheduler

    original_create_aot_dispatcher_function = create_aot_dispatcher_function
except ImportError:
    pass

from .util import get_input_nodes
from .graph_param import DSGraphParamManager


def patch_compiler(original_compiler, dc_compiler, z3_partition: bool, graph_id, graph_param_manager, bwd: bool):

    def wrapped_compiler(gm, fake_inputs):
        mod_graph = dc_compiler(gm, fake_inputs)

        # For symint case
        if mod_graph is None:
            return None

        if z3_partition:
            # Inductor validates input size estimated by the first trace, where ds tensor is materialized.
            # We need to patch the input tensors to avoid the validation error.
            patched_inputs = []
            if bwd:
                param_nodes_bw, _ = graph_param_manager[graph_id].get_bwd_mapping(gm.graph)
                param_names = [n.name for n in param_nodes_bw]
            else:
                param_names = graph_param_manager[graph_id].param_names
            input_nodes = get_input_nodes(gm.graph)

            for in_node, in_v in zip(input_nodes, fake_inputs):
                ds_param = in_node.name in param_names
                if ds_param:
                    from torch._subclasses.fake_tensor import is_fake
                    from torch._dynamo.utils import to_fake_tensor
                    assert is_fake(in_v), f"Input {in_v} should be fake tensor"
                    patched_inputs.append(
                        to_fake_tensor(torch.empty([0], dtype=in_v.dtype, device=in_v.device), in_v.fake_mode))
                else:
                    patched_inputs.append(in_v)

            patched_inputs = tuple(patched_inputs)
        else:
            patched_inputs = fake_inputs

        return original_compiler(gm, patched_inputs)

    return wrapped_compiler


def wrap_partition_fn(partition_fn, real_inputs, param_indices):

    def wrapped_partition_fn(*args, **kwargs):

        fw_module, bw_module = partition_fn(*args, **kwargs)

        # get parameter names
        pm = DSGraphParamManager(fw_module.graph, real_inputs, param_indices)

        def fix_placeholder_meta(graph):
            for n in graph.nodes:
                if n.op == "placeholder" and n.name in pm.param_names:
                    n.meta["val"] = torch.empty([0], dtype=n.meta["val"].dtype, device=n.meta["val"].device)

        fix_placeholder_meta(fw_module.graph)
        fix_placeholder_meta(bw_module.graph)

        return fw_module, bw_module

    return wrapped_partition_fn


def patch_create_aot_dispatcher_function(graph_id: int, z3_partition: bool, make_fw_graph, make_bw_graph, real_inputs,
                                         param_indices, param_manager):

    from torch._dynamo.backends.common import AotAutograd
    import functools

    def patch_aotautograd():
        # Unpatch if it was already patched
        if hasattr(AotAutograd, "__original_init"):
            AotAutograd.__init__ = AotAutograd.__original_init

        original_init = AotAutograd.__init__

        @functools.wraps(original_init)
        def patched_init(self, **kwargs):
            kwargs["fw_compiler"] = patch_compiler(kwargs["fw_compiler"],
                                                   make_fw_graph,
                                                   z3_partition,
                                                   graph_id,
                                                   param_manager,
                                                   bwd=False)
            kwargs["bw_compiler"] = patch_compiler(kwargs["bw_compiler"],
                                                   make_bw_graph,
                                                   z3_partition,
                                                   graph_id,
                                                   param_manager,
                                                   bwd=True)
            kwargs["inference_compiler"] = kwargs["fw_compiler"]

            if z3_partition:
                kwargs["partition_fn"] = wrap_partition_fn(kwargs["partition_fn"], real_inputs, param_indices)

            original_init(self, **kwargs)

        AotAutograd.__original_init = original_init
        AotAutograd.__init__ = patched_init

    patch_aotautograd()


def register_custom_ops():

    def fallback_handler_no_reuse(kernel,
                                  never_reuse_input,
                                  never_reuse_output,
                                  force_free_input,
                                  add_to_fallback_set=True):
        if add_to_fallback_set:
            fallbacks.add(kernel)

        def handler(*args, **kwargs):

            def wrap_tensors(x):
                out = TensorBox.create(x) if isinstance(x, torch._inductor.ir.IRNode) else x
                if out is not None and never_reuse_output:
                    V.graph.never_reuse_buffers.add(out.get_name())
                return out

            class CustomDCKernel(FallbackKernel):

                def __init__(self, op, *args, **kwargs):
                    super().__init__(op, *args, **kwargs)

                    def add_to_never_reuse(x):
                        if isinstance(x, IRNode):
                            assert hasattr(x, "get_name"), f"x doesn't have get_name {x.__class__}"
                            V.graph.never_reuse_buffers.add(x.get_name())

                    if never_reuse_input:
                        pytree.tree_map(add_to_never_reuse, args)

                def get_var_name_for_arg(self, arg: str):
                    if arg.isidentifier():
                        return arg

                    import re
                    match = re.match(r"reinterpret_tensor\((\w+),", arg)
                    if match:
                        return match.group(1)
                    return None

                def codegen(self, wrapper):
                    if not force_free_input:
                        return super().codegen(wrapper)

                    kernel = self.op_overload
                    self.codegen_comment(wrapper)
                    args = [*self.codegen_args(), *self.codegen_kwargs()]

                    V.graph.wrapper_code.generate_fallback_kernel(self, args)
                    if isinstance(self.layout, Layout):
                        self.codegen_size_asserts(wrapper)

                    var_name = self.get_var_name_for_arg(args[0])
                    if var_name:
                        wrapper.writeline(f"{var_name} = None")

                    self.codegen_unbacked_symbol_defs(wrapper)

            kernel_cls = CustomDCKernel if force_free_input else FallbackKernel
            return pytree.tree_map(wrap_tensors, kernel_cls.create(kernel, *args, **kwargs))

        return handler

    def register_fallback_no_reuse(op_overload,
                                   never_reuse_input=False,
                                   never_reuse_output=False,
                                   force_free_input=False):
        add_needs_realized_inputs(op_overload)
        return register_lowering(op_overload, type_promotion_kind=None)(fallback_handler_no_reuse(
            op_overload,
            never_reuse_input=never_reuse_input,
            never_reuse_output=never_reuse_output,
            force_free_input=force_free_input))

    # Inductor tries to reuse output buffer when possible. We need to disable this behavior for some custom ops.
    # -> It seems that memory region is still reused in some cases. So we clone the inputs for some ops.
    register_fallback_no_reuse(torch.ops.dc.allgather_param.default, never_reuse_input=False, never_reuse_output=True)
    register_fallback_no_reuse(torch.ops.dc.wait_allgather.default, never_reuse_input=True, never_reuse_output=True)
    register_fallback_no_reuse(torch.ops.dc.release_param.default, never_reuse_input=True, never_reuse_output=False)
    register_fallback_no_reuse(torch.ops.dc.reduce_grad.default,
                               never_reuse_input=True,
                               never_reuse_output=True,
                               force_free_input=True)
    register_fallback_no_reuse(torch.ops.dc.free_tensors.default, never_reuse_input=True, never_reuse_output=True)

    if not hasattr(Scheduler, "is_dc_patched") or not Scheduler.is_dc_patched:
        Scheduler.is_dc_patched = True
        Scheduler.dead_node_elimination = lambda _: None