File size: 4,717 Bytes
9c6594c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

from typing import Callable, Any, List
from collections import defaultdict

import torch
from torch.fx import Node, Graph

from .util import get_last_uses


def get_output_node(graph: Graph):
    for v in graph.nodes:
        if v.target == "output":
            return v
    raise ValueError("No output node found")


def move_primals_to_head(graph: Graph):

    # Move primals to the head of the graph
    primals = [n for n in graph.nodes if n.op == "placeholder"]
    non_primals = [n for n in graph.nodes if n.op != "placeholder"]
    all_nodes = primals + non_primals

    new_graph = Graph()
    env = {}
    for node in all_nodes:
        new_node = new_graph.node_copy(node, lambda n: env[n.name])
        env[node.name] = new_node
    new_graph.lint()

    return new_graph


def add_args_process(graph: Graph,
                     node: Node,
                     fn: Callable[..., Any],
                     extra_args: List[int] = [],
                     name=None,
                     meta={}) -> List[Node]:
    # Apply fn to all args of node
    new_nodes = []
    with graph.inserting_before(node):
        target_args = [arg for arg in node.args if isinstance(arg, Node)]

        for arg in target_args:
            new_node = graph.create_node('call_function', fn, (arg, ) + tuple(extra_args), name=name)
            for k, v in meta.items():
                new_node.meta[k] = v
            node.replace_input_with(arg, new_node)
            new_nodes.append(new_node)

    return new_nodes


def add_postprocess(graph: Graph,
                    node: Node,
                    fn: Callable[..., Any],
                    extra_args: List[int] = [],
                    name=None,
                    meta={}) -> Node:
    # https://github.com/pytorch/examples/blob/main/fx/wrap_output_dynamically.py
    with graph.inserting_after(node):
        args = (node, )
        for a in extra_args:  # To add ds_id
            args += (a, )

        node_users = node.users.keys()
        new_node = graph.create_node('call_function', fn, args, {}, name=name)
        users = {}
        for u in node_users:
            if u != new_node:
                users[u] = (node, new_node)
        for u, (old_in, new_in) in users.items():
            u.replace_input_with(old_in, new_in)

    for k, v in meta.items():
        new_node.meta[k] = v

    return new_node


def _make_node_meta(node: Node, ds_id: int, comm: bool):
    meta = {"param_name": node.name, "ds_id": ds_id, "comm": comm}
    if "tensor_meta" in node.meta:
        meta["tensor_meta"] = node.meta["tensor_meta"]
    return meta


def add_free_activations(graph_id: int, graph: Graph, activation_node_names: List[str]):
    node_to_last_use, _ = get_last_uses(graph)
    activation_nodes_set = set([n for n in graph.nodes if n.op == "placeholder" and n.name in activation_node_names])

    offload_id_to_node = {}
    node_to_wait_reload = {}
    for node in graph.nodes:
        if node.target == torch.ops.dc.reload_tensor.default:
            offload_act = node.args[0]
            # node_to_offload_id[offload_act] = node.args[2]
            offload_id_to_node[node.args[2]] = offload_act
        elif node.target == torch.ops.dc.wait_reload.default:
            offload_id = node.args[2]
            node_to_wait_reload[offload_id_to_node[offload_id]] = node

    activation_nodes_set = set(node_to_wait_reload[n] if n in node_to_wait_reload else n for n in activation_nodes_set)

    last_user_to_uses = defaultdict(list)
    for node, last_user in node_to_last_use.items():
        last_user_to_uses[last_user].append(node)

    def _should_free(node: Node) -> bool:
        if not hasattr(node, "meta"):
            return False
        if not "tensor_meta" in node.meta:
            return False
        return True

    def free_tensors(tensors: List[torch.Tensor]):
        for a in tensors:
            if a.numel() > 10_000_000:
                a.data = torch.empty([0], device=a.device, dtype=a.dtype)

    for last_user, used_nodes in last_user_to_uses.items():
        activation_args = [an for an in used_nodes if an in activation_nodes_set and _should_free(an)]

        if len(activation_args) == 0:
            continue

        node_name = f"free_activations_{[n.name for n in used_nodes]}"
        with graph.inserting_after(last_user):
            args = (activation_args, )
            graph.create_node('call_function', torch.ops.dc.free_tensors.default, args, {}, name=node_name)

            # Python version for debugging
            # graph.create_node('call_function', free_tensors, args, {}, name=node_name)