|
|
|
import builtins |
|
import contextlib |
|
import copy |
|
import enum |
|
import functools |
|
import inspect |
|
import keyword |
|
import math |
|
import os |
|
import re |
|
import typing |
|
import warnings |
|
from collections import defaultdict |
|
from collections.abc import Iterable |
|
from contextlib import contextmanager |
|
from dataclasses import dataclass |
|
from typing import Any, Callable, Literal, NamedTuple, Optional, TYPE_CHECKING |
|
|
|
import torch |
|
import torch.utils._pytree as pytree |
|
from torch._C import _fx_map_arg as map_arg, _NodeIter |
|
|
|
from . import _pytree as fx_pytree |
|
from ._compatibility import compatibility |
|
from .immutable_collections import immutable_dict |
|
from .node import _get_qualified_name, _type_repr, Argument, Node, Target |
|
|
|
|
|
__all__ = ["PythonCode", "CodeGen", "Graph"] |
|
|
|
if TYPE_CHECKING: |
|
from ._symbolic_trace import Tracer |
|
from .graph_module import GraphModule |
|
|
|
|
|
|
|
|
|
_origin_type_map = { |
|
list: typing.List, |
|
dict: typing.Dict, |
|
set: typing.Set, |
|
frozenset: typing.FrozenSet, |
|
tuple: typing.Tuple, |
|
} |
|
|
|
_legal_ops = dict.fromkeys( |
|
["call_function", "call_method", "get_attr", "call_module", "placeholder", "output"] |
|
) |
|
|
|
|
|
|
|
|
|
TransformCodeFunc = Callable[[list[str]], list[str]] |
|
|
|
|
|
class _CustomBuiltin(NamedTuple): |
|
"""Additional objs that we add to every graph's globals. |
|
|
|
The repr() for some standard library objects is not valid Python code without |
|
an import. For common objects of this sort, we bundle them in the globals of |
|
every FX graph. |
|
""" |
|
|
|
|
|
import_str: str |
|
|
|
obj: Any |
|
|
|
|
|
|
|
_illegal_names = {k: object() for k in keyword.kwlist} |
|
_illegal_names.update(builtins.__dict__) |
|
|
|
_custom_builtins: dict[str, _CustomBuiltin] = {} |
|
|
|
|
|
def _register_custom_builtin(name: str, import_str: str, obj: Any): |
|
_custom_builtins[name] = _CustomBuiltin(import_str, obj) |
|
_illegal_names[name] = obj |
|
|
|
|
|
_register_custom_builtin("inf", "from math import inf", math.inf) |
|
_register_custom_builtin("nan", "from math import nan", math.nan) |
|
_register_custom_builtin("NoneType", "NoneType = type(None)", type(None)) |
|
_register_custom_builtin("torch", "import torch", torch) |
|
_register_custom_builtin("device", "from torch import device", torch.device) |
|
_register_custom_builtin("fx_pytree", "import torch.fx._pytree as fx_pytree", fx_pytree) |
|
_register_custom_builtin("pytree", "import torch.utils._pytree as pytree", pytree) |
|
|
|
|
|
def _is_magic(x: str) -> bool: |
|
return x.startswith("__") and x.endswith("__") |
|
|
|
|
|
def _snake_case(s: str) -> str: |
|
""" |
|
Transforms the given string ``s`` to a Python-style variable name |
|
|
|
Examples: |
|
``mod.snake_case`` -> ``mod.snake_case`` |
|
``mod.pascalCase``-> ``mod.pascal_case`` |
|
``mod.ALL_CAPS`` -> ``mod.all_caps`` |
|
""" |
|
return _snake_case_sub(s).lower() |
|
|
|
|
|
|
|
_snake_case_sub = functools.partial(re.compile(r"(?<=[a-z])([A-Z])").sub, r"_\1") |
|
|
|
|
|
_illegal_char_regex = re.compile("[^0-9a-zA-Z_]+") |
|
|
|
|
|
|
|
|
|
|
|
|
|
_name_regex = re.compile(r"^([a-zA-Z_][0-9a-zA-Z_]*?)(?:_(\d+))?$") |
|
|
|
|
|
_torch_but_not_dynamo = re.compile( |
|
r"^torch(?:\.(?!_dynamo\.|_inductor\.)[^.]+)*$" |
|
).fullmatch |
|
|
|
|
|
def _is_from_torch(obj: Any) -> bool: |
|
module_name = getattr(obj, "__module__", None) |
|
if module_name is not None: |
|
return _torch_but_not_dynamo(module_name) is not None |
|
|
|
name = getattr(obj, "__name__", None) |
|
|
|
if name is not None and name != "torch": |
|
for guess in [torch, torch.nn.functional]: |
|
if getattr(guess, name, None) is obj: |
|
return True |
|
|
|
return False |
|
|
|
|
|
class _Namespace: |
|
"""A context for associating names uniquely with objects. |
|
|
|
The following invariants are enforced: |
|
- Each object gets a single name. |
|
- Each name is unique within a given namespace. |
|
- Names generated do not shadow builtins, unless the object is indeed that builtin. |
|
""" |
|
|
|
def __init__(self): |
|
self._obj_to_name: dict[Any, str] = {} |
|
self._used_names: set[str] = set() |
|
self._base_count: dict[str, int] = {} |
|
|
|
def create_name(self, candidate: str, obj: Optional[Any]) -> str: |
|
"""Create a unique name. |
|
|
|
Arguments: |
|
candidate: used as the basis for the unique name, relevant to the user. |
|
obj: If not None, an object that will be associated with the unique name. |
|
""" |
|
if obj is not None and obj in self._obj_to_name: |
|
return self._obj_to_name[obj] |
|
|
|
|
|
match = _name_regex.match(candidate) |
|
if match is None: |
|
|
|
candidate = _illegal_char_regex.sub("_", candidate) |
|
|
|
if not candidate: |
|
candidate = "_unnamed" |
|
|
|
if candidate[0].isdigit(): |
|
candidate = f"_{candidate}" |
|
|
|
match = _name_regex.match(candidate) |
|
assert match is not None |
|
|
|
base, num = match.group(1, 2) |
|
if num is None or candidate in self._used_names: |
|
num = self._base_count.get(candidate, 0) |
|
if _illegal_names.get(candidate, obj) is not obj: |
|
num += 1 |
|
candidate = f"{base}_{num}" |
|
|
|
else: |
|
num = int(num) |
|
|
|
while candidate in self._used_names: |
|
num += 1 |
|
candidate = f"{base}_{num}" |
|
|
|
self._used_names.add(candidate) |
|
self._base_count[base] = num |
|
if obj is not None: |
|
self._obj_to_name[obj] = candidate |
|
return candidate |
|
|
|
def associate_name_with_obj(self, name: str, obj: Any): |
|
"""Associate a unique name with an object. |
|
|
|
Neither `name` nor `obj` should be associated already. |
|
""" |
|
maybe_existing = self._obj_to_name.setdefault(obj, name) |
|
assert maybe_existing is name, "obj is already associated" |
|
|
|
def _rename_object(self, obj: Any, name: str): |
|
assert obj in self._obj_to_name |
|
self._obj_to_name[obj] = name |
|
self._used_names.add(name) |
|
|
|
|
|
dtype_abbrs = { |
|
torch.bfloat16: "bf16", |
|
torch.float64: "f64", |
|
torch.float32: "f32", |
|
torch.float16: "f16", |
|
torch.float8_e4m3fn: "f8e4m3fn", |
|
torch.float8_e5m2: "f8e5m2", |
|
torch.float8_e4m3fnuz: "f8e4m3fnuz", |
|
torch.float8_e5m2fnuz: "f8e5m2fnuz", |
|
torch.float8_e8m0fnu: "f8e8m0fnu", |
|
torch.complex32: "c32", |
|
torch.complex64: "c64", |
|
torch.complex128: "c128", |
|
torch.int8: "i8", |
|
torch.int16: "i16", |
|
torch.int32: "i32", |
|
torch.int64: "i64", |
|
torch.bool: "b8", |
|
torch.uint8: "u8", |
|
torch.uint16: "u16", |
|
torch.uint32: "u32", |
|
torch.uint64: "u64", |
|
torch.bits16: "b16", |
|
torch.bits1x8: "b1x8", |
|
} |
|
|
|
|
|
@compatibility(is_backward_compatible=True) |
|
@dataclass |
|
class PythonCode: |
|
""" |
|
Represents all the information necessary to exec or save a graph as Python code. |
|
""" |
|
|
|
|
|
src: str |
|
|
|
globals: dict[str, Any] |
|
|
|
|
|
_lineno_map: Optional[dict[int, Optional[int]]] |
|
|
|
|
|
def _format_target(base: str, target: str) -> str: |
|
elems = target.split(".") |
|
r = base |
|
for e in elems: |
|
if not e.isidentifier(): |
|
r = f'getattr({r}, "{e}")' |
|
else: |
|
r = f"{r}.{e}" |
|
return r |
|
|
|
|
|
class _InsertPoint: |
|
def __init__(self, graph, new_insert): |
|
self.graph = graph |
|
self.orig_insert, graph._insert = graph._insert, new_insert |
|
|
|
def __enter__(self): |
|
pass |
|
|
|
def __exit__(self, type, value, tb): |
|
self.graph._insert = self.orig_insert |
|
|
|
|
|
class _node_list: |
|
def __init__(self, graph: "Graph", direction: Literal["_prev", "_next"] = "_next"): |
|
assert direction in ("_next", "_prev") |
|
self.graph = graph |
|
self.direction = direction |
|
|
|
def __len__(self): |
|
return self.graph._len |
|
|
|
def __iter__(self): |
|
return _NodeIter(self.graph._root, self.direction == "_prev") |
|
|
|
def __reversed__(self): |
|
return _node_list(self.graph, "_next" if self.direction == "_prev" else "_prev") |
|
|
|
|
|
class _PyTreeInfo(NamedTuple): |
|
""" |
|
Contains extra info stored when we're using Pytrees |
|
""" |
|
|
|
orig_args: list[str] |
|
in_spec: pytree.TreeSpec |
|
out_spec: Optional[pytree.TreeSpec] |
|
|
|
|
|
@dataclass(frozen=True) |
|
class _ParsedStackTrace: |
|
""" |
|
Represents the top-most frame of a parsed stack trace |
|
""" |
|
|
|
file: str |
|
lineno: str |
|
name: str |
|
code: str |
|
|
|
def get_summary_str(self): |
|
return f"File: {self.file}:{self.lineno} in {self.name}, code: {self.code}" |
|
|
|
|
|
|
|
def _parse_stack_trace(stack_trace: str): |
|
if stack_trace is None: |
|
return None |
|
pattern = re.compile(r"^File \"(.+)\", line (\d+), in (.+)$") |
|
lines = stack_trace.strip().split("\n") |
|
|
|
|
|
|
|
for idx in range(len(lines) - 2, -1, -1): |
|
line = lines[idx].strip() |
|
matches = pattern.match(line) |
|
if matches: |
|
file = matches.group(1) |
|
lineno = matches.group(2) |
|
name = matches.group(3) |
|
|
|
code = lines[idx + 1].strip() |
|
return _ParsedStackTrace(file, lineno, name, code) |
|
return None |
|
|
|
|
|
@compatibility(is_backward_compatible=False) |
|
class CodeGen: |
|
def __init__(self): |
|
self._body_transformer: Optional[TransformCodeFunc] = None |
|
self._func_name: str = "forward" |
|
|
|
def gen_fn_def(self, free_vars: list[str], maybe_return_annotation: str) -> str: |
|
""" |
|
Given the free variables and a return annotation, generates the beginning of the FX function. |
|
By default, `gen_fn_def(['a', 'b'], '') == 'def {self._func_name}(a, b):'` |
|
""" |
|
|
|
|
|
if len(free_vars) == 0 or free_vars[0] != "self": |
|
free_vars.insert(0, "self") |
|
return ( |
|
f"def {self._func_name}({', '.join(free_vars)}){maybe_return_annotation}:" |
|
) |
|
|
|
def generate_output(self, output_args: Argument) -> str: |
|
""" |
|
Given the output arguments, generates the return statement of the FX function. |
|
Note: The returned statement should not be indented. |
|
""" |
|
return f"return {repr(output_args)}" |
|
|
|
def process_inputs(self, *args: Any) -> Any: |
|
""" |
|
Transforms the inputs so that the graph can take them as arguments, as |
|
non-default codegen may result in the inputs to the function being |
|
different from the inputs to the graph. |
|
|
|
If the graph was directly runnable, this invariant should hold true |
|
`f.graph.process_outputs(f.graph(*f.graph.process_inputs(*inputs))) == f(*inputs)` |
|
""" |
|
return args |
|
|
|
def process_outputs(self, outputs: Any) -> Any: |
|
""" |
|
Transforms the outputs of the graph to be identical to the codegen. |
|
|
|
See ``process_inputs`` for more details. |
|
""" |
|
return outputs |
|
|
|
def additional_globals(self) -> list[tuple[str, Any]]: |
|
""" |
|
If your codegen uses extra global values, add tuples of (identifier,reference to the value) here. |
|
For example, return ['List', typing.List] if you need ``List`` in the global context. |
|
""" |
|
return [] |
|
|
|
def _gen_python_code( |
|
self, |
|
nodes, |
|
root_module: str, |
|
namespace: _Namespace, |
|
*, |
|
verbose: bool = False, |
|
include_stride: bool = False, |
|
include_device: bool = False, |
|
colored: bool = False, |
|
) -> PythonCode: |
|
free_vars: list[str] = [] |
|
body: list[str] = [] |
|
globals_: dict[str, Any] = {} |
|
wrapped_fns: dict[str, None] = {} |
|
|
|
|
|
maybe_return_annotation: list[str] = [""] |
|
include_stride = include_stride or ( |
|
os.environ.get("FX_GRAPH_SHOW_STRIDE", "0") == "1" |
|
) |
|
include_device = include_device or ( |
|
os.environ.get("FX_GRAPH_SHOW_DEVICE", "0") == "1" |
|
) |
|
|
|
def add_global(name_hint: str, obj: Any): |
|
"""Add an obj to be tracked as a global. |
|
|
|
We call this for names that reference objects external to the |
|
Graph, like functions or types. |
|
|
|
Returns: the global name that should be used to reference 'obj' in generated source. |
|
""" |
|
if ( |
|
_is_from_torch(obj) and obj != torch.device |
|
): |
|
|
|
|
|
|
|
return _get_qualified_name(obj) |
|
|
|
|
|
global_name = namespace.create_name(name_hint, obj) |
|
|
|
if global_name in globals_: |
|
assert globals_[global_name] is obj |
|
return global_name |
|
globals_[global_name] = obj |
|
return global_name |
|
|
|
|
|
for name, (_, obj) in _custom_builtins.items(): |
|
add_global(name, obj) |
|
|
|
def type_repr(o: Any): |
|
if o == (): |
|
|
|
return "()" |
|
|
|
typename = _type_repr(o) |
|
|
|
if origin_type := getattr(o, "__origin__", None): |
|
|
|
|
|
if isinstance(o, typing._GenericAlias): |
|
|
|
origin_type = _origin_type_map.get(origin_type, origin_type) |
|
|
|
origin_typename = add_global(_type_repr(origin_type), origin_type) |
|
|
|
if hasattr(o, "__args__"): |
|
|
|
args = [type_repr(arg) for arg in o.__args__] |
|
|
|
if len(args) == 0: |
|
|
|
|
|
return origin_typename |
|
|
|
return f'{origin_typename}[{",".join(args)}]' |
|
else: |
|
|
|
|
|
return origin_typename |
|
|
|
|
|
return add_global(typename, o) |
|
|
|
if colored: |
|
red = _color_fns["red"] |
|
dim_green = _color_fns["dim_green"] |
|
dim = _color_fns["dim"] |
|
dim_blue = _color_fns["dim_blue"] |
|
blue = _color_fns["blue"] |
|
else: |
|
red = _identity |
|
dim_green = _identity |
|
dim = _identity |
|
dim_blue = _identity |
|
blue = _identity |
|
|
|
def _get_repr(arg: Any) -> str: |
|
if isinstance(arg, Node): |
|
return repr(arg) |
|
elif isinstance(arg, tuple) and hasattr(arg, "_fields"): |
|
|
|
qualified_name = _get_qualified_name(type(arg)) |
|
global_name = add_global(qualified_name, type(arg)) |
|
return f"{global_name}{repr(tuple(arg))}" |
|
elif isinstance( |
|
arg, (torch._ops.OpOverload, torch._ops.HigherOrderOperator) |
|
): |
|
qualified_name = _get_qualified_name(arg) |
|
global_name = add_global(qualified_name, arg) |
|
return f"{global_name}" |
|
elif isinstance(arg, enum.Enum): |
|
cls = arg.__class__ |
|
clsname = add_global(cls.__name__, cls) |
|
return f"{clsname}.{arg.name}" |
|
elif isinstance(arg, torch.Tensor): |
|
size = list(arg.size()) |
|
dtype = str(arg.dtype).split(".")[-1] |
|
return f"torch.Tensor(size={size}, dtype={dtype})" |
|
elif isinstance(arg, tuple): |
|
if len(arg) == 1: |
|
return f"({_get_repr(arg[0])},)" |
|
else: |
|
return "(" + ", ".join(_get_repr(a) for a in arg) + ")" |
|
elif isinstance(arg, list): |
|
return "[" + ", ".join(_get_repr(a) for a in arg) + "]" |
|
elif isinstance(arg, slice): |
|
return f"slice({_get_repr(arg.start)}, {_get_repr(arg.stop)}, {_get_repr(arg.step)})" |
|
else: |
|
return blue(repr(arg)) |
|
|
|
def _format_args( |
|
args: tuple[Argument, ...], kwargs: dict[str, Argument] |
|
) -> str: |
|
res = [_get_repr(a) for a in args] |
|
res.extend([f"{k} = {_get_repr(v)}" for k, v in kwargs.items()]) |
|
return ", ".join(res) |
|
|
|
|
|
|
|
|
|
|
|
node_to_last_use: dict[Node, Node] = {} |
|
user_to_last_uses: dict[Node, list[Node]] = {} |
|
|
|
def register_last_uses(n: Node, user: Node): |
|
if n not in node_to_last_use: |
|
node_to_last_use[n] = user |
|
user_to_last_uses.setdefault(user, []).append(n) |
|
|
|
for node in reversed(nodes): |
|
for input_node in node._input_nodes: |
|
register_last_uses(input_node, node) |
|
|
|
def delete_unused_values(user: Node): |
|
""" |
|
Delete values after their last use. This ensures that values that are |
|
not used in the remainder of the code are freed and the memory usage |
|
of the code is optimal. |
|
""" |
|
if user.op == "placeholder": |
|
return |
|
if user.op == "output": |
|
body.append("\n") |
|
return |
|
nodes_to_delete = user_to_last_uses.get(user, []) |
|
|
|
if len(user.users.keys()) == 0: |
|
|
|
|
|
|
|
nodes_to_delete.append(user) |
|
|
|
if len(nodes_to_delete): |
|
to_delete_str = " = ".join( |
|
[repr(n) for n in nodes_to_delete] + ["None"] |
|
) |
|
body.append(f"; {dim(to_delete_str)}\n") |
|
else: |
|
body.append("\n") |
|
|
|
prev_stacktrace = None |
|
|
|
def append_stacktrace_summary(node: Node): |
|
""" |
|
Append a summary of the stacktrace to the generated code. This is |
|
useful for debugging. |
|
""" |
|
nonlocal prev_stacktrace |
|
|
|
if node.op not in {"placeholder", "output"}: |
|
stack_trace = node.stack_trace |
|
if stack_trace: |
|
if stack_trace != prev_stacktrace: |
|
prev_stacktrace = stack_trace |
|
if parsed_stack_trace := _parse_stack_trace(stack_trace): |
|
summary_str = parsed_stack_trace.get_summary_str() |
|
else: |
|
summary_str = "" |
|
body.append(f'\n {dim(f"# {summary_str}")}\n') |
|
elif prev_stacktrace != "": |
|
prev_stacktrace = "" |
|
no_stacktrace_msg = "# No stacktrace found for following nodes" |
|
body.append(f"\n{dim(no_stacktrace_msg)}\n") |
|
|
|
def stringify_shape(shape: Iterable) -> str: |
|
return f"[{', '.join([str(x) for x in shape])}]" |
|
|
|
def emit_node(node: Node): |
|
maybe_type_annotation = ( |
|
"" if node.type is None else f" : {type_repr(node.type)}" |
|
) |
|
|
|
if verbose: |
|
|
|
from torch.fx.experimental.proxy_tensor import py_sym_types |
|
from torch.fx.passes.shape_prop import TensorMetadata |
|
|
|
meta_val = node.meta.get( |
|
"val", |
|
node.meta.get("tensor_meta", node.meta.get("example_value", None)), |
|
) |
|
|
|
if isinstance(meta_val, torch.Tensor) and meta_val.layout not in ( |
|
torch.sparse_csc, |
|
torch.sparse_csr, |
|
): |
|
stride_annotation = ( |
|
f"{stringify_shape(meta_val.stride())}" |
|
if include_stride |
|
else "" |
|
) |
|
device_annotation = f"{meta_val.device}" if include_device else "" |
|
maybe_type_annotation = ( |
|
f': "{red(dtype_abbrs[meta_val.dtype])}{blue(stringify_shape(meta_val.shape))}' |
|
f'{dim_blue(stride_annotation)}{dim_green(device_annotation)}"' |
|
) |
|
elif isinstance(meta_val, py_sym_types): |
|
maybe_type_annotation = f': "Sym({meta_val})"' |
|
elif isinstance(meta_val, TensorMetadata): |
|
maybe_type_annotation = f': "{dtype_abbrs[meta_val.dtype]}{stringify_shape(meta_val.shape)}"' |
|
|
|
if node.op == "placeholder": |
|
assert isinstance(node.target, str) |
|
maybe_default_arg = ( |
|
"" if not node.args else f" = {_get_repr(node.args[0])}" |
|
) |
|
free_vars.append( |
|
f"{node.target}{maybe_type_annotation}{maybe_default_arg}" |
|
) |
|
raw_name = node.target.replace("*", "") |
|
if raw_name != repr(node): |
|
body.append(f"{repr(node)} = {raw_name}\n") |
|
return |
|
elif node.op == "call_method": |
|
assert isinstance(node.target, str) |
|
body.append( |
|
f"{repr(node)}{maybe_type_annotation} = {_format_target(_get_repr(node.args[0]), node.target)}" |
|
f"({_format_args(node.args[1:], node.kwargs)})" |
|
) |
|
return |
|
elif node.op == "call_function": |
|
assert callable(node.target) |
|
|
|
if ( |
|
getattr(node.target, "__module__", "") == "_operator" |
|
and node.target.__name__ in magic_methods |
|
): |
|
assert isinstance(node.args, tuple) |
|
body.append( |
|
f"{repr(node)}{maybe_type_annotation} = " |
|
f"{magic_methods[node.target.__name__].format(*(_get_repr(a) for a in node.args))}" |
|
) |
|
return |
|
|
|
|
|
|
|
if ( |
|
getattr(node.target, "__module__", "") == "_operator" |
|
and node.target.__name__ in inplace_methods |
|
): |
|
body.append( |
|
f"{inplace_methods[node.target.__name__].format(*(_get_repr(a) for a in node.args))}; " |
|
f"{repr(node)}{maybe_type_annotation} = {_get_repr(node.args[0])}" |
|
) |
|
return |
|
|
|
qualified_name = _get_qualified_name(node.target) |
|
global_name = add_global(qualified_name, node.target) |
|
|
|
|
|
if ( |
|
global_name == "getattr" |
|
and isinstance(node.args, tuple) |
|
and isinstance(node.args[1], str) |
|
and node.args[1].isidentifier() |
|
and len(node.args) == 2 |
|
): |
|
body.append( |
|
f"{repr(node)}{maybe_type_annotation} = {_format_target(_get_repr(node.args[0]), node.args[1])}" |
|
) |
|
return |
|
body.append( |
|
f"{repr(node)}{maybe_type_annotation} = {global_name}({_format_args(node.args, node.kwargs)})" |
|
) |
|
if node.meta.get("is_wrapped", False): |
|
wrapped_fns.setdefault(global_name) |
|
return |
|
elif node.op == "call_module": |
|
assert isinstance(node.target, str) |
|
body.append( |
|
f"{repr(node)}{maybe_type_annotation} = " |
|
f"{_format_target(root_module, node.target)}({_format_args(node.args, node.kwargs)})" |
|
) |
|
return |
|
elif node.op == "get_attr": |
|
assert isinstance(node.target, str) |
|
body.append( |
|
f"{repr(node)}{maybe_type_annotation} = {_format_target(root_module, node.target)}" |
|
) |
|
return |
|
elif node.op == "output": |
|
if node.type is not None: |
|
maybe_return_annotation[0] = f" -> {type_repr(node.type)}" |
|
body.append(self.generate_output(node.args[0])) |
|
return |
|
raise NotImplementedError(f"node: {node.op} {node.target}") |
|
|
|
for i, node in enumerate(nodes): |
|
|
|
|
|
if verbose: |
|
append_stacktrace_summary(node) |
|
|
|
|
|
|
|
body.append(f"# COUNTER: {i}\n") |
|
emit_node(node) |
|
delete_unused_values(node) |
|
|
|
if len(body) == 0: |
|
|
|
|
|
|
|
body.append("pass\n") |
|
|
|
if len(wrapped_fns) > 0: |
|
wrap_name = add_global("wrap", torch.fx.wrap) |
|
wrap_stmts = "\n".join([f'{wrap_name}("{name}")' for name in wrapped_fns]) |
|
else: |
|
wrap_stmts = "" |
|
|
|
if self._body_transformer: |
|
body = self._body_transformer(body) |
|
|
|
for name, value in self.additional_globals(): |
|
add_global(name, value) |
|
|
|
prologue = self.gen_fn_def(free_vars, maybe_return_annotation[0]) |
|
|
|
|
|
lineno_map: dict[int, Optional[int]] = {} |
|
prologue_len = prologue.count("\n") + 1 |
|
new_lines: list[str] = [] |
|
cur_idx = None |
|
for line in "".join(body).split("\n"): |
|
counter = _counter_regexp.search(line) |
|
if counter is not None: |
|
cur_idx = int(counter.group(1)) |
|
else: |
|
lineno_map[len(new_lines) + prologue_len] = cur_idx |
|
new_lines.append(line) |
|
|
|
code = "\n".join(new_lines).lstrip("\n") |
|
code = "\n".join(" " + line for line in code.split("\n")) |
|
|
|
fn_code = f""" |
|
{wrap_stmts} |
|
|
|
{prologue} |
|
{code}""" |
|
return PythonCode(fn_code, globals_, _lineno_map=lineno_map) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class _PyTreeCodeGen(CodeGen): |
|
def __init__(self, pytree_info: _PyTreeInfo): |
|
super().__init__() |
|
self.pytree_info: _PyTreeInfo = pytree_info |
|
|
|
def process_inputs(self, *inputs: Any) -> Any: |
|
flat_args = pytree.arg_tree_leaves(*inputs) |
|
return flat_args |
|
|
|
def process_outputs(self, out: Any) -> Any: |
|
if self.pytree_info is None or self.pytree_info.out_spec is None: |
|
return out |
|
if not isinstance(out, (list, tuple)): |
|
out = [out] |
|
assert self.pytree_info.out_spec is not None |
|
return pytree.tree_unflatten(out, self.pytree_info.out_spec) |
|
|
|
def gen_fn_def(self, free_vars, maybe_return_annotation): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.pytree_info is None: |
|
return super().gen_fn_def(free_vars, maybe_return_annotation) |
|
|
|
fn_args = self.pytree_info.orig_args |
|
has_orig_self = (fn_args[0] == "self") if len(fn_args) > 0 else False |
|
if has_orig_self: |
|
free_vars.insert(0, "self") |
|
fn_definition = super().gen_fn_def(fn_args[:], maybe_return_annotation) |
|
|
|
if len(free_vars) > 0: |
|
|
|
has_args_kwargs_tuple = ( |
|
self.pytree_info.in_spec.type == tuple |
|
and self.pytree_info.in_spec.num_children == 2 |
|
and self.pytree_info.in_spec.children_specs[0].type == tuple |
|
and self.pytree_info.in_spec.children_specs[1].type == dict |
|
) |
|
fn_kwargs = "{}" |
|
fn_signature = f"[{', '.join(fn_args)}], self._in_spec" |
|
if has_args_kwargs_tuple: |
|
count_args = self.pytree_info.in_spec.children_specs[0].num_children |
|
fn_args = self.pytree_info.orig_args[:count_args] |
|
fn_kwargs = ( |
|
"{" |
|
+ ", ".join( |
|
f"'{k}':{v}" |
|
for k, v in zip( |
|
self.pytree_info.in_spec.children_specs[1].context, |
|
self.pytree_info.orig_args[count_args:], |
|
) |
|
) |
|
+ "}" |
|
) |
|
fn_signature = f"([{', '.join(fn_args)}], {fn_kwargs}), self._in_spec" |
|
|
|
|
|
|
|
|
|
|
|
without_annotation = [x.split(":")[0] for x in free_vars] |
|
has_annotation = [x + "; " for x in free_vars if ":" in x] |
|
if len(has_annotation) > 0: |
|
fn_definition += "\n " + "".join(has_annotation) + "\n" |
|
fn_definition += f""" |
|
{', '.join(without_annotation)}, = fx_pytree.tree_flatten_spec({fn_signature})""" |
|
return fn_definition |
|
|
|
def generate_output(self, output_args): |
|
if self.pytree_info and self.pytree_info.out_spec: |
|
return f"return pytree.tree_unflatten({repr(output_args)}, self._out_spec)" |
|
else: |
|
return super().generate_output(output_args) |
|
|
|
|
|
class _FindNodesLookupTable: |
|
""" |
|
Side table for the graph for the purpose of doing fast queries |
|
""" |
|
|
|
def __init__(self): |
|
self.table: dict[tuple[str, Optional[Target]], dict[Node, None]] = defaultdict( |
|
dict |
|
) |
|
|
|
def _key(self, node) -> tuple[str, Optional[Target]]: |
|
return (node.op, node.target if node.op == "call_function" else None) |
|
|
|
def __contains__(self, node) -> bool: |
|
return node in self.table[self._key(node)] |
|
|
|
def insert(self, node: Node) -> None: |
|
self.table[self._key(node)][node] = None |
|
|
|
def remove(self, node: Node) -> None: |
|
self.table[self._key(node)].pop(node) |
|
|
|
def find_nodes(self, *, op: str, target: Optional["Target"] = None): |
|
if op == "call_function": |
|
assert target is not None |
|
return [*self.table[(op, target)].keys()] |
|
|
|
if target is None: |
|
return [*self.table[(op, None)].keys()] |
|
|
|
|
|
return [node for node in self.table[(op, None)].keys() if node.target == target] |
|
|
|
|
|
@compatibility(is_backward_compatible=True) |
|
class Graph: |
|
""" |
|
``Graph`` is the main data structure used in the FX Intermediate Representation. |
|
It consists of a series of ``Node`` s, each representing callsites (or other |
|
syntactic constructs). The list of ``Node`` s, taken together, constitute a |
|
valid Python function. |
|
|
|
For example, the following code |
|
|
|
.. code-block:: python |
|
|
|
import torch |
|
import torch.fx |
|
|
|
|
|
class MyModule(torch.nn.Module): |
|
def __init__(self): |
|
super().__init__() |
|
self.param = torch.nn.Parameter(torch.rand(3, 4)) |
|
self.linear = torch.nn.Linear(4, 5) |
|
|
|
def forward(self, x): |
|
return torch.topk( |
|
torch.sum(self.linear(x + self.linear.weight).relu(), dim=-1), 3 |
|
) |
|
|
|
|
|
m = MyModule() |
|
gm = torch.fx.symbolic_trace(m) |
|
|
|
Will produce the following Graph:: |
|
|
|
print(gm.graph) |
|
|
|
.. code-block:: text |
|
|
|
graph(x): |
|
%linear_weight : [num_users=1] = self.linear.weight |
|
%add_1 : [num_users=1] = call_function[target=operator.add](args = (%x, %linear_weight), kwargs = {}) |
|
%linear_1 : [num_users=1] = call_module[target=linear](args = (%add_1,), kwargs = {}) |
|
%relu_1 : [num_users=1] = call_method[target=relu](args = (%linear_1,), kwargs = {}) |
|
%sum_1 : [num_users=1] = call_function[target=torch.sum](args = (%relu_1,), kwargs = {dim: -1}) |
|
%topk_1 : [num_users=1] = call_function[target=torch.topk](args = (%sum_1, 3), kwargs = {}) |
|
return topk_1 |
|
|
|
For the semantics of operations represented in the ``Graph``, please see :class:`Node`. |
|
""" |
|
|
|
@compatibility(is_backward_compatible=True) |
|
def __init__( |
|
self, |
|
owning_module: Optional["GraphModule"] = None, |
|
tracer_cls: Optional[type["Tracer"]] = None, |
|
tracer_extras: Optional[dict[str, Any]] = None, |
|
): |
|
""" |
|
Construct an empty Graph. |
|
""" |
|
self._root: Node = Node(self, "", "root", "", (), {}) |
|
self._used_names: dict[str, int] = {} |
|
self._insert = self._root.prepend |
|
self._len = 0 |
|
self._graph_namespace = _Namespace() |
|
self._owning_module = owning_module |
|
self._tracer_cls = tracer_cls |
|
self._tracer_extras = tracer_extras |
|
self._codegen = CodeGen() |
|
self._co_fields: dict[str, Any] = {} |
|
self._find_nodes_lookup_table = _FindNodesLookupTable() |
|
|
|
@property |
|
def owning_module(self): |
|
return self._owning_module |
|
|
|
@owning_module.setter |
|
def owning_module(self, mod: Optional["GraphModule"]): |
|
self._owning_module = mod |
|
|
|
@property |
|
def nodes(self) -> _node_list: |
|
""" |
|
Get the list of Nodes that constitute this Graph. |
|
|
|
Note that this ``Node`` list representation is a doubly-linked list. Mutations |
|
during iteration (e.g. delete a Node, add a Node) are safe. |
|
|
|
Returns: |
|
|
|
A doubly-linked list of Nodes. Note that ``reversed`` can be called on |
|
this list to switch iteration order. |
|
""" |
|
return _node_list(self) |
|
|
|
@compatibility(is_backward_compatible=False) |
|
def output_node(self) -> Node: |
|
output_node = next(iter(reversed(self.nodes))) |
|
assert output_node.op == "output" |
|
return output_node |
|
|
|
@compatibility(is_backward_compatible=False) |
|
def find_nodes( |
|
self, *, op: str, target: Optional["Target"] = None, sort: bool = True |
|
): |
|
""" |
|
Allows for fast query of nodes |
|
|
|
Args: |
|
|
|
op (str): the name of the operation |
|
|
|
target (Optional[Target]): the target of the node. For call_function, |
|
the target is required. For other ops, the target is optional. |
|
|
|
sort (bool): whether to return nodes in the order they appear on |
|
on the graph. |
|
|
|
Returns: |
|
|
|
Iteratable of nodes with the requested op and target. |
|
""" |
|
node_list = self._find_nodes_lookup_table.find_nodes(op=op, target=target) |
|
if sort: |
|
return sorted(node_list) |
|
return node_list |
|
|
|
@compatibility(is_backward_compatible=True) |
|
def graph_copy( |
|
self, g: "Graph", val_map: dict[Node, Node], return_output_node=False |
|
) -> "Optional[Argument]": |
|
""" |
|
Copy all nodes from a given graph into ``self``. |
|
|
|
Args: |
|
|
|
g (Graph): The source graph from which to copy Nodes. |
|
|
|
val_map (Dict[Node, Node]): a dictionary that will be populated with a mapping |
|
from nodes in ``g`` to nodes in ``self``. Note that ``val_map`` can be passed |
|
in with values in it already to override copying of certain values. |
|
|
|
Returns: |
|
|
|
The value in ``self`` that is now equivalent to the output value in ``g``, |
|
if ``g`` had an ``output`` node. ``None`` otherwise. |
|
""" |
|
for node in g.nodes: |
|
if node in val_map: |
|
continue |
|
if node.op == "output": |
|
rv = map_arg(node.args[0], lambda n: val_map[n]) |
|
return rv if not return_output_node else (rv, node) |
|
val_map[node] = self.node_copy(node, lambda n: val_map[n]) |
|
return None |
|
|
|
def __deepcopy__(self, memo=None) -> "Graph": |
|
""" |
|
Explicitly implement __deepcopy__ to prevent excessive recursion depth |
|
from the default implementation. This uses graph_copy to copy the nodes |
|
in an iterative way, rather than recursive. It also populates the |
|
memoization table to prevent unnecessary copies (e.g. references to |
|
nodes or other parts of the Graph from a custom GraphModule implementation. |
|
""" |
|
memo = memo if memo else {} |
|
g = Graph(tracer_cls=self._tracer_cls) |
|
output_vals = g.graph_copy(self, val_map=memo, return_output_node=True) |
|
g._codegen = copy.deepcopy(self._codegen) |
|
if output_vals is not None: |
|
assert isinstance(output_vals, tuple) |
|
output_val, old_output_node = output_vals |
|
new_output_node = g.output( |
|
output_val, type_expr=getattr(old_output_node, "type", None) |
|
) |
|
new_output_node.meta = copy.copy(old_output_node.meta) |
|
return g |
|
|
|
@compatibility(is_backward_compatible=True) |
|
def create_node( |
|
self, |
|
op: str, |
|
target: "Target", |
|
args: Optional[tuple["Argument", ...]] = None, |
|
kwargs: Optional[dict[str, "Argument"]] = None, |
|
name: Optional[str] = None, |
|
type_expr: Optional[Any] = None, |
|
) -> Node: |
|
""" |
|
Create a ``Node`` and add it to the ``Graph`` at the current insert-point. |
|
Note that the current insert-point can be set via :meth:`Graph.inserting_before` |
|
and :meth:`Graph.inserting_after`. |
|
|
|
Args: |
|
op (str): the opcode for this Node. One of 'call_function', 'call_method', 'get_attr', |
|
'call_module', 'placeholder', or 'output'. The semantics of these opcodes are |
|
described in the ``Graph`` docstring. |
|
|
|
args (Optional[Tuple[Argument, ...]]): is a tuple of arguments to this node. |
|
|
|
kwargs (Optional[Dict[str, Argument]]): the kwargs of this Node |
|
|
|
name (Optional[str]): an optional string name for the ``Node``. |
|
This will influence the name of the value assigned to in the |
|
Python generated code. |
|
|
|
type_expr (Optional[Any]): an optional type annotation representing the |
|
Python type the output of this node will have. |
|
|
|
Returns: |
|
|
|
The newly-created and inserted node. |
|
""" |
|
|
|
if not args: |
|
args = () |
|
else: |
|
assert isinstance(args, tuple), "args must be a tuple" |
|
if not kwargs: |
|
kwargs = immutable_dict() |
|
else: |
|
assert isinstance(kwargs, dict), "kwargs must be a dict" |
|
|
|
candidate = name if name is not None else self._target_to_str(target) |
|
name = self._graph_namespace.create_name(candidate, None) |
|
n = Node(self, name, op, target, args, kwargs, type_expr) |
|
|
|
if ( |
|
self.owning_module is not None |
|
and getattr(self.owning_module, "_create_node_hooks", None) is not None |
|
): |
|
for f in self.owning_module._create_node_hooks: |
|
f(n) |
|
|
|
self._graph_namespace.associate_name_with_obj(name, n) |
|
|
|
self._insert(n) |
|
self._find_nodes_lookup_table.insert(n) |
|
self._len += 1 |
|
return n |
|
|
|
@compatibility(is_backward_compatible=False) |
|
def process_inputs(self, *args): |
|
""" |
|
Processes args so that they can be passed to the FX graph. |
|
""" |
|
return self._codegen.process_inputs(*args) |
|
|
|
@compatibility(is_backward_compatible=False) |
|
def process_outputs(self, out): |
|
return self._codegen.process_outputs(out) |
|
|
|
@compatibility(is_backward_compatible=True) |
|
def erase_node(self, to_erase: Node) -> None: |
|
""" |
|
Erases a ``Node`` from the ``Graph``. Throws an exception if |
|
there are still users of that node in the ``Graph``. |
|
|
|
Args: |
|
|
|
to_erase (Node): The ``Node`` to erase from the ``Graph``. |
|
""" |
|
if len(to_erase.users) > 0: |
|
raise RuntimeError( |
|
f"Tried to erase Node {to_erase} but it still had {len(to_erase.users)} " |
|
f"users in the graph: {to_erase.users}!" |
|
) |
|
if to_erase.graph != self: |
|
raise RuntimeError(f"Attempting to remove {to_erase} from wrong graph!") |
|
if to_erase._erased: |
|
warnings.warn(f"erase_node({to_erase}) on an already erased node") |
|
return |
|
|
|
if ( |
|
self.owning_module is not None |
|
and getattr(self.owning_module, "_erase_node_hooks", None) is not None |
|
): |
|
for f in self.owning_module._erase_node_hooks: |
|
f(to_erase) |
|
|
|
self._find_nodes_lookup_table.remove(to_erase) |
|
to_erase._remove_from_list() |
|
to_erase._erased = True |
|
self._len -= 1 |
|
|
|
|
|
|
|
to_erase._update_args_kwargs( |
|
map_arg(to_erase._args, lambda n: None), |
|
map_arg(to_erase._kwargs, lambda n: None), |
|
) |
|
|
|
@compatibility(is_backward_compatible=True) |
|
def inserting_before(self, n: Optional[Node] = None): |
|
"""Set the point at which create_node and companion methods will insert into the graph. |
|
When used within a 'with' statement, this will temporary set the insert point and |
|
then restore it when the with statement exits:: |
|
|
|
with g.inserting_before(n): |
|
... # inserting before node n |
|
... # insert point restored to what it was previously |
|
g.inserting_before(n) # set the insert point permanently |
|
|
|
Args: |
|
|
|
n (Optional[Node]): The node before which to insert. If None this will insert before |
|
the beginning of the entire graph. |
|
|
|
Returns: |
|
A resource manager that will restore the insert point on ``__exit__``. |
|
""" |
|
if n is None: |
|
return self.inserting_after(self._root) |
|
assert n.graph == self, "Node to insert before is not in graph." |
|
return _InsertPoint(self, n.prepend) |
|
|
|
@compatibility(is_backward_compatible=True) |
|
def inserting_after(self, n: Optional[Node] = None): |
|
"""Set the point at which create_node and companion methods will insert into the graph. |
|
When used within a 'with' statement, this will temporary set the insert point and |
|
then restore it when the with statement exits:: |
|
|
|
with g.inserting_after(n): |
|
... # inserting after node n |
|
... # insert point restored to what it was previously |
|
g.inserting_after(n) # set the insert point permanently |
|
|
|
Args: |
|
|
|
n (Optional[Node]): The node before which to insert. If None this will insert after |
|
the beginning of the entire graph. |
|
|
|
Returns: |
|
A resource manager that will restore the insert point on ``__exit__``. |
|
""" |
|
if n is None: |
|
return self.inserting_before(self._root) |
|
assert n.graph == self, "Node to insert after is not in graph." |
|
return _InsertPoint(self, n.append) |
|
|
|
@compatibility(is_backward_compatible=True) |
|
def placeholder( |
|
self, |
|
name: str, |
|
type_expr: Optional[Any] = None, |
|
default_value: Any = inspect.Signature.empty, |
|
) -> Node: |
|
""" |
|
Insert a ``placeholder`` node into the Graph. A ``placeholder`` represents |
|
a function input. |
|
|
|
Args: |
|
|
|
name (str): A name for the input value. This corresponds to the name |
|
of the positional argument to the function this ``Graph`` represents. |
|
|
|
type_expr (Optional[Any]): an optional type annotation representing the |
|
Python type the output of this node will have. This is needed in some |
|
cases for proper code generation (e.g. when the function is used |
|
subsequently in TorchScript compilation). |
|
|
|
default_value (Any): The default value this function argument should take |
|
on. NOTE: to allow for `None` as a default value, `inspect.Signature.empty` |
|
should be passed as this argument to specify that the parameter does _not_ |
|
have a default value. |
|
|
|
.. note:: |
|
The same insertion point and type expression rules apply for this method |
|
as ``Graph.create_node``. |
|
""" |
|
args = () if default_value is inspect.Signature.empty else (default_value,) |
|
return self.create_node("placeholder", name, args=args, type_expr=type_expr) |
|
|
|
@compatibility(is_backward_compatible=True) |
|
def get_attr(self, qualified_name: str, type_expr: Optional[Any] = None) -> Node: |
|
""" |
|
Insert a ``get_attr`` node into the Graph. A ``get_attr`` ``Node`` represents the |
|
fetch of an attribute from the ``Module`` hierarchy. |
|
|
|
Args: |
|
|
|
qualified_name (str): the fully-qualified name of the attribute to be retrieved. |
|
For example, if the traced Module has a submodule named ``foo``, which has a |
|
submodule named ``bar``, which has an attribute named ``baz``, the qualified |
|
name ``foo.bar.baz`` should be passed as ``qualified_name``. |
|
|
|
type_expr (Optional[Any]): an optional type annotation representing the |
|
Python type the output of this node will have. |
|
|
|
|
|
Returns: |
|
|
|
The newly-created and inserted ``get_attr`` node. |
|
|
|
.. note:: |
|
The same insertion point and type expression rules apply for this method |
|
as ``Graph.create_node``. |
|
""" |
|
|
|
def _get_attr_reference_exists( |
|
mod: torch.nn.Module, qualified_name: str |
|
) -> bool: |
|
module_path, _, name = qualified_name.rpartition(".") |
|
|
|
try: |
|
submod: torch.nn.Module = mod.get_submodule(module_path) |
|
except AttributeError: |
|
warnings.warn(f"Failed to fetch module {module_path}!") |
|
return False |
|
|
|
if not hasattr(submod, name): |
|
return False |
|
|
|
res = getattr(submod, name) |
|
|
|
if ( |
|
not isinstance(res, torch.nn.Module) |
|
and not isinstance(res, torch.nn.Parameter) |
|
and name not in submod._buffers |
|
): |
|
return False |
|
|
|
return True |
|
|
|
if self.owning_module and not _get_attr_reference_exists( |
|
self.owning_module, qualified_name |
|
): |
|
warnings.warn( |
|
"Attempted to insert a get_attr Node with no " |
|
"underlying reference in the owning " |
|
"GraphModule! Call " |
|
"GraphModule.add_submodule to add the " |
|
"necessary submodule, " |
|
"GraphModule.add_parameter to add the " |
|
"necessary Parameter, or " |
|
"nn.Module.register_buffer to add the " |
|
"necessary buffer", |
|
stacklevel=2, |
|
) |
|
return self.create_node("get_attr", qualified_name, type_expr=type_expr) |
|
|
|
@compatibility(is_backward_compatible=True) |
|
def call_module( |
|
self, |
|
module_name: str, |
|
args: Optional[tuple["Argument", ...]] = None, |
|
kwargs: Optional[dict[str, "Argument"]] = None, |
|
type_expr: Optional[Any] = None, |
|
) -> Node: |
|
""" |
|
Insert a ``call_module`` ``Node`` into the ``Graph``. A ``call_module`` node |
|
represents a call to the forward() function of a ``Module`` in the ``Module`` |
|
hierarchy. |
|
|
|
Args: |
|
|
|
module_name (str): The qualified name of the ``Module`` in the ``Module`` |
|
hierarchy to be called. For example, if the traced ``Module`` has a |
|
submodule named ``foo``, which has a submodule named ``bar``, the |
|
qualified name ``foo.bar`` should be passed as ``module_name`` to |
|
call that module. |
|
|
|
args (Optional[Tuple[Argument, ...]]): The positional arguments to be passed |
|
to the called method. Note that this should *not* include a ``self`` argument. |
|
|
|
kwargs (Optional[Dict[str, Argument]]): The keyword arguments to be passed |
|
to the called method |
|
|
|
type_expr (Optional[Any]): an optional type annotation representing the |
|
Python type the output of this node will have. |
|
|
|
Returns: |
|
|
|
The newly-created and inserted ``call_module`` node. |
|
|
|
.. note:: |
|
The same insertion point and type expression rules apply for this method |
|
as :meth:`Graph.create_node`. |
|
""" |
|
if self.owning_module and self.owning_module.get_submodule(module_name) is None: |
|
warnings.warn( |
|
"Attempted to insert a call_module Node with " |
|
"no underlying reference in the owning " |
|
"GraphModule! Call " |
|
"GraphModule.add_submodule to add the " |
|
"necessary submodule" |
|
) |
|
return self.create_node( |
|
"call_module", module_name, args, kwargs, type_expr=type_expr |
|
) |
|
|
|
@compatibility(is_backward_compatible=True) |
|
def call_method( |
|
self, |
|
method_name: str, |
|
args: Optional[tuple["Argument", ...]] = None, |
|
kwargs: Optional[dict[str, "Argument"]] = None, |
|
type_expr: Optional[Any] = None, |
|
) -> Node: |
|
""" |
|
Insert a ``call_method`` ``Node`` into the ``Graph``. A ``call_method`` node |
|
represents a call to a given method on the 0th element of ``args``. |
|
|
|
Args: |
|
|
|
method_name (str): The name of the method to apply to the self argument. |
|
For example, if args[0] is a ``Node`` representing a ``Tensor``, |
|
then to call ``relu()`` on that ``Tensor``, pass ``relu`` to ``method_name``. |
|
|
|
args (Optional[Tuple[Argument, ...]]): The positional arguments to be passed |
|
to the called method. Note that this *should* include a ``self`` argument. |
|
|
|
kwargs (Optional[Dict[str, Argument]]): The keyword arguments to be passed |
|
to the called method |
|
|
|
type_expr (Optional[Any]): an optional type annotation representing the |
|
Python type the output of this node will have. |
|
|
|
Returns: |
|
|
|
The newly created and inserted ``call_method`` node. |
|
|
|
.. note:: |
|
The same insertion point and type expression rules apply for this method |
|
as :meth:`Graph.create_node`. |
|
""" |
|
return self.create_node( |
|
"call_method", method_name, args, kwargs, type_expr=type_expr |
|
) |
|
|
|
@compatibility(is_backward_compatible=True) |
|
def call_function( |
|
self, |
|
the_function: Callable[..., Any], |
|
args: Optional[tuple["Argument", ...]] = None, |
|
kwargs: Optional[dict[str, "Argument"]] = None, |
|
type_expr: Optional[Any] = None, |
|
) -> Node: |
|
""" |
|
Insert a ``call_function`` ``Node`` into the ``Graph``. A ``call_function`` node |
|
represents a call to a Python callable, specified by ``the_function``. |
|
|
|
Args: |
|
|
|
the_function (Callable[..., Any]): The function to be called. Can be any PyTorch |
|
operator, Python function, or member of the ``builtins`` or ``operator`` |
|
namespaces. |
|
|
|
args (Optional[Tuple[Argument, ...]]): The positional arguments to be passed |
|
to the called function. |
|
|
|
kwargs (Optional[Dict[str, Argument]]): The keyword arguments to be passed |
|
to the called function |
|
|
|
type_expr (Optional[Any]): an optional type annotation representing the |
|
Python type the output of this node will have. |
|
|
|
Returns: |
|
|
|
The newly created and inserted ``call_function`` node. |
|
|
|
.. note:: |
|
The same insertion point and type expression rules apply for this method |
|
as :meth:`Graph.create_node`. |
|
""" |
|
return self.create_node( |
|
"call_function", the_function, args, kwargs, type_expr=type_expr |
|
) |
|
|
|
@compatibility(is_backward_compatible=True) |
|
def node_copy( |
|
self, node: Node, arg_transform: Callable[[Node], "Argument"] = lambda x: x |
|
) -> Node: |
|
""" |
|
Copy a node from one graph into another. ``arg_transform`` needs to transform arguments from |
|
the graph of node to the graph of self. Example:: |
|
|
|
# Copying all the nodes in `g` into `new_graph` |
|
g: torch.fx.Graph = ... |
|
new_graph = torch.fx.graph() |
|
value_remap = {} |
|
for node in g.nodes: |
|
value_remap[node] = new_graph.node_copy(node, lambda n: value_remap[n]) |
|
|
|
Args: |
|
|
|
node (Node): The node to copy into ``self``. |
|
|
|
arg_transform (Callable[[Node], Argument]): A function that transforms |
|
``Node`` arguments in node's ``args`` and ``kwargs`` into the |
|
equivalent argument in ``self``. In the simplest case, this should |
|
retrieve a value out of a table mapping Nodes in the original |
|
graph to ``self``. |
|
""" |
|
args = map_arg(node.args, arg_transform) |
|
kwargs = map_arg(node.kwargs, arg_transform) |
|
assert isinstance(args, tuple) |
|
assert isinstance(kwargs, dict) |
|
result_node = self.create_node( |
|
node.op, node.target, args, kwargs, node.name, node.type |
|
) |
|
result_node.meta = copy.copy(node.meta) |
|
return result_node |
|
|
|
@compatibility(is_backward_compatible=True) |
|
def output(self, result: "Argument", type_expr: Optional[Any] = None): |
|
""" |
|
Insert an ``output`` ``Node`` into the ``Graph``. An ``output`` node represents |
|
a ``return`` statement in Python code. ``result`` is the value that should |
|
be returned. |
|
|
|
Args: |
|
|
|
result (Argument): The value to be returned. |
|
|
|
type_expr (Optional[Any]): an optional type annotation representing the |
|
Python type the output of this node will have. |
|
|
|
.. note:: |
|
|
|
The same insertion point and type expression rules apply for this method |
|
as ``Graph.create_node``. |
|
""" |
|
return self.create_node( |
|
op="output", target="output", args=(result,), type_expr=type_expr |
|
) |
|
|
|
def _target_to_str(self, target: Target) -> str: |
|
if callable(target): |
|
op = target.__name__ |
|
else: |
|
assert isinstance(target, str) |
|
op = target |
|
if _is_magic(op): |
|
op = op[2:-2] |
|
op = _snake_case(op) |
|
return op |
|
|
|
@compatibility(is_backward_compatible=True) |
|
def python_code( |
|
self, |
|
root_module: str, |
|
*, |
|
verbose: bool = False, |
|
include_stride: bool = False, |
|
include_device: bool = False, |
|
colored: bool = False, |
|
) -> PythonCode: |
|
""" |
|
Turn this ``Graph`` into valid Python code. |
|
|
|
Args: |
|
|
|
root_module (str): The name of the root module on which to look-up |
|
qualified name targets. This is usually 'self'. |
|
|
|
Returns: |
|
|
|
A PythonCode object, consisting of two fields: |
|
src: the Python source code representing the object |
|
globals: a dictionary of global names in `src` -> the objects that they reference. |
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
namespace = _Namespace() |
|
|
|
|
|
|
|
|
|
|
|
|
|
def node_repr(n: Node): |
|
return namespace.create_name(n.name, n) |
|
|
|
@contextmanager |
|
def override_node_repr(graph: Graph): |
|
orig_repr_fns = {} |
|
for node in graph.nodes: |
|
orig_repr_fns[node] = node._repr_fn |
|
node._repr_fn = node_repr |
|
try: |
|
yield None |
|
finally: |
|
|
|
for node in graph.nodes: |
|
node._repr_fn = orig_repr_fns[node] |
|
|
|
with override_node_repr(self): |
|
return self._python_code( |
|
root_module, |
|
namespace, |
|
verbose=verbose, |
|
include_stride=include_stride, |
|
include_device=include_device, |
|
colored=colored, |
|
) |
|
|
|
def _python_code( |
|
self, |
|
root_module: str, |
|
namespace: _Namespace, |
|
*, |
|
verbose: bool = False, |
|
include_stride: bool = False, |
|
include_device: bool = False, |
|
colored: bool = False, |
|
) -> PythonCode: |
|
return self._codegen._gen_python_code( |
|
self.nodes, |
|
root_module, |
|
namespace, |
|
verbose=verbose, |
|
include_stride=include_stride, |
|
include_device=include_device, |
|
colored=colored, |
|
) |
|
|
|
def __str__(self) -> str: |
|
""" |
|
Return a human-readable (not machine-readable) string representation |
|
of this Graph |
|
""" |
|
placeholder_names: list[str] = [] |
|
|
|
|
|
maybe_return_typename: list[str] = [""] |
|
|
|
node_strs = [node.format_node(placeholder_names) for node in self.nodes] |
|
param_str = ", ".join(placeholder_names) |
|
s = f"graph({param_str}){maybe_return_typename[0]}:" |
|
for node_str in node_strs: |
|
if node_str: |
|
s += "\n " + node_str |
|
return s |
|
|
|
@compatibility(is_backward_compatible=True) |
|
def print_tabular(self): |
|
""" |
|
Prints the intermediate representation of the graph in tabular |
|
format. Note that this API requires the ``tabulate`` module to be |
|
installed. |
|
""" |
|
try: |
|
from tabulate import tabulate |
|
except ImportError: |
|
print( |
|
"`print_tabular` relies on the library `tabulate`, " |
|
"which could not be found on this machine. Run `pip " |
|
"install tabulate` to install the library." |
|
) |
|
raise |
|
|
|
node_specs = [[n.op, n.name, n.target, n.args, n.kwargs] for n in self.nodes] |
|
print( |
|
tabulate(node_specs, headers=["opcode", "name", "target", "args", "kwargs"]) |
|
) |
|
|
|
@compatibility(is_backward_compatible=True) |
|
def lint(self): |
|
""" |
|
Runs various checks on this Graph to make sure it is well-formed. In |
|
particular: |
|
- Checks Nodes have correct ownership (owned by this graph) |
|
- Checks Nodes appear in topological order |
|
- If this Graph has an owning GraphModule, checks that targets |
|
exist in that GraphModule |
|
""" |
|
|
|
|
|
def check_arg(arg: Node, n: Optional[Node] = None) -> None: |
|
context_str = f" of Node '{n}' " if n else " " |
|
if arg.graph is not self: |
|
raise RuntimeError( |
|
f"Argument '{arg}'{context_str}does not belong to this Graph, " |
|
f"but was used as an argument! If you are copying nodes from another graph, make " |
|
f"sure to use ``arg_transform`` on node_copy() to remap values\n{self}" |
|
) |
|
if arg not in seen_values: |
|
raise RuntimeError( |
|
f"Argument '{arg}'{context_str}was used before it has been " |
|
f"defined! Please check that Nodes in the graph are topologically ordered\n{self}" |
|
) |
|
|
|
seen_names: set[str] = set() |
|
seen_values: set[Node] = set() |
|
for node in self.nodes: |
|
if node.op not in _legal_ops: |
|
raise RuntimeError(f"Node {node} had unknown opcode {node.op}!") |
|
if node.graph is not self: |
|
raise RuntimeError(f"Node '{node}' does not belong to this Graph!") |
|
if node not in self._find_nodes_lookup_table: |
|
raise RuntimeError(f"Node '{node}' is not added to the side table") |
|
for arg in node._input_nodes: |
|
check_arg(arg, node) |
|
seen_values.add(node) |
|
|
|
if node.name in seen_names: |
|
raise RuntimeError(f"Node redefined name {node.name}!") |
|
seen_names.add(node.name) |
|
|
|
|
|
if self.owning_module: |
|
num_warnings = 0 |
|
MAX_WARNINGS = 5 |
|
for node in self.nodes: |
|
if node.op == "call_function": |
|
if not callable(node.target): |
|
raise ValueError( |
|
f"Node {node} target {node.target} has type {torch.typename(node.target)} but " |
|
"a Callable is expected" |
|
) |
|
else: |
|
if not isinstance(node.target, str): |
|
raise ValueError( |
|
f"Node {node} target {node.target} has type {torch.typename(node.target)} but " |
|
"a str is expected" |
|
) |
|
if node.op in ["get_attr", "call_module"]: |
|
target_atoms = node.target.split(".") |
|
m_itr = self.owning_module |
|
for i, atom in enumerate(target_atoms): |
|
new_m_itr = getattr(m_itr, atom, None) |
|
seen_qualname = ".".join(target_atoms[:i]) |
|
if new_m_itr is None: |
|
raise RuntimeError( |
|
f"Node {node} target {node.target} references nonexistent attribute " |
|
f"{atom} of {seen_qualname}" |
|
) |
|
if node.op == "call_module" and not isinstance( |
|
new_m_itr, torch.nn.Module |
|
): |
|
raise RuntimeError( |
|
f"Node {node} target {node.target} {atom} of {seen_qualname} does " |
|
"not reference an nn.Module" |
|
) |
|
elif ( |
|
node.op == "get_attr" |
|
and not isinstance(new_m_itr, torch.nn.Module) |
|
and not isinstance(new_m_itr, torch.nn.Parameter) |
|
and atom not in m_itr._buffers |
|
): |
|
if num_warnings < MAX_WARNINGS: |
|
|
|
|
|
|
|
warnings.warn( |
|
f"Node {node} target {node.target} {atom} of {seen_qualname} does " |
|
"not reference an nn.Module, nn.Parameter, or buffer, which is " |
|
"what 'get_attr' Nodes typically target" |
|
) |
|
num_warnings += 1 |
|
else: |
|
m_itr = new_m_itr |
|
if num_warnings > MAX_WARNINGS: |
|
warnings.warn( |
|
f"Additional {num_warnings - MAX_WARNINGS} warnings " |
|
"suppressed about get_attr references" |
|
) |
|
|
|
@compatibility(is_backward_compatible=True) |
|
def eliminate_dead_code( |
|
self, is_impure_node: Optional[Callable[[Node], bool]] = None |
|
) -> bool: |
|
""" |
|
Remove all dead code from the graph, based on each node's number of |
|
users, and whether the nodes have any side effects. The graph must be |
|
topologically sorted before calling. |
|
|
|
Args: |
|
is_impure_node (Optional[Callable[[Node], bool]]): A function that returns |
|
whether a node is impure. If this is None, then the default behavior is to |
|
use Node.is_impure. |
|
|
|
Returns: |
|
bool: Whether the graph was changed as a result of the pass. |
|
|
|
Example: |
|
|
|
Before dead code is eliminated, `a` from `a = x + 1` below has no users |
|
and thus can be eliminated from the graph without having an effect. |
|
|
|
.. code-block:: python |
|
|
|
def forward(self, x): |
|
a = x + 1 |
|
return x + self.attr_1 |
|
|
|
After dead code is eliminated, `a = x + 1` has been removed, and the rest |
|
of `forward` remains. |
|
|
|
.. code-block:: python |
|
|
|
def forward(self, x): |
|
return x + self.attr_1 |
|
|
|
.. warning:: |
|
|
|
Dead code elimination has some heuristics to avoid removing |
|
side-effectful nodes (see Node.is_impure) but in general coverage |
|
is very bad, so you should assume that this method is not sound |
|
to call unless you know that your FX graph consists entirely |
|
of functional operations or you supply your own custom |
|
function for detecting side-effectful nodes. |
|
""" |
|
|
|
|
|
self.lint() |
|
|
|
def has_side_effect(node): |
|
if is_impure_node is not None: |
|
return is_impure_node(node) |
|
return node.is_impure() |
|
|
|
|
|
|
|
|
|
changed = False |
|
for node in reversed(self.nodes): |
|
if not has_side_effect(node) and len(node.users) == 0: |
|
self.erase_node(node) |
|
changed = True |
|
|
|
return changed |
|
|
|
@compatibility(is_backward_compatible=False) |
|
def set_codegen(self, codegen: CodeGen): |
|
self._codegen = codegen |
|
|
|
@compatibility(is_backward_compatible=False) |
|
def on_generate_code( |
|
self, |
|
make_transformer: Callable[[Optional[TransformCodeFunc]], TransformCodeFunc], |
|
): |
|
"""Register a transformer function when python code is generated |
|
|
|
Args: |
|
make_transformer (Callable[[Optional[TransformCodeFunc]], TransformCodeFunc]): |
|
a function that returns a code transformer to be registered. |
|
This function is called by `on_generate_code` to obtain the |
|
code transformer. |
|
|
|
This function is also given as its input the currently |
|
registered code transformer (or None if nothing is registered), |
|
in case it is not desirable to overwrite it. This is useful to |
|
chain code transformers together. |
|
|
|
Returns: |
|
a context manager that when used in a `with` statement, to automatically |
|
restore the previously registered code transformer. |
|
|
|
Example: |
|
|
|
.. code-block:: python |
|
|
|
|
|
gm: fx.GraphModule = ... |
|
|
|
|
|
# This is a code transformer we want to register. This code |
|
# transformer prepends a pdb import and trace statement at the very |
|
# beginning of the generated torch.fx code to allow for manual |
|
# debugging with the PDB library. |
|
def insert_pdb(body): |
|
return ["import pdb; pdb.set_trace()\\n", *body] |
|
|
|
|
|
# Registers `insert_pdb`, and overwrites the current registered |
|
# code transformer (given by `_` to the lambda): |
|
gm.graph.on_generate_code(lambda _: insert_pdb) |
|
|
|
# Or alternatively, registers a code transformer which first |
|
# runs `body` through existing registered transformer, then |
|
# through `insert_pdb`: |
|
gm.graph.on_generate_code( |
|
lambda current_trans: ( |
|
lambda body: insert_pdb(current_trans(body) if current_trans else body) |
|
) |
|
) |
|
|
|
gm.recompile() |
|
gm(*inputs) # drops into pdb |
|
|
|
|
|
This function can also be used as a context manager, with the benefit to |
|
automatically restores the previously registered code transformer: |
|
|
|
.. code-block:: python |
|
|
|
# ... continue from previous example |
|
|
|
with gm.graph.on_generate_code(lambda _: insert_pdb): |
|
# do more stuff with `gm`... |
|
gm.recompile() |
|
gm(*inputs) # drops into pdb |
|
|
|
# now previous code transformer is restored (but `gm`'s code with pdb |
|
# remains - that means you can run `gm` with pdb here too, until you |
|
# run next `recompile()`). |
|
""" |
|
on_gen_code_old = self._codegen._body_transformer |
|
self._codegen._body_transformer = make_transformer(on_gen_code_old) |
|
|
|
@contextlib.contextmanager |
|
def on_generate_code_context_manager(): |
|
try: |
|
yield |
|
finally: |
|
self._codegen._body_transformer = on_gen_code_old |
|
|
|
return on_generate_code_context_manager() |
|
|
|
|
|
def _identity(x): |
|
return x |
|
|
|
|
|
def _make_color_fn(code): |
|
def f(s): |
|
reset = "\033[0m" |
|
return f"{code}{s}{reset}" |
|
|
|
return f |
|
|
|
|
|
_color_codes = { |
|
"yellow": "\033[33m", |
|
"cyan": "\033[36m", |
|
"green": "\033[32m", |
|
"blue": "\033[34m", |
|
"red": "\033[31m", |
|
"dim": "\033[2m", |
|
"dim_blue": "\033[2m\033[34m", |
|
"dim_green": "\033[2m\033[32m", |
|
} |
|
_color_fns = {k: _make_color_fn(v) for k, v in _color_codes.items()} |
|
_counter_regexp = re.compile(r"# COUNTER: (\d+)") |
|
|
|
|
|
reflectable_magic_methods = { |
|
"add": "{} + {}", |
|
"sub": "{} - {}", |
|
"mul": "{} * {}", |
|
"floordiv": "{} // {}", |
|
"truediv": "{} / {}", |
|
"div": "{} / {}", |
|
"mod": "{} % {}", |
|
"pow": "{} ** {}", |
|
"lshift": "{} << {}", |
|
"rshift": "{} >> {}", |
|
"and_": "{} & {}", |
|
"or_": "{} | {}", |
|
"xor": "{} ^ {}", |
|
"getitem": "{}[{}]", |
|
"matmul": "{} @ {}", |
|
} |
|
|
|
magic_methods = { |
|
"eq": "{} == {}", |
|
"ne": "{} != {}", |
|
"lt": "{} < {}", |
|
"gt": "{} > {}", |
|
"le": "{} <= {}", |
|
"ge": "{} >= {}", |
|
"pos": "+{}", |
|
"neg": "-{}", |
|
"invert": "~{}", |
|
**reflectable_magic_methods, |
|
} |
|
|
|
inplace_methods = { |
|
"iadd": "{} += {}", |
|
"iand": "{} &= {}", |
|
"ifloordiv": "{} //= {}", |
|
"ilshift": "{} <<= {}", |
|
"imod": "{} %= {}", |
|
"imul": "{} *= {}", |
|
"imatmul": "{} @= {}", |
|
"ior": "{} |= {}", |
|
"ipow": "{} **= {}", |
|
"irshift": "{} >>= {}", |
|
"isub": "{} -= {}", |
|
"itruediv": "{} /= {}", |
|
"ixor": "{} ^= {}", |
|
"setitem": "{}[{}] = {}", |
|
} |
|
|