File size: 1,839 Bytes
9c6594c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

from typing import List

import torch
from torch.fx import GraphModule

from ..util import get_deepcompile_handle
from ..fx import add_postprocess, move_primals_to_head, _make_node_meta

NAME = "zero1_compile"


def add_z1_reduce_fw(gm: GraphModule, graph_id: int, profiling_results, param_manager) -> GraphModule:

    dc = get_deepcompile_handle()
    param_indices = profiling_results[graph_id].param_indices
    dc.register_graph_z1(graph_id, [v[1] for v in param_indices])  # Need this before profiling

    return gm


def add_z1_reduce_bw(gm: GraphModule, graph_id: int, param_manager) -> GraphModule:

    graph = gm.graph
    pm = param_manager[graph_id]
    _, param_name_to_grad = pm.get_bwd_mapping(graph)

    for param_name in pm.param_names:

        grad_node = param_name_to_grad[param_name]

        assert param_name in pm.ds_ids, f"param_name={param_name} not in ds_ids"
        ds_id = pm.ds_ids[param_name]

        new_node = add_postprocess(graph,
                                   grad_node,
                                   torch.ops.dc.reduce_grad.default,
                                   extra_args=[graph_id, ds_id],
                                   name=f"reduce_param_{param_name}",
                                   meta=_make_node_meta(grad_node, param_name, True))
        new_node.meta["val"] = None

    gm.graph = move_primals_to_head(graph)
    return gm


def add_z1_reduce(gm: GraphModule, graph_id: int, graph_order: List[int], profiling_results, create_inputs_fn,
                  mem_budget: float, param_manager, bwd: bool) -> GraphModule:
    if bwd:
        return add_z1_reduce_bw(gm, graph_id, param_manager)
    return add_z1_reduce_fw(gm, graph_id, profiling_results, param_manager)