File size: 6,902 Bytes
9c6594c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 |
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
# This file was copied from PyTorch and modified for DeepSpeed.
from typing import Tuple, List
import operator
import torch
from torch.fx import GraphModule, Graph, Node
try:
from torch._functorch.partitioners import is_sym_node, _is_primal, _is_fwd_seed_offset, _extract_fwd_bwd_outputs, _extract_graph_with_inputs_outputs, _extract_fwd_bwd_modules, has_recomputable_ops, min_cut_rematerialization_partition, choose_saved_values_set
except ImportError:
pass
from .util import get_no_copy_ops
_recompute_ops = {torch.ops.aten.t.default}
def _find_recompute_nodes(graph: Graph, ds_param_node: Node) -> List[Node]:
"""
Given a graph and a node that represents a parameter that was allgathered,
find all nodes that use the parameter and require recomputation.
"""
no_copy_ops = get_no_copy_ops()
recompute_nodes = set()
for node in graph.nodes:
if node.target in no_copy_ops:
if ds_param_node in node.args:
recompute_nodes.add(node)
if any(a in recompute_nodes for a in node.args):
recompute_nodes.add(node)
return recompute_nodes
def _get_values_from_ds_params(joint_graph, param_indices):
primal_inputs = list(filter(_is_primal, joint_graph.nodes))
ds_param_inputs = [primal_inputs[arg_idx] for arg_idx, _, _ in param_indices]
no_copy_ops = get_no_copy_ops()
ds_param_inputs = set(ds_param_inputs)
ds_param_users = {}
for node in joint_graph.nodes:
if node.target in no_copy_ops and any((a in ds_param_inputs or a in ds_param_users) for a in node.args):
for a in node.args:
if a in ds_param_inputs:
ds_param_users[node] = a
elif a in ds_param_users:
ds_param_users[node] = ds_param_users[a]
return ds_param_users
def get_wrapped_choose_saved_values_set(param_indices: List[Tuple[int, int, torch.Size]]):
def ds_choose_saved_values_set(joint_graph: torch.fx.Graph, node_info, memory_budget=1) -> List[Node]:
saved_values = choose_saved_values_set(joint_graph, node_info, memory_budget)
ds_param_users = _get_values_from_ds_params(joint_graph, param_indices)
new_saved_values = []
for v in saved_values:
if v in ds_param_users:
ds_val = ds_param_users[v]
if ds_val not in new_saved_values:
new_saved_values.append(ds_val)
else:
new_saved_values.append(v)
return new_saved_values
return ds_choose_saved_values_set
def get_wrapped_partitioner(param_indices: List[Tuple[int, int, torch.Size]]):
def partition_recompute_ds_params(joint_module: GraphModule, _joint_inputs, *,
num_fwd_outputs) -> Tuple[GraphModule, GraphModule]:
"""
This is basically the same as the default_partition function, but
it doesn't save the gathered params and values computed from them.
"""
if has_recomputable_ops(joint_module):
return min_cut_rematerialization_partition(joint_module, _joint_inputs, num_fwd_outputs=num_fwd_outputs)
primal_inputs = list(filter(_is_primal, joint_module.graph.nodes))
fwd_seed_offset_inputs = list(filter(_is_fwd_seed_offset, joint_module.graph.nodes))
inputs = primal_inputs + fwd_seed_offset_inputs
fwd_outputs, bwd_outputs = _extract_fwd_bwd_outputs(joint_module, num_fwd_outputs=num_fwd_outputs)
forward_only_graph = _extract_graph_with_inputs_outputs(joint_module.graph, inputs, fwd_outputs, "forward")
forward_node_names = {node.name for node in forward_only_graph.nodes if node.op != "output"}
saved_values = []
saved_sym_nodes = []
fwd_inputs = list(filter(_is_primal, forward_only_graph.nodes))
ds_param_inputs = [fwd_inputs[arg_idx] for arg_idx, _, _ in param_indices]
ds_param_input_names = {node.name for node in ds_param_inputs}
ds_param_recompute_nodes = set()
for node in joint_module.graph.nodes:
if node.name not in forward_node_names:
continue
if is_sym_node(node):
# Symints must be kept separate from tensors so that PythonFunction only calls
# save_for_backward on tensors and stashes symints in autograd .ctx
saved_sym_nodes.append(node)
elif "tensor_meta" not in node.meta and node.op == "call_function":
# Since we can't save tuple of tensor values, we need to flatten out what we're saving
users = node.users
assert all(user.target == operator.getitem for user in users)
saved_values.extend(users)
else:
backward_usages = [n for n in node.users if n.name not in forward_node_names]
if "tensor_meta" in node.meta and all(is_sym_node(n) for n in backward_usages):
# If we have a tensor in the forward, where only its sizes/strides are needed in the backward,
# and not the actual tensor data,
# then it will be a lot cheaper to save only the sizes/strides, and not the actual tensor.
#
# Note that saving the tensor could also cause compilation problems:
# If the user mutated an input in the forward and uses its sizes/strides in the backward,
# then we would be obligated to clone the input before saving it to appease autograd.
# (This is how we originally found this bug).
saved_sym_nodes.extend(backward_usages)
if node.name in ds_param_input_names:
saved_values.append(node)
recompute_nodes = _find_recompute_nodes(joint_module.graph, node)
recompute_nodes = [n for n in recompute_nodes if n.name in forward_node_names]
for recompute_node in recompute_nodes:
ds_param_recompute_nodes.add(recompute_node)
if len(recompute_nodes) > 0:
saved_values.append(node)
else:
if node not in ds_param_recompute_nodes:
saved_values.append(node)
saved_values = list(dict.fromkeys(saved_values).keys())
saved_sym_nodes = list(dict.fromkeys(saved_sym_nodes).keys())
f_gm, b_gm = _extract_fwd_bwd_modules(
joint_module,
saved_values,
saved_sym_nodes=saved_sym_nodes,
num_fwd_outputs=num_fwd_outputs,
)
return f_gm, b_gm
return partition_recompute_ds_params
|