File size: 4,717 Bytes
9c6594c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from typing import Callable, Any, List
from collections import defaultdict
import torch
from torch.fx import Node, Graph
from .util import get_last_uses
def get_output_node(graph: Graph):
for v in graph.nodes:
if v.target == "output":
return v
raise ValueError("No output node found")
def move_primals_to_head(graph: Graph):
# Move primals to the head of the graph
primals = [n for n in graph.nodes if n.op == "placeholder"]
non_primals = [n for n in graph.nodes if n.op != "placeholder"]
all_nodes = primals + non_primals
new_graph = Graph()
env = {}
for node in all_nodes:
new_node = new_graph.node_copy(node, lambda n: env[n.name])
env[node.name] = new_node
new_graph.lint()
return new_graph
def add_args_process(graph: Graph,
node: Node,
fn: Callable[..., Any],
extra_args: List[int] = [],
name=None,
meta={}) -> List[Node]:
# Apply fn to all args of node
new_nodes = []
with graph.inserting_before(node):
target_args = [arg for arg in node.args if isinstance(arg, Node)]
for arg in target_args:
new_node = graph.create_node('call_function', fn, (arg, ) + tuple(extra_args), name=name)
for k, v in meta.items():
new_node.meta[k] = v
node.replace_input_with(arg, new_node)
new_nodes.append(new_node)
return new_nodes
def add_postprocess(graph: Graph,
node: Node,
fn: Callable[..., Any],
extra_args: List[int] = [],
name=None,
meta={}) -> Node:
# https://github.com/pytorch/examples/blob/main/fx/wrap_output_dynamically.py
with graph.inserting_after(node):
args = (node, )
for a in extra_args: # To add ds_id
args += (a, )
node_users = node.users.keys()
new_node = graph.create_node('call_function', fn, args, {}, name=name)
users = {}
for u in node_users:
if u != new_node:
users[u] = (node, new_node)
for u, (old_in, new_in) in users.items():
u.replace_input_with(old_in, new_in)
for k, v in meta.items():
new_node.meta[k] = v
return new_node
def _make_node_meta(node: Node, ds_id: int, comm: bool):
meta = {"param_name": node.name, "ds_id": ds_id, "comm": comm}
if "tensor_meta" in node.meta:
meta["tensor_meta"] = node.meta["tensor_meta"]
return meta
def add_free_activations(graph_id: int, graph: Graph, activation_node_names: List[str]):
node_to_last_use, _ = get_last_uses(graph)
activation_nodes_set = set([n for n in graph.nodes if n.op == "placeholder" and n.name in activation_node_names])
offload_id_to_node = {}
node_to_wait_reload = {}
for node in graph.nodes:
if node.target == torch.ops.dc.reload_tensor.default:
offload_act = node.args[0]
# node_to_offload_id[offload_act] = node.args[2]
offload_id_to_node[node.args[2]] = offload_act
elif node.target == torch.ops.dc.wait_reload.default:
offload_id = node.args[2]
node_to_wait_reload[offload_id_to_node[offload_id]] = node
activation_nodes_set = set(node_to_wait_reload[n] if n in node_to_wait_reload else n for n in activation_nodes_set)
last_user_to_uses = defaultdict(list)
for node, last_user in node_to_last_use.items():
last_user_to_uses[last_user].append(node)
def _should_free(node: Node) -> bool:
if not hasattr(node, "meta"):
return False
if not "tensor_meta" in node.meta:
return False
return True
def free_tensors(tensors: List[torch.Tensor]):
for a in tensors:
if a.numel() > 10_000_000:
a.data = torch.empty([0], device=a.device, dtype=a.dtype)
for last_user, used_nodes in last_user_to_uses.items():
activation_args = [an for an in used_nodes if an in activation_nodes_set and _should_free(an)]
if len(activation_args) == 0:
continue
node_name = f"free_activations_{[n.name for n in used_nodes]}"
with graph.inserting_after(last_user):
args = (activation_args, )
graph.create_node('call_function', torch.ops.dc.free_tensors.default, args, {}, name=node_name)
# Python version for debugging
# graph.create_node('call_function', free_tensors, args, {}, name=node_name)
|