|
|
|
import ast |
|
import builtins |
|
import dis |
|
import enum |
|
import inspect |
|
import re |
|
import typing |
|
import warnings |
|
from textwrap import dedent |
|
|
|
import torch |
|
from torch._C import ( |
|
_GeneratorType, |
|
AnyType, |
|
AwaitType, |
|
BoolType, |
|
ComplexType, |
|
DeviceObjType, |
|
DictType, |
|
EnumType, |
|
FloatType, |
|
FutureType, |
|
InterfaceType, |
|
IntType, |
|
ListType, |
|
NoneType, |
|
NumberType, |
|
OptionalType, |
|
StreamObjType, |
|
StringType, |
|
TensorType, |
|
TupleType, |
|
UnionType, |
|
) |
|
from torch._jit_internal import ( |
|
_Await, |
|
_qualified_name, |
|
Any, |
|
BroadcastingList1, |
|
BroadcastingList2, |
|
BroadcastingList3, |
|
Dict, |
|
Future, |
|
is_await, |
|
is_dict, |
|
is_future, |
|
is_ignored_fn, |
|
is_list, |
|
is_optional, |
|
is_tuple, |
|
is_union, |
|
List, |
|
Optional, |
|
Tuple, |
|
Union, |
|
) |
|
from torch._sources import get_source_lines_and_file |
|
|
|
from ._state import _get_script_class |
|
|
|
|
|
if torch.distributed.rpc.is_available(): |
|
from torch._C import RRefType |
|
from torch._jit_internal import is_rref, RRef |
|
|
|
from torch._ops import OpOverloadPacket |
|
|
|
|
|
class Module: |
|
def __init__(self, name, members): |
|
self.name = name |
|
self.members = members |
|
|
|
def __getattr__(self, name): |
|
try: |
|
return self.members[name] |
|
except KeyError: |
|
raise RuntimeError( |
|
f"Module {self.name} has no member called {name}" |
|
) from None |
|
|
|
|
|
class EvalEnv: |
|
env = { |
|
"torch": Module("torch", {"Tensor": torch.Tensor}), |
|
"Tensor": torch.Tensor, |
|
"typing": Module("typing", {"Tuple": Tuple}), |
|
"Tuple": Tuple, |
|
"List": List, |
|
"Dict": Dict, |
|
"Optional": Optional, |
|
"Union": Union, |
|
"Future": Future, |
|
"Await": _Await, |
|
} |
|
|
|
def __init__(self, rcb): |
|
self.rcb = rcb |
|
if torch.distributed.rpc.is_available(): |
|
self.env["RRef"] = RRef |
|
|
|
def __getitem__(self, name): |
|
if name in self.env: |
|
return self.env[name] |
|
if self.rcb is not None: |
|
return self.rcb(name) |
|
return getattr(builtins, name, None) |
|
|
|
|
|
def get_signature(fn, rcb, loc, is_method): |
|
if isinstance(fn, OpOverloadPacket): |
|
signature = try_real_annotations(fn.op, loc) |
|
else: |
|
signature = try_real_annotations(fn, loc) |
|
if signature is not None and is_method: |
|
|
|
|
|
|
|
|
|
param_types, return_type = signature |
|
param_types = param_types[1:] |
|
signature = (param_types, return_type) |
|
|
|
if signature is None: |
|
type_line, source = None, None |
|
try: |
|
source = dedent("".join(get_source_lines_and_file(fn)[0])) |
|
type_line = get_type_line(source) |
|
except TypeError: |
|
pass |
|
|
|
|
|
if type_line is not None: |
|
signature = parse_type_line(type_line, rcb, loc) |
|
|
|
return signature |
|
|
|
|
|
def is_function_or_method(the_callable): |
|
|
|
|
|
return inspect.isfunction(the_callable) or inspect.ismethod(the_callable) |
|
|
|
|
|
def is_vararg(the_callable): |
|
if not is_function_or_method(the_callable) and callable(the_callable): |
|
|
|
|
|
the_callable = the_callable.__call__ |
|
|
|
if is_function_or_method(the_callable): |
|
return inspect.getfullargspec(the_callable).varargs is not None |
|
else: |
|
return False |
|
|
|
|
|
def get_param_names(fn, n_args): |
|
if isinstance(fn, OpOverloadPacket): |
|
fn = fn.op |
|
|
|
if ( |
|
not is_function_or_method(fn) |
|
and callable(fn) |
|
and is_function_or_method(fn.__call__) |
|
): |
|
|
|
fn = fn.__call__ |
|
|
|
if is_function_or_method(fn): |
|
if is_ignored_fn(fn): |
|
fn = inspect.unwrap(fn) |
|
return inspect.getfullargspec(fn).args |
|
else: |
|
|
|
|
|
return [str(i) for i in range(n_args)] |
|
|
|
|
|
def check_fn(fn, loc): |
|
|
|
try: |
|
source = dedent("".join(get_source_lines_and_file(fn)[0])) |
|
except (OSError, TypeError): |
|
return |
|
if source is None: |
|
return |
|
|
|
py_ast = ast.parse(source) |
|
if len(py_ast.body) == 1 and isinstance(py_ast.body[0], ast.ClassDef): |
|
raise torch.jit.frontend.FrontendError( |
|
loc, |
|
f"Cannot instantiate class '{py_ast.body[0].name}' in a script function", |
|
) |
|
if len(py_ast.body) != 1 or not isinstance(py_ast.body[0], ast.FunctionDef): |
|
raise torch.jit.frontend.FrontendError( |
|
loc, "Expected a single top-level function" |
|
) |
|
|
|
|
|
def _eval_no_call(stmt, glob, loc): |
|
"""Evaluate statement as long as it does not contain any method/function calls.""" |
|
bytecode = compile(stmt, "", mode="eval") |
|
for insn in dis.get_instructions(bytecode): |
|
if "CALL" in insn.opname: |
|
raise RuntimeError( |
|
f"Type annotation should not contain calls, but '{stmt}' does" |
|
) |
|
return eval(bytecode, glob, loc) |
|
|
|
|
|
def parse_type_line(type_line, rcb, loc): |
|
"""Parse a type annotation specified as a comment. |
|
|
|
Example inputs: |
|
# type: (Tensor, torch.Tensor) -> Tuple[Tensor] |
|
# type: (Tensor, Tuple[Tensor, Tensor]) -> Tensor |
|
""" |
|
arg_ann_str, ret_ann_str = split_type_line(type_line) |
|
|
|
try: |
|
arg_ann = _eval_no_call(arg_ann_str, {}, EvalEnv(rcb)) |
|
except (NameError, SyntaxError) as e: |
|
raise RuntimeError( |
|
"Failed to parse the argument list of a type annotation" |
|
) from e |
|
|
|
if not isinstance(arg_ann, tuple): |
|
arg_ann = (arg_ann,) |
|
|
|
try: |
|
ret_ann = _eval_no_call(ret_ann_str, {}, EvalEnv(rcb)) |
|
except (NameError, SyntaxError) as e: |
|
raise RuntimeError( |
|
"Failed to parse the return type of a type annotation" |
|
) from e |
|
|
|
arg_types = [ann_to_type(ann, loc) for ann in arg_ann] |
|
return arg_types, ann_to_type(ret_ann, loc) |
|
|
|
|
|
def get_type_line(source): |
|
"""Try to find the line containing a comment with the type annotation.""" |
|
type_comment = "# type:" |
|
|
|
lines = source.split("\n") |
|
lines = list(enumerate(lines)) |
|
type_lines = list(filter(lambda line: type_comment in line[1], lines)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
type_pattern = re.compile("# type:\\ ignore(\\[[a-zA-Z-]+\\])?$") |
|
type_lines = list(filter(lambda line: not type_pattern.search(line[1]), type_lines)) |
|
|
|
if len(type_lines) == 0: |
|
|
|
wrong_type_pattern = re.compile("#[\t ]*type[\t ]*(?!: ignore(\\[.*\\])?$):") |
|
wrong_type_lines = list( |
|
filter(lambda line: wrong_type_pattern.search(line[1]), lines) |
|
) |
|
if len(wrong_type_lines) > 0: |
|
raise RuntimeError( |
|
"The annotation prefix in line " |
|
+ str(wrong_type_lines[0][0]) |
|
+ " is probably invalid.\nIt must be '# type:'" |
|
+ "\nSee PEP 484 (https://www.python.org/dev/peps/pep-0484/#suggested-syntax-for-python-2-7-and-straddling-code)" |
|
+ "\nfor examples" |
|
) |
|
return None |
|
elif len(type_lines) == 1: |
|
|
|
return type_lines[0][1].strip() |
|
|
|
|
|
|
|
return_line = None |
|
parameter_type_lines = [] |
|
for line_num, line in type_lines: |
|
if "# type: (...) -> " in line: |
|
return_line = (line_num, line) |
|
break |
|
elif type_comment in line: |
|
parameter_type_lines.append(line) |
|
if return_line is None: |
|
raise RuntimeError( |
|
"Return type line '# type: (...) -> ...' not found on multiline " |
|
"type annotation\nfor type lines:\n" |
|
+ "\n".join([line[1] for line in type_lines]) |
|
+ "\n(See PEP 484 https://www.python.org/dev/peps/pep-0484/#suggested-syntax-for-python-2-7-and-straddling-code)" |
|
) |
|
|
|
def get_parameter_type(line): |
|
item_type = line[line.find(type_comment) + len(type_comment) :] |
|
return item_type.strip() |
|
|
|
types = map(get_parameter_type, parameter_type_lines) |
|
parameter_types = ", ".join(types) |
|
|
|
return return_line[1].replace("...", parameter_types) |
|
|
|
|
|
def split_type_line(type_line): |
|
"""Split the comment with the type annotation into parts for argument and return types. |
|
|
|
For example, for an input of: |
|
# type: (Tensor, torch.Tensor) -> Tuple[Tensor, Tensor] |
|
|
|
This function will return: |
|
("(Tensor, torch.Tensor)", "Tuple[Tensor, Tensor]") |
|
|
|
""" |
|
start_offset = len("# type:") |
|
try: |
|
arrow_pos = type_line.index("->") |
|
except ValueError: |
|
raise RuntimeError( |
|
"Syntax error in type annotation (couldn't find `->`)" |
|
) from None |
|
return type_line[start_offset:arrow_pos].strip(), type_line[arrow_pos + 2 :].strip() |
|
|
|
|
|
def try_real_annotations(fn, loc): |
|
"""Try to use the Py3.5+ annotation syntax to get the type.""" |
|
try: |
|
|
|
|
|
|
|
sig = inspect.signature(fn) |
|
except ValueError: |
|
return None |
|
|
|
all_annots = [sig.return_annotation] + [ |
|
p.annotation for p in sig.parameters.values() |
|
] |
|
if all(ann is sig.empty for ann in all_annots): |
|
return None |
|
|
|
arg_types = [ann_to_type(p.annotation, loc) for p in sig.parameters.values()] |
|
return_type = ann_to_type(sig.return_annotation, loc) |
|
return arg_types, return_type |
|
|
|
|
|
|
|
|
|
def get_enum_value_type(e: type[enum.Enum], loc): |
|
enum_values: List[enum.Enum] = list(e) |
|
if not enum_values: |
|
raise ValueError(f"No enum values defined for: '{e.__class__}'") |
|
|
|
types = {type(v.value) for v in enum_values} |
|
ir_types = [try_ann_to_type(t, loc) for t in types] |
|
|
|
|
|
|
|
|
|
|
|
res = torch._C.unify_type_list(ir_types) |
|
if not res: |
|
return AnyType.get() |
|
return res |
|
|
|
|
|
def is_tensor(ann): |
|
if issubclass(ann, torch.Tensor): |
|
return True |
|
|
|
if issubclass( |
|
ann, |
|
( |
|
torch.LongTensor, |
|
torch.DoubleTensor, |
|
torch.FloatTensor, |
|
torch.IntTensor, |
|
torch.ShortTensor, |
|
torch.HalfTensor, |
|
torch.CharTensor, |
|
torch.ByteTensor, |
|
torch.BoolTensor, |
|
), |
|
): |
|
warnings.warn( |
|
"TorchScript will treat type annotations of Tensor " |
|
"dtype-specific subtypes as if they are normal Tensors. " |
|
"dtype constraints are not enforced in compilation either." |
|
) |
|
return True |
|
|
|
return False |
|
|
|
|
|
def _fake_rcb(inp): |
|
return None |
|
|
|
|
|
def try_ann_to_type(ann, loc, rcb=None): |
|
ann_args = typing.get_args(ann) |
|
|
|
if ann is inspect.Signature.empty: |
|
return TensorType.getInferred() |
|
if ann is None: |
|
return NoneType.get() |
|
if inspect.isclass(ann) and is_tensor(ann): |
|
return TensorType.get() |
|
if is_tuple(ann): |
|
|
|
if len(ann_args) == 1 and ann_args[0] == (): |
|
return TupleType([]) |
|
return TupleType([try_ann_to_type(a, loc) for a in ann_args]) |
|
if is_list(ann): |
|
elem_type = try_ann_to_type(ann_args[0], loc) |
|
if elem_type: |
|
return ListType(elem_type) |
|
if is_dict(ann): |
|
key = try_ann_to_type(ann_args[0], loc) |
|
value = try_ann_to_type(ann_args[1], loc) |
|
|
|
if key is None: |
|
raise ValueError( |
|
f"Unknown type annotation: '{ann_args[0]}' at {loc.highlight()}" |
|
) |
|
if value is None: |
|
raise ValueError( |
|
f"Unknown type annotation: '{ann_args[1]}' at {loc.highlight()}" |
|
) |
|
return DictType(key, value) |
|
if is_optional(ann): |
|
if issubclass(ann_args[1], type(None)): |
|
contained = ann_args[0] |
|
else: |
|
contained = ann_args[1] |
|
valid_type = try_ann_to_type(contained, loc) |
|
msg = "Unsupported annotation {} could not be resolved because {} could not be resolved. At\n{}" |
|
assert valid_type, msg.format(repr(ann), repr(contained), repr(loc)) |
|
return OptionalType(valid_type) |
|
if is_union(ann): |
|
|
|
if set(ann_args) == {int, float, complex}: |
|
return NumberType.get() |
|
inner: List = [] |
|
|
|
|
|
|
|
for a in typing.get_args(ann): |
|
if a is None: |
|
inner.append(NoneType.get()) |
|
maybe_type = try_ann_to_type(a, loc) |
|
msg = "Unsupported annotation {} could not be resolved because {} could not be resolved. At\n{}" |
|
assert maybe_type, msg.format(repr(ann), repr(maybe_type), repr(loc)) |
|
inner.append(maybe_type) |
|
return UnionType(inner) |
|
if torch.distributed.rpc.is_available() and is_rref(ann): |
|
return RRefType(try_ann_to_type(ann_args[0], loc)) |
|
if is_future(ann): |
|
return FutureType(try_ann_to_type(ann_args[0], loc)) |
|
if is_await(ann): |
|
elementType = try_ann_to_type(ann_args[0], loc) if ann_args else AnyType.get() |
|
return AwaitType(elementType) |
|
if ann is float: |
|
return FloatType.get() |
|
if ann is complex: |
|
return ComplexType.get() |
|
if ann is int or ann is torch.SymInt: |
|
return IntType.get() |
|
if ann is str: |
|
return StringType.get() |
|
if ann is bool: |
|
return BoolType.get() |
|
if ann is Any: |
|
return AnyType.get() |
|
if ann is type(None): |
|
return NoneType.get() |
|
if inspect.isclass(ann) and hasattr(ann, "__torch_script_interface__"): |
|
return InterfaceType(ann.__torch_script_interface__) |
|
if ann is torch.device: |
|
return DeviceObjType.get() |
|
if ann is torch.Generator: |
|
return _GeneratorType.get() |
|
if ann is torch.Stream: |
|
return StreamObjType.get() |
|
if ann is torch.dtype: |
|
return IntType.get() |
|
if inspect.isclass(ann) and issubclass(ann, enum.Enum): |
|
if _get_script_class(ann) is None: |
|
scripted_class = torch.jit._script._recursive_compile_class(ann, loc) |
|
name = scripted_class.qualified_name() |
|
else: |
|
name = _qualified_name(ann) |
|
return EnumType(name, get_enum_value_type(ann, loc), list(ann)) |
|
if inspect.isclass(ann): |
|
maybe_script_class = _get_script_class(ann) |
|
if maybe_script_class is not None: |
|
return maybe_script_class |
|
if torch._jit_internal.can_compile_class(ann): |
|
return torch.jit._script._recursive_compile_class(ann, loc) |
|
|
|
|
|
if rcb is None: |
|
rcb = _fake_rcb |
|
return torch._C._resolve_type_from_object(ann, loc, rcb) |
|
|
|
|
|
def ann_to_type(ann, loc, rcb=None): |
|
the_type = try_ann_to_type(ann, loc, rcb) |
|
if the_type is not None: |
|
return the_type |
|
raise ValueError(f"Unknown type annotation: '{ann}' at {loc.highlight()}") |
|
|
|
|
|
__all__ = [ |
|
"Any", |
|
"List", |
|
"BroadcastingList1", |
|
"BroadcastingList2", |
|
"BroadcastingList3", |
|
"Tuple", |
|
"is_tuple", |
|
"is_list", |
|
"Dict", |
|
"is_dict", |
|
"is_optional", |
|
"is_union", |
|
"TensorType", |
|
"TupleType", |
|
"FloatType", |
|
"ComplexType", |
|
"IntType", |
|
"ListType", |
|
"StringType", |
|
"DictType", |
|
"AnyType", |
|
"Module", |
|
|
|
|
|
"get_signature", |
|
"check_fn", |
|
"get_param_names", |
|
"parse_type_line", |
|
"get_type_line", |
|
"split_type_line", |
|
"try_real_annotations", |
|
"try_ann_to_type", |
|
"ann_to_type", |
|
] |
|
|