|
|
|
import _operator |
|
import itertools |
|
from collections import defaultdict |
|
from enum import Enum |
|
|
|
import torch |
|
from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode |
|
from torch.fx import Node |
|
from torch.fx._compatibility import compatibility |
|
from torch.multiprocessing.reductions import StorageWeakRef |
|
from torch.utils import _pytree as pytree |
|
from torch.utils._pytree import tree_map_only |
|
|
|
|
|
__all__ = ["reinplace"] |
|
|
|
|
|
class _ViewType(Enum): |
|
NonView = 0 |
|
SingleOutputView = 1 |
|
MultiOutputView = 2 |
|
|
|
|
|
def _is_view_op(tgt): |
|
if tgt is not None and isinstance(tgt, torch._ops.OpOverload): |
|
schema = tgt._schema |
|
if len(schema.arguments) > 0: |
|
first_arg = schema.arguments[0] |
|
|
|
return ( |
|
first_arg.alias_info is not None and not first_arg.alias_info.is_write |
|
) |
|
|
|
|
|
def _get_view_type(tgt) -> _ViewType: |
|
if tgt is not None and isinstance(tgt, torch._ops.OpOverload): |
|
schema = tgt._schema |
|
if len(schema.arguments) > 0: |
|
first_arg = schema.arguments[0] |
|
|
|
if first_arg.alias_info is not None and not first_arg.alias_info.is_write: |
|
|
|
if "*" in first_arg.alias_info.after_set: |
|
return _ViewType.MultiOutputView |
|
else: |
|
return _ViewType.SingleOutputView |
|
return _ViewType.NonView |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@compatibility(is_backward_compatible=False) |
|
class _FunctionalizationMetadataProp(torch.fx.Interpreter): |
|
def run_node(self, node: Node): |
|
self.node_counter += 1 |
|
result = super().run_node(node) |
|
node.meta["fake_result"] = result |
|
node.meta["node_idx"] = self.node_counter |
|
|
|
|
|
|
|
|
|
node_args = node.args |
|
if node.target is torch.ops.aten.copy_.default: |
|
node_args = node_args[1:] |
|
|
|
|
|
if node.op == "call_function": |
|
view_type = _get_view_type(node.target) |
|
if view_type == _ViewType.SingleOutputView: |
|
assert isinstance(node.args[0], Node) |
|
node.meta["view_of"] = node.args[0] |
|
elif view_type == _ViewType.MultiOutputView: |
|
self.multi_output_view_nodes[node] = node.args[0] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
elif node.target is _operator.getitem: |
|
list_arg = node.args[0] |
|
maybe_base_of_view = self.multi_output_view_nodes.get(list_arg, None) |
|
if maybe_base_of_view is not None: |
|
|
|
|
|
assert isinstance(maybe_base_of_view, Node) |
|
node.meta["view_of"] = maybe_base_of_view |
|
|
|
if "view_of" in node.meta: |
|
|
|
|
|
assert isinstance(node.meta["fake_result"], FakeTensor) |
|
assert isinstance(node.meta["view_of"].meta["fake_result"], FakeTensor) |
|
view_storage = StorageWeakRef(node.meta["fake_result"]._typed_storage()) |
|
base_storage = StorageWeakRef( |
|
node.meta["view_of"].meta["fake_result"]._typed_storage() |
|
) |
|
assert view_storage == base_storage |
|
return result |
|
|
|
def propagate(self, *args): |
|
self.multi_output_view_nodes = {} |
|
self.node_counter = -1 |
|
|
|
with FakeTensorMode() as mode: |
|
fake_args = [ |
|
mode.from_tensor(a) if isinstance(a, torch.Tensor) else a for a in args |
|
] |
|
return super().run(*fake_args) |
|
|
|
|
|
def _schemas_match(functional_schema, inplace_schema): |
|
names_match = ( |
|
inplace_schema.name.endswith("_") |
|
and inplace_schema.name[:-1] == functional_schema.name |
|
) |
|
arg_types_match = len(functional_schema.arguments) == len( |
|
inplace_schema.arguments |
|
) and all( |
|
a1.type == a2.type |
|
for a1, a2 in zip(functional_schema.arguments, inplace_schema.arguments) |
|
) |
|
|
|
assert ( |
|
inplace_schema.arguments[0].alias_info is not None |
|
and inplace_schema.arguments[0].alias_info.is_write |
|
) |
|
|
|
assert all(a.alias_info is None for a in inplace_schema.arguments[1:]) |
|
return names_match and arg_types_match |
|
|
|
|
|
|
|
|
|
|
|
|
|
def _maybe_get_inplace_op(op): |
|
|
|
if not isinstance(op, torch._ops.OpOverload): |
|
return None |
|
|
|
|
|
|
|
if _is_view_op(op): |
|
return None |
|
op_namespace = op.__module__.split(".")[-1] |
|
op_base_name = op.overloadpacket.__name__ |
|
maybe_namespace_module = getattr(torch.ops, op_namespace) |
|
maybe_inplace_op = ( |
|
None |
|
if maybe_namespace_module is None |
|
else getattr(maybe_namespace_module, f"{op_base_name}_", None) |
|
) |
|
if maybe_inplace_op is None: |
|
return None |
|
|
|
inplace_overloads = [ |
|
getattr(maybe_inplace_op, overload_name) |
|
for overload_name in maybe_inplace_op.overloads() |
|
] |
|
inplace_overloads_with_matching_schemas = [ |
|
f for f in inplace_overloads if _schemas_match(op._schema, f._schema) |
|
] |
|
|
|
|
|
|
|
|
|
if len(inplace_overloads_with_matching_schemas) == 0: |
|
return None |
|
assert len(inplace_overloads_with_matching_schemas) == 1 |
|
inplace_op = inplace_overloads_with_matching_schemas[0] |
|
return inplace_op |
|
|
|
|
|
_VIEW_INVERSE_MAP = { |
|
torch.ops.aten.diagonal_scatter.default: torch.ops.aten.diagonal.default, |
|
torch.ops.aten.select_scatter.default: torch.ops.aten.select.int, |
|
torch.ops.aten.slice_scatter.default: torch.ops.aten.slice.Tensor, |
|
torch.ops.aten.as_strided_scatter.default: torch.ops.aten.as_strided.default, |
|
} |
|
|
|
|
|
|
|
|
|
|
|
def _get_all_later_node_usages(tensor_aliases: set[Node], op_index: int): |
|
def _add_if_tensor(x, set_): |
|
if isinstance(x, FakeTensor): |
|
set_.add(StorageWeakRef(x._typed_storage())) |
|
|
|
nodes_used_after = set() |
|
for t in tensor_aliases: |
|
|
|
usage_nodes = t.users |
|
for n in usage_nodes: |
|
|
|
if "node_idx" not in n.meta or n.meta["node_idx"] <= op_index: |
|
continue |
|
|
|
|
|
|
|
if n in tensor_aliases: |
|
if ( |
|
isinstance(n.target, torch._ops.OpOverload) |
|
or n.target == _operator.getitem |
|
): |
|
continue |
|
nodes_used_after.add(n) |
|
return nodes_used_after |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _get_view_inverse_node_usages( |
|
later_node_usages: set[Node], self_aliases: set[Node] |
|
) -> set[Node]: |
|
def matching_view_metadata(a, b): |
|
return ( |
|
a.size() == b.size() |
|
and a.stride() == b.stride() |
|
and a.storage_offset() == b.storage_offset() |
|
) |
|
|
|
view_inverse_nodes = set() |
|
|
|
for n in sorted(later_node_usages, key=lambda x: x.meta["node_idx"]): |
|
if n.target not in _VIEW_INVERSE_MAP: |
|
continue |
|
base = n.args[0] |
|
mutated_view = n.args[1] |
|
assert isinstance(base, Node) |
|
assert isinstance(base.meta["fake_result"], FakeTensor) |
|
assert isinstance(mutated_view, Node) |
|
assert isinstance(mutated_view.meta["fake_result"], FakeTensor) |
|
|
|
|
|
original_view = _VIEW_INVERSE_MAP[n.target] |
|
for self_alias in self_aliases: |
|
|
|
|
|
|
|
|
|
if "view_of" not in self_alias.meta: |
|
continue |
|
self_alias_base = self_alias.meta["view_of"] |
|
try: |
|
|
|
|
|
|
|
view_replay_metadata = original_view( |
|
self_alias_base.meta["fake_result"], *n.args[2:], **n.kwargs |
|
) |
|
expected_metadata = self_alias.meta["fake_result"] |
|
|
|
if matching_view_metadata( |
|
self_alias_base.meta["fake_result"], base.meta["fake_result"] |
|
) and matching_view_metadata(view_replay_metadata, expected_metadata): |
|
view_inverse_nodes.add(n) |
|
except Exception: |
|
continue |
|
|
|
return view_inverse_nodes |
|
|
|
|
|
@compatibility(is_backward_compatible=True) |
|
def reinplace(gm, *sample_args): |
|
""" |
|
Given an fx.GraphModule, modifies it to perform "reinplacing", |
|
mutating the nodes of the graph. |
|
We look for out-of-place op call sites like `b = a.add(...)`, |
|
and convert them to be inplace (`b = a.add_(...)`), |
|
as long as the input to the current operator ("a") isn't re-used |
|
anywhere later in the graph. |
|
|
|
This pass currently expects to operate on a **functional, ATen** graph. |
|
This can be obtained by running `make_fx(functionalize(f))`. |
|
|
|
Sample inputs are needed to determine aliasing relationships of the inputs. |
|
In general, we can't reinplace node `b = a.add(...)` if "a" aliases any of the |
|
inputs to the program. |
|
|
|
Given a node "b = foo(a, args...) the algorithm for re-inplacing is as follows: |
|
|
|
(1) Perform some initial checks on the metadata of "a" and "args..." |
|
that can disqualify them from being reinplaced. |
|
|
|
(1a) Check that the self argument we're attempting to reinplace |
|
has acceptable dtype/size metadata to reinplace with. |
|
|
|
For example, if we have: |
|
a = torch.ones(1) |
|
b = torch.ones(10) |
|
out = torch.add(a, b) |
|
We can't turn that into |
|
a.add_(b) |
|
Because that would require resizing "a". |
|
|
|
Similarly, we can't convert torch.ge(a, b) into a.ge_(b), |
|
because that would require changing a's dtype (from e.g. float32 to bool). |
|
Note that in this specific example, we could technically do better.. |
|
|
|
If we see the pattern: |
|
a_1 = a.ge(b) |
|
a_2 = aten._to_copy(a_1, a.dtype) |
|
Then we this should be valid to completely re-inplace |
|
(this is exactly what functionalization will emit when it sees a.ge_(b)). |
|
|
|
This optimization is only really important for user programs |
|
that directly use inplace comparison ops though. |
|
|
|
We also cannot re-inplace on tensors that have overlapping memory, |
|
e.g. torch.ones(1).expand(4, 4).add_(1) |
|
|
|
(1b) Check if "a" is an alias of any of the program inputs. |
|
|
|
If it is, skip and move to the next node. |
|
Inplace'ing an op that would cause it to mutate a program is not sound, |
|
because that would be a side effect visible to the user. |
|
|
|
NOTE: there's a future optimization that we should make: |
|
if "a" is a (alias of a) program input, but later in the program |
|
there is a node that looks like "a.copy_(...)", |
|
Then re-inplacing is ok to do - we are temporarily re-using a's buffer, |
|
which will later be overwritten by the copy_() call. |
|
|
|
This will be an important optimization to have for programs that mutate |
|
their inputs. It currently isn't implemented though. |
|
|
|
(1c) Check if "a" and "args..." alias |
|
|
|
For example, re-inplacing to create code like the below |
|
isn't guaranteed to be sound: |
|
|
|
aten.mul_(a, a) |
|
|
|
(2) Check that "a" and all of its outstanding aliases are not used anywhere |
|
later in the graph. If this is the case, then it's safe to re-inplace |
|
to "b = foo_(a)". |
|
|
|
There are a few caveats to this, explained in more detail below: |
|
(a) If "a" is used later as an argument to a view op, that is okay. |
|
It's only a problem if "a" (or that view) is later passed |
|
into a normal operator, or if it is returned as the program output. |
|
(b) If "a" is a repeat argument in `foo()`, then don't reinplace. |
|
Most ATen kernels don't make any guarantees that this is sound, |
|
e.g. if you do aten.mul_(a, a). |
|
So we'll just ban re-inplacing in this case. |
|
It's only a problem if "a" (or that view) is later passed |
|
(c) If "a" is used as an input into a view "inverse" / "scatter" |
|
operator, it is potentially fine to re-inplace |
|
(and remove that scatter operator from the graph). |
|
See below for a more detailed example. |
|
|
|
NOTE: there is an optimization in this step that is crucial |
|
to fully recovering performance from functionalization. |
|
|
|
Given this program: |
|
def f(x): |
|
a = torch.ops.aten.add(x, x) |
|
b = torch.ops.aten.diagonal(a) |
|
torch.ops.aten.fill_(b, 0) |
|
return d |
|
|
|
Functionalization will emit the following: |
|
def f(x): |
|
a = torch.ops.aten.add(x, x) |
|
b = torch.ops.aten.diagonal(a, 0, 1) |
|
b_updated = torch.ops.aten.fill(b, 0) |
|
a_updated = torch.ops.aten.diagonal_scatter(a, b_updated, 0, 1) |
|
return a_updated |
|
|
|
Ordinarily, we would not be able to reinplace the fill, |
|
because "b" aliases with "a" which is used by the diagonal_scatter call. |
|
|
|
"re-inplacing" is on the hook for figuring out that it is ok to |
|
completely, the expensive diagonal_scatter call, if we re-inplace the add(). |
|
|
|
So, for every `alias in alias_set(a)`, instead of checking |
|
that "alias" is not used anywhere later in the graph, |
|
we check that |
|
EITHER: |
|
(a) alias is not used anywhere later in the graph |
|
OR: |
|
(b) alias is used exactly once later on in the graph, |
|
in the following op: |
|
|
|
out = foo_scatter(alias, x, args...) |
|
|
|
where the following must hold: |
|
(i) "foo_scatter" is the "inverse" operator for foo. |
|
This only applies to "foo" ops that are view operators, |
|
which view into a subset of the original tensor's memory. |
|
In practice, there are ~4 operators where this applies: |
|
diagonal -> diagonal_scatter |
|
slice -> slice_scatter |
|
select -> select_scatter |
|
as_strided -> as_strided_scatter |
|
(ii) "args..." are the same between the foo() and foo_scatter() calls. |
|
|
|
(3) Perform the actual re-inplacing on foo! |
|
|
|
(3b) is the common case, but special care is needed for {view}_scatter (3a) |
|
|
|
(3a) {view}_scatter ops. |
|
|
|
Consider this program: |
|
a = torch.zeros(2, 2) |
|
b = torch.ones(2) |
|
a[0] = b |
|
|
|
Post functionalization, that will look like: |
|
a = torch.zeros(2) |
|
b = torch.ones(1) |
|
a_updated = torch.select_scatter(a, b, 0, 0) |
|
|
|
In this case though, there is no "functional" op to re-inplace! |
|
Instead, we'd like to directly remove toe select_scatter call. |
|
We already know from (3) that this is valid, |
|
because "a" has no later usages in the graph. |
|
|
|
We perform the re-inplacing on the {view}_scatter op like so |
|
Before: |
|
a_updated = torch.select_scatter(a, b, args...) |
|
After: |
|
a_slice = a.select(a, args...) |
|
a_slice.copy_(b) |
|
|
|
(3b) Otherwise, replace the functional op with its inplace variant. |
|
Before: |
|
b = foo(a, args...) |
|
After: |
|
a.foo_(args...) |
|
|
|
(4) Finally, after converting either: |
|
Before: |
|
b = foo(a) |
|
After: |
|
foo_(a) |
|
or |
|
Before: |
|
b = {slice}_scatter(a, mutated_slice, args...) |
|
After: |
|
slice = {slice}(a, args...) |
|
slice.copy_(mutated_slice) |
|
|
|
We now need to find all later nodes that use "b" as an argument |
|
and update them to take in "a" instead. |
|
|
|
Note that for the majority of inplace ops, this isn't actually necessary |
|
(because most inplace ops return "self" as their output). |
|
This isn't generally true for all mutable ops though, which is why |
|
we need to actually replace all of the arguments. |
|
|
|
We also need to update our metadata of Dict[StorageWeakRef, Set[Node]], |
|
That maps a given tensor storage to the set of all nodes that take in that storage |
|
as an input. |
|
Specifically, re-inplacing `b = foo(a)` causes "a" and "b"'s sets to get fused |
|
together. |
|
|
|
(5) Any "view_inverse/scatter" nodes that were identified as "it's ok to ignore them" |
|
during step (3) get manually deleted from the graph. |
|
Their outputs are no longer used, so technically standard DCE would be able |
|
to do this, but we can no longer run FX's DCE pass now that we have mutable |
|
ops in the graph. |
|
""" |
|
_FunctionalizationMetadataProp(gm).propagate(*sample_args) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
input_storages = { |
|
StorageWeakRef(node.meta["fake_result"]._typed_storage()) |
|
for node in gm.graph.nodes |
|
if ( |
|
node.op == "placeholder" |
|
and isinstance(node.meta["fake_result"], torch.Tensor) |
|
) |
|
} |
|
|
|
|
|
storage_to_nodes: dict[StorageWeakRef, set[Node]] = defaultdict(set) |
|
for n in gm.graph.nodes: |
|
if "fake_result" in n.meta: |
|
|
|
def _add_to_map(x): |
|
if isinstance(x, FakeTensor): |
|
storage_to_nodes[StorageWeakRef(x._typed_storage())].add(n) |
|
|
|
pytree.tree_map_(_add_to_map, n.meta["fake_result"]) |
|
|
|
|
|
all_later_view_inverse_nodes_to_delete = set() |
|
for node in gm.graph.nodes: |
|
if node.op == "call_function": |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if not isinstance(node.target, torch._ops.OpOverload): |
|
continue |
|
if len(node.target._schema.arguments) < 1: |
|
continue |
|
if type(node.target._schema.arguments[0].type) != torch.TensorType: |
|
continue |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self_arg = node.args[0] |
|
self_flattened = pytree.tree_leaves(self_arg.meta["fake_result"]) |
|
node_flattened = pytree.tree_leaves(node.meta["fake_result"]) |
|
self_has_wrong_metadata = False |
|
if len(self_flattened) == len(node_flattened): |
|
for self_meta, node_meta in zip(self_flattened, node_flattened): |
|
if self_meta.numel() != node_meta.numel(): |
|
self_has_wrong_metadata = True |
|
if self_meta.dtype != node_meta.dtype: |
|
self_has_wrong_metadata = True |
|
|
|
|
|
if torch._debug_has_internal_overlap(self_meta) == 1: |
|
self_has_wrong_metadata = True |
|
|
|
|
|
|
|
if self_has_wrong_metadata and node.target != torch.ops.aten.resize.default: |
|
continue |
|
|
|
|
|
self_arg_storage = StorageWeakRef( |
|
self_arg.meta["fake_result"]._typed_storage() |
|
) |
|
if self_arg_storage in input_storages: |
|
|
|
continue |
|
if len([x for x in node.args if x is self_arg]) > 1: |
|
|
|
|
|
|
|
continue |
|
|
|
self_arg_storage = StorageWeakRef( |
|
self_arg.meta["fake_result"]._typed_storage() |
|
) |
|
self_aliases = storage_to_nodes[self_arg_storage] |
|
|
|
|
|
later_node_usages = _get_all_later_node_usages( |
|
self_aliases, node.meta["node_idx"] |
|
) |
|
|
|
|
|
later_view_inverse_node_usages = _get_view_inverse_node_usages( |
|
later_node_usages, self_aliases |
|
) |
|
|
|
|
|
|
|
|
|
|
|
can_reinplace = len(later_node_usages - later_view_inverse_node_usages) == 0 |
|
if not can_reinplace: |
|
continue |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if ( |
|
node.target in _VIEW_INVERSE_MAP |
|
and node not in all_later_view_inverse_nodes_to_delete |
|
): |
|
view_op = _VIEW_INVERSE_MAP[node.target] |
|
|
|
|
|
|
|
|
|
|
|
with gm.graph.inserting_before(node): |
|
mutated_slice_node = node.args[1] |
|
remaining_slice_args = node.args[2:] |
|
slice_node = gm.graph.create_node( |
|
"call_function", |
|
view_op, |
|
(self_arg,) + tuple(remaining_slice_args), |
|
node.kwargs, |
|
) |
|
gm.graph.create_node( |
|
"call_function", |
|
torch.ops.aten.copy_.default, |
|
( |
|
slice_node, |
|
mutated_slice_node, |
|
), |
|
{}, |
|
) |
|
|
|
all_later_view_inverse_nodes_to_delete.add(node) |
|
|
|
else: |
|
|
|
maybe_inplace_op = _maybe_get_inplace_op(node.target) |
|
if maybe_inplace_op is None: |
|
continue |
|
|
|
node.target = maybe_inplace_op |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
curr_node_storage = StorageWeakRef( |
|
node.meta["fake_result"]._typed_storage() |
|
) |
|
storage_to_nodes[self_arg_storage].update( |
|
storage_to_nodes[curr_node_storage] |
|
) |
|
storage_to_nodes[curr_node_storage].update( |
|
storage_to_nodes[self_arg_storage] |
|
) |
|
|
|
|
|
all_later_view_inverse_nodes_to_delete.update( |
|
later_view_inverse_node_usages |
|
) |
|
|
|
|
|
|
|
|
|
for old in itertools.chain([node], later_view_inverse_node_usages): |
|
new = old.args[0] |
|
nodes_to_update = [ |
|
n for n in old.users if n.meta["node_idx"] > node.meta["node_idx"] |
|
] |
|
for node_to_update in nodes_to_update: |
|
|
|
def replace_arg(a): |
|
if a == old: |
|
return new |
|
return a |
|
|
|
|
|
node_to_update.args = tree_map_only( |
|
Node, replace_arg, node_to_update.args |
|
) |
|
node_to_update.kwargs = tree_map_only( |
|
Node, replace_arg, node_to_update.kwargs |
|
) |
|
|
|
|
|
old_flattened_res = pytree.tree_leaves(old.meta["fake_result"]) |
|
node_flattened_res = pytree.tree_leaves( |
|
node_to_update.meta["fake_result"] |
|
) |
|
|
|
old_res_storage = { |
|
StorageWeakRef(x._typed_storage()) |
|
for x in old_flattened_res |
|
if isinstance(x, FakeTensor) |
|
} |
|
node_res_storage = { |
|
StorageWeakRef(x._typed_storage()) |
|
for x in node_flattened_res |
|
if isinstance(x, FakeTensor) |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if ( |
|
len(old_res_storage) == 1 |
|
and len(node_res_storage) == 1 |
|
and old_res_storage == node_res_storage |
|
): |
|
new_flattened_res = pytree.tree_leaves(new.meta["fake_result"]) |
|
new_res_storage = { |
|
StorageWeakRef(x._typed_storage()) |
|
for x in new_flattened_res |
|
if isinstance(x, FakeTensor) |
|
} |
|
assert len(new_res_storage) == 1 |
|
(new_ref,) = new_res_storage |
|
(node_ref,) = node_res_storage |
|
|
|
|
|
|
|
|
|
storage_to_nodes[node_ref].update(storage_to_nodes[new_ref]) |
|
storage_to_nodes[new_ref].update(storage_to_nodes[node_ref]) |
|
|
|
|
|
|
|
|
|
for to_delete in all_later_view_inverse_nodes_to_delete: |
|
gm.graph.erase_node(to_delete) |
|
|
|
gm.recompile() |
|
return gm |
|
|