|
|
|
import ast |
|
import functools |
|
import inspect |
|
from textwrap import dedent |
|
from typing import Any, NamedTuple, Optional |
|
|
|
from torch._C import ErrorReport |
|
from torch._C._jit_tree_views import SourceRangeFactory |
|
|
|
|
|
def get_source_lines_and_file( |
|
obj: Any, |
|
error_msg: Optional[str] = None, |
|
) -> tuple[list[str], int, Optional[str]]: |
|
""" |
|
Wrapper around inspect.getsourcelines and inspect.getsourcefile. |
|
|
|
Returns: (sourcelines, file_lino, filename) |
|
""" |
|
filename = None |
|
try: |
|
filename = inspect.getsourcefile(obj) |
|
sourcelines, file_lineno = inspect.getsourcelines(obj) |
|
except OSError as e: |
|
msg = ( |
|
f"Can't get source for {obj}. TorchScript requires source access in " |
|
"order to carry out compilation, make sure original .py files are " |
|
"available." |
|
) |
|
if error_msg: |
|
msg += "\n" + error_msg |
|
raise OSError(msg) from e |
|
|
|
return sourcelines, file_lineno, filename |
|
|
|
|
|
def normalize_source_lines(sourcelines: list[str]) -> list[str]: |
|
""" |
|
This helper function accepts a list of source lines. It finds the |
|
indentation level of the function definition (`def`), then it indents |
|
all lines in the function body to a point at or greater than that |
|
level. This allows for comments and continued string literals that |
|
are at a lower indentation than the rest of the code. |
|
Args: |
|
sourcelines: function source code, separated into lines by |
|
the '\n' character |
|
Returns: |
|
A list of source lines that have been correctly aligned |
|
""" |
|
|
|
def remove_prefix(text, prefix): |
|
return text[text.startswith(prefix) and len(prefix) :] |
|
|
|
|
|
idx = None |
|
for i, l in enumerate(sourcelines): |
|
if l.lstrip().startswith("def"): |
|
idx = i |
|
break |
|
|
|
|
|
|
|
|
|
if idx is None: |
|
return sourcelines |
|
|
|
|
|
fn_def = sourcelines[idx] |
|
whitespace = fn_def.split("def")[0] |
|
|
|
|
|
aligned_prefix = [ |
|
whitespace + remove_prefix(s, whitespace) for s in sourcelines[:idx] |
|
] |
|
aligned_suffix = [ |
|
whitespace + remove_prefix(s, whitespace) for s in sourcelines[idx + 1 :] |
|
] |
|
|
|
|
|
aligned_prefix.append(fn_def) |
|
return aligned_prefix + aligned_suffix |
|
|
|
|
|
|
|
|
|
class SourceContext(SourceRangeFactory): |
|
def __init__( |
|
self, |
|
source, |
|
filename, |
|
file_lineno, |
|
leading_whitespace_len, |
|
uses_true_division=True, |
|
funcname=None, |
|
): |
|
super().__init__(source, filename, file_lineno, leading_whitespace_len) |
|
self.uses_true_division = uses_true_division |
|
self.filename = filename |
|
self.funcname = funcname |
|
|
|
|
|
@functools.cache |
|
def make_source_context(*args): |
|
return SourceContext(*args) |
|
|
|
|
|
def fake_range(): |
|
return SourceContext("", None, 0, 0).make_raw_range(0, 1) |
|
|
|
|
|
class ParsedDef(NamedTuple): |
|
ast: ast.Module |
|
ctx: SourceContext |
|
source: str |
|
filename: Optional[str] |
|
file_lineno: int |
|
|
|
|
|
def parse_def(fn): |
|
sourcelines, file_lineno, filename = get_source_lines_and_file( |
|
fn, ErrorReport.call_stack() |
|
) |
|
sourcelines = normalize_source_lines(sourcelines) |
|
source = "".join(sourcelines) |
|
dedent_src = dedent(source) |
|
py_ast = ast.parse(dedent_src) |
|
if len(py_ast.body) != 1 or not isinstance(py_ast.body[0], ast.FunctionDef): |
|
raise RuntimeError( |
|
f"Expected a single top-level function: {filename}:{file_lineno}" |
|
) |
|
leading_whitespace_len = len(source.split("\n", 1)[0]) - len( |
|
dedent_src.split("\n", 1)[0] |
|
) |
|
ctx = make_source_context( |
|
source, filename, file_lineno, leading_whitespace_len, True, fn.__name__ |
|
) |
|
return ParsedDef(py_ast, ctx, source, filename, file_lineno) |
|
|