File size: 3,849 Bytes
9c6594c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

from typing import List, Dict, Set, Tuple
import random
from collections import defaultdict

import torch
from torch.fx import Graph, Node

from ..fx import get_output_node, move_primals_to_head
from ..graph_param import DSGraphParamManager

value_to_id: Dict[int, Dict[str, int]] = defaultdict(dict)
used_ids: Set[int] = set()


def get_random_id() -> int:

    def _gen():
        # generate random int
        return random.randint(10000, 2**31)

    global used_ids
    v = _gen()
    while v in used_ids:
        v = _gen()
    used_ids.add(v)
    return v


def _should_offload(node: Node) -> bool:
    if not hasattr(node, "meta"):
        return False
    if not "tensor_meta" in node.meta:
        return False

    return True


def offload_activation_fwd(graph: Graph, graph_id: int, nodes_to_offload_with_names: List[Tuple[str, Node]],
                           graph_order: List[int], mem_budget: float, param_manager: DSGraphParamManager) -> Graph:
    param_names = set(param_manager.param_names)

    import copy
    cl_graph = copy.deepcopy(graph)
    cl_graph.erase_node(get_output_node(cl_graph))

    global value_to_id
    for name, node in nodes_to_offload_with_names:
        if node.name in param_names:
            continue

        if not _should_offload(node):
            continue

        val_id = get_random_id()
        with graph.inserting_after(node):
            offload_node = graph.create_node('call_function',
                                             torch.ops.dc.offload_tensor.default, (node, graph_id, val_id), {},
                                             name=f"offload_{node.name}_{val_id}")
        with graph.inserting_after(offload_node):
            wait_node = graph.create_node('call_function',
                                          torch.ops.dc.wait_offload.default, (offload_node, graph_id, val_id), {},
                                          name=f"wait_copy_{node.name}_{val_id}")

        output_node = get_output_node(graph)
        output_node.replace_input_with(node, wait_node)

        value_to_id[graph_id][name] = val_id

    graph = move_primals_to_head(graph)

    graph.lint()
    return graph


def reload_activation_bwd(graph: Graph, graph_id: int, graph_order: List[int], mem_budget: float,
                          param_manager: DSGraphParamManager) -> Graph:

    graph_value_to_id = value_to_id[graph_id]
    name_to_node = {n.name: n for n in graph.nodes}
    act_nodes = [name_to_node[n] for n in graph_value_to_id.keys()]

    node_to_first_user = {}
    for act in act_nodes:
        for node in graph.nodes:
            if act in node.args:
                node_to_first_user[act] = node
                break

    for node in act_nodes:
        val_id = graph_value_to_id[node.name]

        with graph.inserting_before(node_to_first_user[node]):
            reload_node = graph.create_node('call_function',
                                            torch.ops.dc.reload_tensor.default, (node, graph_id, val_id), {},
                                            name=f"reload_{node.name}_{val_id}")
        with graph.inserting_after(reload_node):
            wait_node = graph.create_node('call_function',
                                          torch.ops.dc.wait_reload.default, (reload_node, graph_id, val_id), {},
                                          name=f"wait_copy_{node.name}_{val_id}")

        # replace all uses of node with wait_node
        users = {}
        for u in node.users.keys():
            if u != reload_node:
                users[u] = (node, wait_node)
        for u, (old_in, new_in) in users.items():
            u.replace_input_with(old_in, new_in)

    graph = move_primals_to_head(graph)
    graph.lint()
    return graph