File size: 5,468 Bytes
9c6594c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 |
# mypy: allow-untyped-decorators
# mypy: allow-untyped-defs
import ast
import copy
import functools
import inspect
import textwrap
from types import FunctionType
from typing import Any, Callable, cast, Optional, Union
import torch
from torch._sources import normalize_source_lines
from torch.fx._symbolic_trace import Tracer
from torch.fx.graph import Graph
class AST_Rewriter(ast.NodeTransformer):
"""
Take a FunctionType object representing a `forward` method, then
perform an AST rewrite to swap out nodes that are not symbolically
traceable with a callsite to the FX alternative.
To support swapping out an AST node, define a new `visit` method on
that node. For more details, see:
https://docs.python.org/3/library/ast.html#ast.NodeTransformer
"""
# This function checks for new keys added in the globals dict. TorchDynamo
# can insert new keys in the global dict and upset the check. Therefore, put
# a disable here. This function is an optimization pass and not really
# suitable for dynamo tracing anyways.
@torch._dynamo.disable
def rewrite(self, fn: FunctionType):
# Normalize the source lines
sourcelines, _ = inspect.getsourcelines(fn)
sourcelines = normalize_source_lines(sourcelines)
source = "".join(sourcelines)
normalized_str = textwrap.dedent(source)
# Rewrite the original AST
source_ast = ast.parse(normalized_str)
dest_ast = ast.fix_missing_locations(self.visit(source_ast))
# Pull out the compiled function from the newly-created Module
code = compile(dest_ast, "", "exec")
globals_dict = copy.copy(fn.__globals__)
keys_before = set(globals_dict.keys())
exec(code, globals_dict)
new_keys = list(set(globals_dict.keys()) - keys_before)
assert len(new_keys) == 1
fn_compiled = globals_dict[new_keys[0]]
# return the compiled function with the original globals
def change_func_globals(f, globals):
"""Based on https://stackoverflow.com/a/13503277/2988730 (@unutbu)"""
# __globals__ is a private member of the function class
# so we have to copy the function, f, all of its member, except f.__globals__
g = FunctionType(
f.__code__,
globals,
name=f.__name__,
argdefs=f.__defaults__,
closure=f.__closure__,
)
g = functools.update_wrapper(g, f)
g.__kwdefaults__ = copy.copy(f.__kwdefaults__) # type:ignore[attr-defined]
return g
# Return the correct FunctionType object
return change_func_globals(fn_compiled, globals=fn.__globals__)
def visit_Assert(self, node):
"""
Swap out the Assert node (Python's `assert`) with a callsite to the
symbolically-traceable torch._assert function
"""
# Create the Call node
n = ast.parse("torch._assert()", mode="eval")
assert isinstance(n, ast.Expression)
call_node = n.body
assert isinstance(call_node, ast.Call)
msg = node.msg if node.msg else ast.Constant(value="", kind=None)
call_node.args = [node.test, msg]
# Ensure that the new node conforms to the Python AST grammar
expr_wrapper = ast.Expr(value=call_node)
# Return the new Call node to signify that we want to use it as
# a replacement for the original _assert node
return ast.copy_location(expr_wrapper, node)
def visit_AnnAssign(self, node):
"""
Swap out Python's AnnAssign with an Assign node where the annotation function is called.
Example:
Original:
y: Tensor_Type(1,2,3, Dyn) = f2(x)
Output:
y = annotate(f2(x),Tensor_Type((1,2,3,Dyn)))
"""
return ast.Assign(
targets=[node.target],
value=ast.Call(
func=ast.Name(id="annotate", ctx=ast.Load()),
args=[node.value, node.annotation],
keywords=[],
),
)
class RewritingTracer(Tracer):
def trace(
self,
root: Union[torch.nn.Module, Callable],
concrete_args: Optional[dict[str, Any]] = None,
) -> Graph:
return super().trace(_rewrite(root), concrete_args)
def _rewrite(fn: Union[torch.nn.Module, Callable]) -> Union[torch.nn.Module, Callable]:
if isinstance(fn, torch.nn.Module):
# Rewrite this module's `forward` as well as the `forward`s of
# all of this module's recursive descendents. Return the new,
# rewritten module hierarchy.
def rewrite_module(m: torch.nn.Module):
class RewrittenModule(torch.nn.Module):
def __init__(self, orig):
super().__init__()
for k, v in orig.__dict__.items():
if isinstance(v, torch.nn.Module):
self.__dict__[k] = copy.copy(rewrite_module(v))
else:
self.__dict__[k] = copy.copy(v)
RewrittenModule.forward = AST_Rewriter().rewrite(
cast(FunctionType, m.forward)
)
return RewrittenModule(m)
return rewrite_module(fn)
else:
# Rewrite this single free function
return AST_Rewriter().rewrite(cast(FunctionType, fn))
|