jamtur01's picture
Upload folder using huggingface_hub
9c6594c verified
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
# This file was copied from PyTorch and modified for DeepSpeed.
from typing import Tuple, List
import operator
import torch
from torch.fx import GraphModule, Graph, Node
try:
from torch._functorch.partitioners import is_sym_node, _is_primal, _is_fwd_seed_offset, _extract_fwd_bwd_outputs, _extract_graph_with_inputs_outputs, _extract_fwd_bwd_modules, has_recomputable_ops, min_cut_rematerialization_partition, choose_saved_values_set
except ImportError:
pass
from .util import get_no_copy_ops
_recompute_ops = {torch.ops.aten.t.default}
def _find_recompute_nodes(graph: Graph, ds_param_node: Node) -> List[Node]:
"""
Given a graph and a node that represents a parameter that was allgathered,
find all nodes that use the parameter and require recomputation.
"""
no_copy_ops = get_no_copy_ops()
recompute_nodes = set()
for node in graph.nodes:
if node.target in no_copy_ops:
if ds_param_node in node.args:
recompute_nodes.add(node)
if any(a in recompute_nodes for a in node.args):
recompute_nodes.add(node)
return recompute_nodes
def _get_values_from_ds_params(joint_graph, param_indices):
primal_inputs = list(filter(_is_primal, joint_graph.nodes))
ds_param_inputs = [primal_inputs[arg_idx] for arg_idx, _, _ in param_indices]
no_copy_ops = get_no_copy_ops()
ds_param_inputs = set(ds_param_inputs)
ds_param_users = {}
for node in joint_graph.nodes:
if node.target in no_copy_ops and any((a in ds_param_inputs or a in ds_param_users) for a in node.args):
for a in node.args:
if a in ds_param_inputs:
ds_param_users[node] = a
elif a in ds_param_users:
ds_param_users[node] = ds_param_users[a]
return ds_param_users
def get_wrapped_choose_saved_values_set(param_indices: List[Tuple[int, int, torch.Size]]):
def ds_choose_saved_values_set(joint_graph: torch.fx.Graph, node_info, memory_budget=1) -> List[Node]:
saved_values = choose_saved_values_set(joint_graph, node_info, memory_budget)
ds_param_users = _get_values_from_ds_params(joint_graph, param_indices)
new_saved_values = []
for v in saved_values:
if v in ds_param_users:
ds_val = ds_param_users[v]
if ds_val not in new_saved_values:
new_saved_values.append(ds_val)
else:
new_saved_values.append(v)
return new_saved_values
return ds_choose_saved_values_set
def get_wrapped_partitioner(param_indices: List[Tuple[int, int, torch.Size]]):
def partition_recompute_ds_params(joint_module: GraphModule, _joint_inputs, *,
num_fwd_outputs) -> Tuple[GraphModule, GraphModule]:
"""
This is basically the same as the default_partition function, but
it doesn't save the gathered params and values computed from them.
"""
if has_recomputable_ops(joint_module):
return min_cut_rematerialization_partition(joint_module, _joint_inputs, num_fwd_outputs=num_fwd_outputs)
primal_inputs = list(filter(_is_primal, joint_module.graph.nodes))
fwd_seed_offset_inputs = list(filter(_is_fwd_seed_offset, joint_module.graph.nodes))
inputs = primal_inputs + fwd_seed_offset_inputs
fwd_outputs, bwd_outputs = _extract_fwd_bwd_outputs(joint_module, num_fwd_outputs=num_fwd_outputs)
forward_only_graph = _extract_graph_with_inputs_outputs(joint_module.graph, inputs, fwd_outputs, "forward")
forward_node_names = {node.name for node in forward_only_graph.nodes if node.op != "output"}
saved_values = []
saved_sym_nodes = []
fwd_inputs = list(filter(_is_primal, forward_only_graph.nodes))
ds_param_inputs = [fwd_inputs[arg_idx] for arg_idx, _, _ in param_indices]
ds_param_input_names = {node.name for node in ds_param_inputs}
ds_param_recompute_nodes = set()
for node in joint_module.graph.nodes:
if node.name not in forward_node_names:
continue
if is_sym_node(node):
# Symints must be kept separate from tensors so that PythonFunction only calls
# save_for_backward on tensors and stashes symints in autograd .ctx
saved_sym_nodes.append(node)
elif "tensor_meta" not in node.meta and node.op == "call_function":
# Since we can't save tuple of tensor values, we need to flatten out what we're saving
users = node.users
assert all(user.target == operator.getitem for user in users)
saved_values.extend(users)
else:
backward_usages = [n for n in node.users if n.name not in forward_node_names]
if "tensor_meta" in node.meta and all(is_sym_node(n) for n in backward_usages):
# If we have a tensor in the forward, where only its sizes/strides are needed in the backward,
# and not the actual tensor data,
# then it will be a lot cheaper to save only the sizes/strides, and not the actual tensor.
#
# Note that saving the tensor could also cause compilation problems:
# If the user mutated an input in the forward and uses its sizes/strides in the backward,
# then we would be obligated to clone the input before saving it to appease autograd.
# (This is how we originally found this bug).
saved_sym_nodes.extend(backward_usages)
if node.name in ds_param_input_names:
saved_values.append(node)
recompute_nodes = _find_recompute_nodes(joint_module.graph, node)
recompute_nodes = [n for n in recompute_nodes if n.name in forward_node_names]
for recompute_node in recompute_nodes:
ds_param_recompute_nodes.add(recompute_node)
if len(recompute_nodes) > 0:
saved_values.append(node)
else:
if node not in ds_param_recompute_nodes:
saved_values.append(node)
saved_values = list(dict.fromkeys(saved_values).keys())
saved_sym_nodes = list(dict.fromkeys(saved_sym_nodes).keys())
f_gm, b_gm = _extract_fwd_bwd_modules(
joint_module,
saved_values,
saved_sym_nodes=saved_sym_nodes,
num_fwd_outputs=num_fwd_outputs,
)
return f_gm, b_gm
return partition_recompute_ds_params