|
|
|
|
|
import ast |
|
import copy |
|
import functools |
|
import inspect |
|
import textwrap |
|
from types import FunctionType |
|
from typing import Any, Callable, cast, Optional, Union |
|
|
|
import torch |
|
from torch._sources import normalize_source_lines |
|
from torch.fx._symbolic_trace import Tracer |
|
from torch.fx.graph import Graph |
|
|
|
|
|
class AST_Rewriter(ast.NodeTransformer): |
|
""" |
|
Take a FunctionType object representing a `forward` method, then |
|
perform an AST rewrite to swap out nodes that are not symbolically |
|
traceable with a callsite to the FX alternative. |
|
|
|
To support swapping out an AST node, define a new `visit` method on |
|
that node. For more details, see: |
|
https://docs.python.org/3/library/ast.html#ast.NodeTransformer |
|
""" |
|
|
|
|
|
|
|
|
|
|
|
@torch._dynamo.disable |
|
def rewrite(self, fn: FunctionType): |
|
|
|
sourcelines, _ = inspect.getsourcelines(fn) |
|
sourcelines = normalize_source_lines(sourcelines) |
|
source = "".join(sourcelines) |
|
normalized_str = textwrap.dedent(source) |
|
|
|
|
|
source_ast = ast.parse(normalized_str) |
|
dest_ast = ast.fix_missing_locations(self.visit(source_ast)) |
|
|
|
|
|
code = compile(dest_ast, "", "exec") |
|
globals_dict = copy.copy(fn.__globals__) |
|
keys_before = set(globals_dict.keys()) |
|
exec(code, globals_dict) |
|
new_keys = list(set(globals_dict.keys()) - keys_before) |
|
assert len(new_keys) == 1 |
|
fn_compiled = globals_dict[new_keys[0]] |
|
|
|
|
|
def change_func_globals(f, globals): |
|
"""Based on https://stackoverflow.com/a/13503277/2988730 (@unutbu)""" |
|
|
|
|
|
g = FunctionType( |
|
f.__code__, |
|
globals, |
|
name=f.__name__, |
|
argdefs=f.__defaults__, |
|
closure=f.__closure__, |
|
) |
|
g = functools.update_wrapper(g, f) |
|
g.__kwdefaults__ = copy.copy(f.__kwdefaults__) |
|
return g |
|
|
|
|
|
return change_func_globals(fn_compiled, globals=fn.__globals__) |
|
|
|
def visit_Assert(self, node): |
|
""" |
|
Swap out the Assert node (Python's `assert`) with a callsite to the |
|
symbolically-traceable torch._assert function |
|
""" |
|
|
|
n = ast.parse("torch._assert()", mode="eval") |
|
assert isinstance(n, ast.Expression) |
|
call_node = n.body |
|
assert isinstance(call_node, ast.Call) |
|
msg = node.msg if node.msg else ast.Constant(value="", kind=None) |
|
call_node.args = [node.test, msg] |
|
|
|
|
|
expr_wrapper = ast.Expr(value=call_node) |
|
|
|
|
|
|
|
return ast.copy_location(expr_wrapper, node) |
|
|
|
def visit_AnnAssign(self, node): |
|
""" |
|
Swap out Python's AnnAssign with an Assign node where the annotation function is called. |
|
Example: |
|
Original: |
|
y: Tensor_Type(1,2,3, Dyn) = f2(x) |
|
Output: |
|
y = annotate(f2(x),Tensor_Type((1,2,3,Dyn))) |
|
""" |
|
return ast.Assign( |
|
targets=[node.target], |
|
value=ast.Call( |
|
func=ast.Name(id="annotate", ctx=ast.Load()), |
|
args=[node.value, node.annotation], |
|
keywords=[], |
|
), |
|
) |
|
|
|
|
|
class RewritingTracer(Tracer): |
|
def trace( |
|
self, |
|
root: Union[torch.nn.Module, Callable], |
|
concrete_args: Optional[dict[str, Any]] = None, |
|
) -> Graph: |
|
return super().trace(_rewrite(root), concrete_args) |
|
|
|
|
|
def _rewrite(fn: Union[torch.nn.Module, Callable]) -> Union[torch.nn.Module, Callable]: |
|
if isinstance(fn, torch.nn.Module): |
|
|
|
|
|
|
|
def rewrite_module(m: torch.nn.Module): |
|
class RewrittenModule(torch.nn.Module): |
|
def __init__(self, orig): |
|
super().__init__() |
|
for k, v in orig.__dict__.items(): |
|
if isinstance(v, torch.nn.Module): |
|
self.__dict__[k] = copy.copy(rewrite_module(v)) |
|
else: |
|
self.__dict__[k] = copy.copy(v) |
|
|
|
RewrittenModule.forward = AST_Rewriter().rewrite( |
|
cast(FunctionType, m.forward) |
|
) |
|
return RewrittenModule(m) |
|
|
|
return rewrite_module(fn) |
|
else: |
|
|
|
return AST_Rewriter().rewrite(cast(FunctionType, fn)) |
|
|