File size: 2,046 Bytes
9c6594c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

from typing import List

import torch
from torch.fx import Node, GraphModule
from deepspeed.compile.util import get_last_uses
from ..graph_param import DSGraphParamManager


def add_offload_parameter(graph_id: int, gm: GraphModule, node: Node, ds_id: int):
    new_node = None
    with gm.graph.inserting_after(node):
        args = (node, )
        for a in [graph_id, ds_id]:  # To add ds_id
            args += (a, )
        new_node = gm.graph.create_node('call_function',
                                        torch.ops.dc.offload_parameter.default,
                                        args, {},
                                        name="offload_parameter")

    return new_node


def add_reload_parameter(graph_id: int, gm: GraphModule, node: Node, ds_id: int):
    new_node = None
    with gm.graph.inserting_after(node):
        args = (node, )
        for a in [graph_id, ds_id]:  # To add ds_id
            args += (a, )
        new_node = gm.graph.create_node('call_function',
                                        torch.ops.dc.reload_parameter.default,
                                        args, {},
                                        name=f"reload_parameter")
    return new_node


def get_ds_id(node: Node):
    assert node.target == torch.ops.dc.allgather_param.default
    return node.args[2]


def offload_parameter_fwd(gm: GraphModule, graph_id: int, graph_order: List[int], profiling_results, create_inputs_fn,
                          mem_budget: float, param_manager: DSGraphParamManager, bwd: bool) -> GraphModule:
    node_to_last_use, user_to_last_uses = get_last_uses(gm.graph)
    for node in gm.graph.nodes:
        if (isinstance(node, Node) and node.target == torch.ops.dc.allgather_param.default):
            add_reload_parameter(graph_id, gm, node.args[0], get_ds_id(node))
            add_offload_parameter(graph_id, gm, node_to_last_use[node], get_ds_id(node))
    gm.graph.lint()
    return gm