|
from __future__ import annotations |
|
|
|
from dataclasses import dataclass |
|
from typing import TYPE_CHECKING |
|
|
|
from torchgen.api import cpp |
|
from torchgen.api.types import Binding, CppSignature, CppSignatureGroup |
|
from torchgen.gen import pythonify_default |
|
from torchgen.model import ( |
|
Argument, |
|
BaseTy, |
|
BaseType, |
|
FunctionSchema, |
|
ListType, |
|
NativeFunction, |
|
OptionalType, |
|
Return, |
|
Type, |
|
Variant, |
|
) |
|
|
|
|
|
if TYPE_CHECKING: |
|
from collections.abc import Sequence |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass(frozen=True) |
|
class PythonReturns: |
|
returns: tuple[Return, ...] |
|
|
|
|
|
@dataclass(frozen=True) |
|
class PythonArgument: |
|
name: str |
|
type: Type |
|
default: str | None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
default_init: str | None |
|
|
|
|
|
|
|
def argument_str(self, *, method: bool = False, symint: bool = True) -> str: |
|
type_str = ( |
|
argument_type_str(self.type, symint=symint) |
|
.replace("const ", "") |
|
.replace(" &", "") |
|
) |
|
|
|
name = self.name |
|
|
|
|
|
|
|
if name == "self" and type_str in ["Tensor", "Number"] and not method: |
|
name = "input" |
|
|
|
|
|
if self.default is not None: |
|
default = { |
|
"nullptr": "None", |
|
"::std::nullopt": "None", |
|
"std::nullopt": "None", |
|
"{}": "None", |
|
}.get(self.default, self.default) |
|
return f"{type_str} {name}={default}" |
|
else: |
|
return f"{type_str} {name}" |
|
|
|
def argument_str_pyi( |
|
self, *, method: bool = False, deprecated: bool = False |
|
) -> str: |
|
type_str = argument_type_str_pyi(self.type) |
|
|
|
name = self.name |
|
|
|
|
|
|
|
if name == "self" and type_str == "Tensor" and not method and not deprecated: |
|
name = "input" |
|
|
|
if name == "from": |
|
name += "_" |
|
|
|
|
|
if name == "out" and type_str == "Tensor" and not deprecated: |
|
type_str = "Optional[" + type_str + "]" |
|
|
|
|
|
treat_as_no_default = ( |
|
deprecated |
|
and isinstance(self, PythonOutArgument) |
|
and self.default == "None" |
|
) |
|
|
|
|
|
if self.default is not None and not treat_as_no_default: |
|
if ( |
|
isinstance(self.type, ListType) |
|
and self.type.elem == BaseType(BaseTy.int) |
|
and self.default.startswith("{") |
|
and self.default.endswith("}") |
|
): |
|
default = ( |
|
"(" + ", ".join(map(str.strip, self.default[1:-1].split(","))) + ")" |
|
) |
|
else: |
|
default = { |
|
"nullptr": "None", |
|
"::std::nullopt": "None", |
|
"std::nullopt": "None", |
|
"{}": "None", |
|
"c10::MemoryFormat::Contiguous": "contiguous_format", |
|
"QScheme::PER_TENSOR_AFFINE": "per_tensor_affine", |
|
}.get(self.default, self.default) |
|
return f"{name}: {type_str} = {default}" |
|
else: |
|
return f"{name}: {type_str}" |
|
|
|
|
|
@dataclass(frozen=True) |
|
class PythonOutArgument(PythonArgument): |
|
|
|
|
|
|
|
|
|
|
|
outputs: tuple[PythonArgument, ...] |
|
|
|
@staticmethod |
|
def from_outputs(outputs: tuple[PythonArgument, ...]) -> PythonOutArgument | None: |
|
if not outputs: |
|
return None |
|
|
|
size = len(outputs) |
|
if size == 1: |
|
return PythonOutArgument( |
|
name=outputs[0].name, |
|
type=outputs[0].type, |
|
default="None", |
|
default_init=None, |
|
outputs=outputs, |
|
) |
|
elif size > 1: |
|
if any(not a.type.is_tensor_like() for a in outputs): |
|
raise RuntimeError(f"Unsupported output type: {outputs}") |
|
return PythonOutArgument( |
|
name="out", |
|
|
|
type=ListType(BaseType(BaseTy.Tensor), size), |
|
default="None", |
|
default_init=None, |
|
outputs=outputs, |
|
) |
|
raise AssertionError(r"Unexpected PythonOutArgument size") |
|
|
|
|
|
@dataclass(frozen=True) |
|
class PythonSignature: |
|
|
|
name: str |
|
|
|
|
|
|
|
input_args: tuple[PythonArgument, ...] |
|
|
|
|
|
|
|
input_kwargs: tuple[PythonArgument, ...] |
|
|
|
output_args: PythonOutArgument | None |
|
|
|
|
|
returns: PythonReturns |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tensor_options_args: tuple[PythonArgument, ...] |
|
|
|
|
|
method: bool |
|
|
|
@property |
|
def deprecated(self) -> bool: |
|
return False |
|
|
|
def arguments( |
|
self, *, skip_outputs: bool = False, skip_tensor_options: bool = False |
|
) -> tuple[PythonArgument | PythonOutArgument, ...]: |
|
result: list[PythonArgument | PythonOutArgument] = [] |
|
result.extend(self.input_args) |
|
result.extend(self.input_kwargs) |
|
if self.output_args is not None and not skip_outputs: |
|
result.append(self.output_args) |
|
if not skip_tensor_options: |
|
result.extend(self.tensor_options_args) |
|
return tuple(result) |
|
|
|
def arguments_count(self) -> int: |
|
return len(self.arguments()) |
|
|
|
def output_idx(self) -> int: |
|
return len(self.input_args) + len(self.input_kwargs) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def signature_str(self, *, skip_outputs: bool = False, symint: bool = True) -> str: |
|
args = self.arguments(skip_outputs=skip_outputs) |
|
schema_formals: list[str] = [ |
|
a.argument_str(method=self.method, symint=symint) for a in args |
|
] |
|
positional_argc = len(self.input_args) |
|
if len(schema_formals) > positional_argc: |
|
schema_formals.insert(positional_argc, "*") |
|
|
|
return f"{self.name}({', '.join(schema_formals)})" |
|
|
|
def signature_str_pyi(self, *, skip_outputs: bool = False) -> str: |
|
args = self.arguments(skip_outputs=skip_outputs) |
|
schema_formals: list[str] = [ |
|
a.argument_str_pyi(method=self.method) for a in args |
|
] |
|
positional_argc = len(self.input_args) |
|
if len(schema_formals) > positional_argc: |
|
schema_formals.insert(positional_argc, "*") |
|
|
|
|
|
returns_str = returns_str_pyi(self) |
|
|
|
if self.method: |
|
schema_formals.insert(0, "self") |
|
return f"def {self.name}({', '.join(schema_formals)}) -> {returns_str}: ..." |
|
|
|
def signature_str_pyi_vararg(self, *, skip_outputs: bool = False) -> str | None: |
|
|
|
args = self.arguments(skip_outputs=skip_outputs) |
|
schema_formals: list[str] = [ |
|
a.argument_str_pyi(method=self.method) for a in args |
|
] |
|
|
|
num_args = self.arguments_count() |
|
num_positionalargs = len(self.input_args) |
|
|
|
have_vararg_version = False |
|
if num_args > 0: |
|
vararg_type = args[0].type |
|
if ( |
|
isinstance(vararg_type, ListType) |
|
and str(vararg_type.elem) in ["int", "SymInt"] |
|
and num_positionalargs == 1 |
|
): |
|
have_vararg_version = True |
|
|
|
if not have_vararg_version: |
|
return None |
|
|
|
|
|
|
|
assert isinstance(vararg_type, ListType) |
|
schema_formals[0] = ( |
|
"*" + args[0].name + ": " + argument_type_str_pyi(vararg_type.elem) |
|
) |
|
|
|
returns_str = returns_str_pyi(self) |
|
|
|
if self.method: |
|
schema_formals.insert(0, "self") |
|
return f"def {self.name}({', '.join(schema_formals)}) -> {returns_str}: ..." |
|
|
|
|
|
|
|
|
|
@dataclass(frozen=True) |
|
class PythonSignatureDeprecated(PythonSignature): |
|
|
|
deprecated_schema: FunctionSchema |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
deprecated_args_exprs: tuple[str, ...] |
|
|
|
@property |
|
def deprecated(self) -> bool: |
|
return True |
|
|
|
def signature_str(self, *, skip_outputs: bool = False, symint: bool = True) -> str: |
|
return ( |
|
PythonSignature.signature_str( |
|
self, skip_outputs=skip_outputs, symint=symint |
|
) |
|
+ "|deprecated" |
|
) |
|
|
|
def signature_str_pyi(self, *, skip_outputs: bool = False) -> str: |
|
args = self.arguments(skip_outputs=skip_outputs) |
|
schema_formals: list[str] = [ |
|
a.argument_str_pyi(method=self.method, deprecated=True) for a in args |
|
] |
|
positional_argc = len(self.input_args) |
|
if len(schema_formals) > positional_argc: |
|
schema_formals.insert(positional_argc, "*") |
|
|
|
returns_str = returns_str_pyi(self) |
|
return f"def {self.name}({', '.join(schema_formals)}) -> {returns_str}: ..." |
|
|
|
def signature_str_pyi_vararg(self, *, skip_outputs: bool = False) -> str | None: |
|
|
|
return None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass(frozen=True) |
|
class PythonSignatureNativeFunctionPair: |
|
signature: PythonSignature |
|
function: NativeFunction |
|
|
|
|
|
|
|
|
|
|
|
@dataclass(frozen=True) |
|
class PythonSignatureGroup: |
|
|
|
|
|
|
|
signature: PythonSignature |
|
|
|
|
|
base: NativeFunction |
|
|
|
|
|
outplace: NativeFunction | None |
|
|
|
@classmethod |
|
def from_pairs( |
|
cls, |
|
functional: PythonSignatureNativeFunctionPair, |
|
out: PythonSignatureNativeFunctionPair | None, |
|
) -> PythonSignatureGroup: |
|
if out is None: |
|
return PythonSignatureGroup( |
|
signature=functional.signature, |
|
base=functional.function, |
|
outplace=None, |
|
) |
|
|
|
|
|
|
|
signature_kwargs = out.signature.__dict__.copy() |
|
|
|
|
|
|
|
signature_kwargs["tensor_options_args"] = ( |
|
functional.signature.tensor_options_args |
|
) |
|
|
|
return PythonSignatureGroup( |
|
signature=type(out.signature)(**signature_kwargs), |
|
base=functional.function, |
|
outplace=out.function, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass(frozen=True) |
|
class DispatchLambdaArgument: |
|
name: str |
|
type_str: str |
|
is_out_arg: bool |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass(frozen=True) |
|
class PythonArgParserOutputExpr: |
|
|
|
name: str |
|
|
|
|
|
expr: str |
|
|
|
|
|
|
|
index: int |
|
|
|
|
|
argument: PythonArgument |
|
|
|
@property |
|
def is_none_expr(self) -> str: |
|
return f"_r.isNone({self.index})" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass(frozen=True) |
|
class DispatchLambdaArgumentExprs: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
exprs: Sequence[str] |
|
|
|
|
|
|
|
|
|
|
|
|
|
inits: Sequence[str] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _cpp_signature(f: NativeFunction, *, method: bool = False) -> CppSignature: |
|
return CppSignatureGroup.from_native_function(f, method=method).signature |
|
|
|
|
|
def has_tensor_options(f: NativeFunction) -> bool: |
|
return f.func.arguments.tensor_options is not None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def argument_type_str( |
|
t: Type, *, simple_type: bool = False, symint: bool = True |
|
) -> str: |
|
if isinstance(t, BaseType): |
|
if t.name == BaseTy.int: |
|
return "int64_t" |
|
elif t.name == BaseTy.float: |
|
return "double" |
|
elif t.name == BaseTy.str: |
|
return "c10::string_view" |
|
elif t.name in [ |
|
BaseTy.Tensor, |
|
BaseTy.bool, |
|
BaseTy.QScheme, |
|
BaseTy.Scalar, |
|
BaseTy.ScalarType, |
|
BaseTy.Generator, |
|
BaseTy.Storage, |
|
BaseTy.Layout, |
|
BaseTy.Device, |
|
BaseTy.DeviceIndex, |
|
BaseTy.MemoryFormat, |
|
BaseTy.Dimname, |
|
BaseTy.Stream, |
|
BaseTy.SymInt, |
|
]: |
|
|
|
return t.name.name |
|
|
|
elif isinstance(t, OptionalType): |
|
elem = argument_type_str(t.elem, simple_type=simple_type, symint=symint) |
|
return f"{elem}?" |
|
elif isinstance(t, ListType): |
|
size = t.size if not simple_type else None |
|
if str(t.elem) == "bool": |
|
assert t.size is not None |
|
return f"::std::array<bool,{t.size}>" |
|
elif str(t.elem) == "int": |
|
return f"IntArrayRef[{size}]" if size is not None else "IntArrayRef" |
|
elif str(t.elem) == "SymInt": |
|
if symint: |
|
return ( |
|
f"SymIntArrayRef[{size}]" if size is not None else "SymIntArrayRef" |
|
) |
|
else: |
|
return f"IntArrayRef[{size}]" if size is not None else "IntArrayRef" |
|
elif str(t.elem) == "Tensor": |
|
return f"TensorList[{size}]" if size is not None else "TensorList" |
|
elif str(t.elem) == "Scalar": |
|
return f"ScalarList[{size}]" if size is not None else "ScalarList" |
|
elif str(t.elem) == "Tensor?": |
|
if simple_type: |
|
return "c10::List<::std::optional<Tensor>>" |
|
else: |
|
return "const c10::List<::std::optional<Tensor>> &" |
|
elif str(t.elem) == "Dimname": |
|
return f"DimnameList[{size}]" if size is not None else "DimnameList" |
|
elem = argument_type_str(t.elem, simple_type=simple_type, symint=symint) |
|
return f"ArrayRef<{elem}>" |
|
|
|
raise RuntimeError(f"unrecognized type {repr(t)}") |
|
|
|
|
|
def argument_type_size(t: Type) -> int | None: |
|
l = t.is_list_like() |
|
if l is not None and str(l.elem) != "bool": |
|
return l.size |
|
else: |
|
return None |
|
|
|
|
|
def argument(a: Argument) -> PythonArgument: |
|
return PythonArgument( |
|
name=a.name, |
|
type=a.type, |
|
|
|
default=( |
|
str(pythonify_default(cpp.default_expr(a.default, a.type, symint=False))) |
|
if a.default is not None |
|
else None |
|
), |
|
default_init=None, |
|
) |
|
|
|
|
|
|
|
def signature( |
|
f: NativeFunction, *, method: bool = False, pyi: bool = False |
|
) -> PythonSignature: |
|
return signature_from_schema( |
|
f.func, category_override=f.category_override, method=method, pyi=pyi |
|
) |
|
|
|
|
|
def signature_from_schema( |
|
func: FunctionSchema, |
|
*, |
|
category_override: str | None, |
|
method: bool = False, |
|
pyi: bool = False, |
|
) -> PythonSignature: |
|
args: list[Argument] = [] |
|
args.extend(func.arguments.pre_self_positional) |
|
|
|
if not method and func.arguments.self_arg is not None: |
|
args.append(func.arguments.self_arg.argument) |
|
args.extend(func.arguments.post_self_positional) |
|
args.extend(func.arguments.pre_tensor_options_kwarg_only) |
|
|
|
|
|
args.extend(func.arguments.post_tensor_options_kwarg_only) |
|
args.extend(func.arguments.out) |
|
|
|
input_arg_set = {a.name for a in func.arguments.flat_positional} |
|
kwarg_only_set = {a.name for a in func.arguments.flat_kwarg_only} |
|
out_arg_set = {a.name for a in func.arguments.out} |
|
|
|
input_args = tuple(map(argument, filter(lambda a: a.name in input_arg_set, args))) |
|
input_kwargs = tuple( |
|
map(argument, filter(lambda a: a.name in kwarg_only_set, args)) |
|
) |
|
outputs = tuple(map(argument, filter(lambda a: a.name in out_arg_set, args))) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
has_tensor_input_arg = any( |
|
a.type.is_tensor_like() for a in func.arguments.flat_non_out |
|
) |
|
if any(a.name == "requires_grad" for a in func.schema_order_arguments()): |
|
raise ValueError( |
|
"argument named requires_grad is reserved, should not explicitly add it in the schema" |
|
) |
|
|
|
|
|
|
|
has_tensor_return = any(r.type.is_tensor_like() for r in func.returns) |
|
|
|
name: str = cpp.name(func) |
|
is_factory_function = category_override == "factory" or ( |
|
has_tensor_return and not has_tensor_input_arg |
|
) |
|
is_like_or_new_function = ( |
|
category_override in ("new", "like") |
|
or name.startswith("new_") |
|
or name.endswith("_like") |
|
) |
|
is_dummy_function = category_override == "dummy" |
|
|
|
tensor_options_args: list[PythonArgument] = [] |
|
if (is_factory_function or is_like_or_new_function) and not is_dummy_function: |
|
|
|
def topt_default_init(name: str) -> str | None: |
|
topt_args = func.arguments.tensor_options |
|
if topt_args is None: |
|
return None |
|
a = getattr(topt_args, name) |
|
if a.default is None or a.default == "None": |
|
return None |
|
return cpp.default_expr(a.default, a.type, symint=False) |
|
|
|
tensor_options_args.append( |
|
PythonArgument( |
|
name="dtype", |
|
type=OptionalType(BaseType(BaseTy.ScalarType)), |
|
default="None", |
|
default_init=( |
|
None if is_like_or_new_function else topt_default_init("dtype") |
|
), |
|
) |
|
) |
|
tensor_options_args.append( |
|
PythonArgument( |
|
name="layout", |
|
type=OptionalType(BaseType(BaseTy.Layout)), |
|
default="None", |
|
default_init=( |
|
None if is_like_or_new_function else topt_default_init("layout") |
|
), |
|
) |
|
) |
|
tensor_options_args.append( |
|
PythonArgument( |
|
name="device", |
|
type=OptionalType(BaseType(BaseTy.Device)), |
|
default="None", |
|
default_init=( |
|
None |
|
if is_like_or_new_function |
|
else ( |
|
topt_default_init("device") |
|
or "torch::tensors::get_default_device()" |
|
) |
|
), |
|
) |
|
) |
|
tensor_options_args.append( |
|
PythonArgument( |
|
name="pin_memory", |
|
type=OptionalType(BaseType(BaseTy.bool)), |
|
default="False", |
|
default_init=None, |
|
) |
|
) |
|
tensor_options_args.append( |
|
PythonArgument( |
|
name="requires_grad", |
|
type=OptionalType(BaseType(BaseTy.bool)), |
|
default="False", |
|
default_init=None, |
|
) |
|
) |
|
|
|
returns = PythonReturns(returns=func.returns) |
|
|
|
return PythonSignature( |
|
name=str(func.name.name), |
|
input_args=input_args, |
|
input_kwargs=input_kwargs, |
|
output_args=PythonOutArgument.from_outputs(outputs), |
|
tensor_options_args=tuple(tensor_options_args), |
|
returns=returns, |
|
method=method, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def structseq_fieldnames(returns: tuple[Return, ...]) -> list[str]: |
|
if len(returns) <= 1 or all(r.name is None for r in returns): |
|
return [] |
|
else: |
|
if any(r.name is None for r in returns): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
raise ValueError("Unnamed field is not supported by codegen") |
|
|
|
return [str(r.name) for r in returns] |
|
|
|
|
|
def argument_type_str_pyi(t: Type) -> str: |
|
add_optional = False |
|
if isinstance(t, OptionalType): |
|
t = t.elem |
|
add_optional = True |
|
|
|
if isinstance(t, BaseType): |
|
if t.name in [BaseTy.int, BaseTy.DeviceIndex]: |
|
ret = "_int" |
|
if t.name == BaseTy.SymInt: |
|
ret = "Union[_int, SymInt]" |
|
elif t.name == BaseTy.float: |
|
ret = "_float" |
|
elif t.name == BaseTy.str: |
|
ret = "str" |
|
elif t.name == BaseTy.Scalar: |
|
ret = "Union[Number, _complex]" |
|
elif t.name == BaseTy.ScalarType: |
|
ret = "_dtype" |
|
elif t.name == BaseTy.bool: |
|
ret = "_bool" |
|
elif t.name == BaseTy.QScheme: |
|
ret = "_qscheme" |
|
elif t.name == BaseTy.Layout: |
|
ret = "_layout" |
|
elif t.name == BaseTy.Device: |
|
ret = "Optional[DeviceLikeType]" |
|
elif t.name == BaseTy.MemoryFormat: |
|
ret = "memory_format" |
|
elif t.name == BaseTy.Dimname: |
|
ret = "Union[str, ellipsis, None]" |
|
elif t.name == BaseTy.Storage: |
|
ret = "Union[Storage, UntypedStorage]" |
|
elif t.name in [BaseTy.Tensor, BaseTy.Generator, BaseTy.Stream]: |
|
|
|
ret = t.name.name |
|
|
|
elif isinstance(t, ListType): |
|
if str(t.elem) == "int": |
|
ret = "Union[_int, _size]" if t.size is not None else "_size" |
|
elif t.is_tensor_like(): |
|
|
|
|
|
|
|
add_optional = True |
|
ret = ( |
|
"Union[Tensor, tuple[Tensor, ...], list[Tensor]]" |
|
if t.size is not None |
|
else "Union[tuple[Tensor, ...], list[Tensor]]" |
|
) |
|
elif str(t.elem) == "float": |
|
ret = "Sequence[_float]" |
|
elif str(t.elem) == "SymInt" and t.size is not None: |
|
elem = argument_type_str_pyi(t.elem) |
|
ret = f"Union[{elem}, Sequence[{elem}]]" |
|
else: |
|
elem = argument_type_str_pyi(t.elem) |
|
ret = f"Sequence[{elem}]" |
|
|
|
else: |
|
raise RuntimeError(f"unrecognized type {repr(t)}") |
|
|
|
if add_optional: |
|
ret = "Optional[" + ret + "]" |
|
|
|
return ret |
|
|
|
|
|
def return_type_str_pyi(t: Type) -> str: |
|
|
|
|
|
|
|
if isinstance(t, OptionalType): |
|
inner = return_type_str_pyi(t.elem) |
|
return f"Optional[{inner}]" |
|
|
|
if isinstance(t, BaseType): |
|
if t.name == BaseTy.Device: |
|
return "_device" |
|
elif t.name == BaseTy.Dimname: |
|
return "Optional[str]" |
|
else: |
|
return argument_type_str_pyi(t) |
|
|
|
if isinstance(t, ListType): |
|
inner = return_type_str_pyi(t.elem) |
|
return f"tuple[{inner}, ...]" |
|
|
|
return argument_type_str_pyi(t) |
|
|
|
|
|
def returns_structseq_pyi(signature: PythonSignature) -> tuple[str, str] | None: |
|
python_returns = [return_type_str_pyi(r.type) for r in signature.returns.returns] |
|
structseq_name = signature.name |
|
field_names = structseq_fieldnames(signature.returns.returns) |
|
if field_names: |
|
|
|
|
|
|
|
seq_type = f"tuple[{', '.join(python_returns)}]" |
|
structseq_def_lines = [ |
|
f"class {structseq_name}({seq_type}):", |
|
] |
|
for name, typ in zip(field_names, python_returns): |
|
structseq_def_lines.extend( |
|
[ |
|
" @property", |
|
f" def {name}(self) -> {typ}: ...", |
|
] |
|
) |
|
structseq_def_lines.extend( |
|
[ |
|
f" def __new__(cls, sequence: {seq_type}): ...", |
|
f" n_fields: _int = {len(field_names)}", |
|
f" n_sequeunce_fields: _int = {len(field_names)}", |
|
" n_unnamed_fields: _int = 0", |
|
" def __init_subclass__(cls) -> NoReturn: ... # prohibit subclassing", |
|
"", |
|
] |
|
) |
|
structseq_def = "\n".join(structseq_def_lines) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return structseq_name, structseq_def |
|
return None |
|
|
|
|
|
def returns_str_pyi(signature: PythonSignature) -> str: |
|
field_names = structseq_fieldnames(signature.returns.returns) |
|
if field_names: |
|
return f"torch.return_types.{signature.name}" |
|
|
|
python_returns = [return_type_str_pyi(r.type) for r in signature.returns.returns] |
|
if len(python_returns) > 1: |
|
return "tuple[" + ", ".join(python_returns) + "]" |
|
if len(python_returns) == 1: |
|
return python_returns[0] |
|
return "None" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def dispatch_lambda_args( |
|
ps: PythonSignature, f: NativeFunction, symint: bool = True |
|
) -> tuple[DispatchLambdaArgument, ...]: |
|
if isinstance(ps, PythonSignatureDeprecated): |
|
schema = ps.deprecated_schema |
|
else: |
|
schema = f.func |
|
|
|
|
|
cpp_args = cpp.arguments( |
|
arguments=schema.arguments, |
|
faithful=False, |
|
symint=symint, |
|
method=False, |
|
cpp_no_default_args=f.cpp_no_default_args, |
|
) |
|
out_args: set[str] = {a.name for a in schema.arguments.out} |
|
|
|
|
|
def dispatch_lambda_arg(cpp_arg: Binding) -> DispatchLambdaArgument: |
|
type_str = cpp_arg.type |
|
is_out_arg = cpp_arg.name in out_args |
|
if ps.method and cpp_arg.name == "self": |
|
|
|
type_str = "const at::Tensor &" |
|
else: |
|
|
|
|
|
|
|
|
|
ensure_temp_safe = len(out_args) <= 1 or not is_out_arg |
|
if ensure_temp_safe: |
|
type_str = { |
|
"at::Tensor &": "at::Tensor", |
|
}.get(type_str, type_str) |
|
return DispatchLambdaArgument( |
|
name=cpp_arg.name, |
|
type_str=type_str, |
|
is_out_arg=is_out_arg, |
|
) |
|
|
|
return tuple(map(dispatch_lambda_arg, cpp_args)) |
|
|
|
|
|
|
|
|
|
|
|
SUPPORTED_RETURN_TYPES = { |
|
"at::Tensor", |
|
"::std::tuple<at::Tensor,at::Tensor>", |
|
"::std::tuple<at::Tensor,at::Tensor,at::Tensor>", |
|
"::std::tuple<at::Tensor,at::Tensor,at::Tensor,at::Tensor>", |
|
"::std::tuple<at::Tensor,at::Tensor,at::Tensor,at::Tensor,at::Tensor>", |
|
"::std::tuple<at::Tensor,at::Tensor,at::Tensor,at::Tensor,at::Tensor,at::Tensor>", |
|
"::std::tuple<at::Tensor,at::Tensor,at::Tensor,int64_t>", |
|
"::std::tuple<at::Tensor,at::Tensor,double,int64_t>", |
|
"::std::tuple<at::Tensor,at::Tensor,at::Tensor,at::Tensor,int64_t>", |
|
"::std::tuple<at::Tensor,at::Tensor,double,at::Tensor,int64_t>", |
|
"::std::tuple<double,int64_t>", |
|
"::std::tuple<at::Tensor,::std::vector<at::Tensor>>", |
|
"::std::vector<at::Tensor>", |
|
|
|
"::std::tuple<at::Tensor,at::Tensor,at::Tensor,at::Tensor,c10::SymInt,c10::SymInt,at::Tensor,at::Tensor,at::Tensor>", |
|
"at::Scalar", |
|
"bool", |
|
"int64_t", |
|
"void*", |
|
"void", |
|
"at::QScheme", |
|
"double", |
|
"at::IntArrayRef", |
|
"at::ScalarType", |
|
"at::Stream", |
|
} |
|
|
|
|
|
def dispatch_lambda_return_str(f: NativeFunction) -> str: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
returns_without_annotation = tuple( |
|
Return(r.name, r.type, None) for r in f.func.returns |
|
) |
|
return_str = cpp.returns_type(returns_without_annotation, symint=True).cpp_type() |
|
if return_str not in SUPPORTED_RETURN_TYPES: |
|
raise RuntimeError(f"{f.func.name} returns unsupported type {return_str}") |
|
return return_str |
|
|
|
|
|
def cpp_dispatch_target(f: NativeFunction) -> str: |
|
symint = f.func.has_symint() |
|
name = cpp.name(f.func, symint_overload=symint) |
|
if Variant.method in f.variants: |
|
return f"self.{name}" |
|
if Variant.function in f.variants: |
|
if has_tensor_options(f) or f.func.name.name.base.endswith("_like"): |
|
namespace = "torch" |
|
else: |
|
namespace = "at" |
|
return f"{namespace}::{name}" |
|
raise RuntimeError(f"could not dispatch, neither function nor method: {f.func}") |
|
|
|
|
|
def cpp_dispatch_exprs( |
|
f: NativeFunction, |
|
*, |
|
python_signature: PythonSignature | None = None, |
|
) -> tuple[str, ...]: |
|
cpp_args: Sequence[Binding] = _cpp_signature(f, method=False).arguments() |
|
|
|
exprs: tuple[str, ...] = () |
|
if not isinstance(python_signature, PythonSignatureDeprecated): |
|
|
|
exprs = tuple(a.name for a in cpp_args) |
|
else: |
|
|
|
exprs = tuple( |
|
filter( |
|
lambda n: n != "out" or f.func.is_out_fn(), |
|
python_signature.deprecated_args_exprs, |
|
) |
|
) |
|
|
|
if Variant.method in f.variants: |
|
exprs = tuple(filter("self".__ne__, exprs)) |
|
|
|
return exprs |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def arg_parser_unpack_method( |
|
t: Type, default: str | None, default_init: str | None, *, symint: bool = True |
|
) -> str: |
|
has_default_init = default_init is not None |
|
if has_default_init and str(t) not in ( |
|
"ScalarType?", |
|
"ScalarType", |
|
"Device", |
|
"Device?", |
|
"Layout", |
|
"Layout?", |
|
"bool", |
|
"bool?", |
|
): |
|
raise RuntimeError(f"type '{t}' does not supported unpacking with default") |
|
|
|
if isinstance(t, BaseType): |
|
if t.name in [ |
|
BaseTy.Tensor, |
|
BaseTy.Stream, |
|
BaseTy.Storage, |
|
BaseTy.Scalar, |
|
BaseTy.Dimname, |
|
]: |
|
|
|
return t.name.name.lower() |
|
elif t.name == BaseTy.ScalarType: |
|
return "scalartypeWithDefault" if has_default_init else "scalartype" |
|
elif t.name == BaseTy.Device: |
|
return "deviceWithDefault" if has_default_init else "device" |
|
elif t.name == BaseTy.DeviceIndex: |
|
return "toInt64" |
|
elif t.name == BaseTy.int: |
|
return "toInt64" |
|
elif t.name == BaseTy.SymInt: |
|
return "toSymInt" if symint else "toInt64" |
|
elif t.name == BaseTy.bool: |
|
return "toBoolWithDefault" if has_default_init else "toBool" |
|
elif t.name == BaseTy.float: |
|
return "toDouble" |
|
elif t.name == BaseTy.str: |
|
return "stringView" |
|
elif t.name == BaseTy.Layout: |
|
return "layoutWithDefault" if has_default_init else "layout" |
|
elif t.name == BaseTy.MemoryFormat: |
|
return "memoryformat" |
|
|
|
elif isinstance(t, OptionalType): |
|
if str(t.elem) == "Tensor": |
|
return "optionalTensor" |
|
elif str(t.elem) == "Generator": |
|
return "generator" |
|
elif str(t.elem) == "Dimname[]": |
|
return "toDimnameListOptional" |
|
elif not has_default_init and default in ( |
|
None, |
|
"None", |
|
"::std::nullopt", |
|
"std::nullopt", |
|
): |
|
|
|
return ( |
|
arg_parser_unpack_method(t.elem, None, None, symint=symint) + "Optional" |
|
) |
|
else: |
|
|
|
return arg_parser_unpack_method( |
|
t.elem, default, default_init, symint=symint |
|
) |
|
|
|
elif isinstance(t, ListType): |
|
if str(t.elem) == "Tensor": |
|
|
|
return f"tensorlist_n<{t.size}>" if t.size is not None else "tensorlist" |
|
elif str(t.elem) == "Tensor?": |
|
return "list_of_optional_tensors" |
|
elif str(t.elem) == "Dimname": |
|
|
|
return "dimnamelist" |
|
elif str(t.elem) == "int": |
|
|
|
return "intlist" |
|
elif str(t.elem) == "float": |
|
return "doublelist" |
|
elif str(t.elem) == "SymInt": |
|
|
|
return "symintlist" if symint else "intlist" |
|
elif str(t.elem) == "Scalar": |
|
return "scalarlist" |
|
raise RuntimeError(f"type '{t}' is not supported by PythonArgParser") |
|
|
|
|
|
|
|
|
|
def arg_parser_output_expr( |
|
arg_index: int, a: PythonArgument, *, symint: bool = True |
|
) -> PythonArgParserOutputExpr: |
|
has_default = a.default_init is not None |
|
unpack_method = arg_parser_unpack_method( |
|
t=a.type, default=a.default, default_init=a.default_init, symint=symint |
|
) |
|
default = f", {a.default_init}" if has_default else "" |
|
expr = f"_r.{unpack_method}({arg_index}{default})" |
|
|
|
return PythonArgParserOutputExpr( |
|
name=a.name, |
|
expr=expr, |
|
index=arg_index, |
|
argument=a, |
|
) |
|
|
|
|
|
|
|
def arg_parser_output_exprs( |
|
ps: PythonSignature, f: NativeFunction, *, symint: bool = True |
|
) -> dict[str, PythonArgParserOutputExpr]: |
|
return { |
|
e.name: e |
|
for i, a in enumerate(ps.arguments()) |
|
for e in (arg_parser_output_expr(i, a, symint=symint),) |
|
} |
|
|
|
|
|
|
|
TENSOR_OPTIONS_FIELDS = { |
|
"dtype": "ScalarType?", |
|
"device": "Device?", |
|
"layout": "Layout?", |
|
"pin_memory": "bool?", |
|
"requires_grad": "bool?", |
|
} |
|
|
|
|
|
|
|
def dispatch_lambda_exprs( |
|
ps: PythonSignature, f: NativeFunction, *, symint: bool = True |
|
) -> DispatchLambdaArgumentExprs: |
|
|
|
|
|
|
|
arg_parser_outputs = arg_parser_output_exprs(ps, f, symint=symint) |
|
lambda_args = dispatch_lambda_args(ps, f, symint=symint) |
|
inits: list[str] = [] |
|
lambda_args_exprs: dict[str, str] = {} |
|
|
|
has_toptions = has_tensor_options(f) |
|
|
|
|
|
for a in ps.arguments(skip_tensor_options=True): |
|
name = a.name |
|
arg_parser_expr = arg_parser_outputs[a.name].expr |
|
|
|
if has_toptions and name == "self": |
|
|
|
inits.extend( |
|
[ |
|
f"auto self = {arg_parser_expr};", |
|
] |
|
) |
|
lambda_args_exprs[name] = name |
|
elif ( |
|
isinstance(a, PythonOutArgument) |
|
and len(a.outputs) > 1 |
|
and f.func.is_out_fn() |
|
): |
|
inits.extend( |
|
[ |
|
f"auto out = {arg_parser_expr};", |
|
] |
|
) |
|
for i, out_arg in enumerate(a.outputs): |
|
lambda_args_exprs[out_arg.name] = f"out[{i}]" |
|
elif str(a.type) == "Dimname[]?": |
|
|
|
|
|
|
|
|
|
|
|
inits.extend( |
|
[ |
|
f"auto __{name} = {arg_parser_expr};", |
|
f"::std::optional<DimnameList> {name} = __{name} ? ::std::make_optional(DimnameList(__{name}.value())) : ::std::nullopt;", |
|
] |
|
) |
|
lambda_args_exprs[name] = name |
|
else: |
|
|
|
lambda_args_exprs[name] = arg_parser_expr |
|
|
|
|
|
if ps.method: |
|
lambda_args_exprs["self"] = "self" |
|
|
|
|
|
tensor_options_args_names = [a.name for a in ps.tensor_options_args] |
|
if has_toptions: |
|
if f.func.is_out_fn(): |
|
raise RuntimeError(f"{f.func}: tensor options with output arg") |
|
for a in ps.tensor_options_args: |
|
if a.name not in TENSOR_OPTIONS_FIELDS: |
|
raise RuntimeError( |
|
f"{f.func}: unrecognized tensor options field '{a.name}' in python binding arguments" |
|
) |
|
if str(a.type) != TENSOR_OPTIONS_FIELDS.get(a.name): |
|
raise RuntimeError( |
|
f"{f.func}: unrecognized type '{str(a.type)}' for tensor options field '{a.name}'" |
|
) |
|
if not all(a in tensor_options_args_names for a in TENSOR_OPTIONS_FIELDS): |
|
raise RuntimeError( |
|
f"{f.func}: incomplete tensor options args: {tensor_options_args_names}" |
|
) |
|
|
|
inits.append( |
|
f"""\ |
|
const auto options = TensorOptions() |
|
.dtype({arg_parser_outputs["dtype"].expr}) |
|
.device({arg_parser_outputs["device"].expr}) |
|
.layout({arg_parser_outputs["layout"].expr}) |
|
.requires_grad({arg_parser_outputs["requires_grad"].expr}) |
|
.pinned_memory({arg_parser_outputs["pin_memory"].expr}); |
|
torch::utils::maybe_initialize_device(options); |
|
""" |
|
) |
|
lambda_args_exprs["options"] = "options" |
|
|
|
|
|
|
|
if not has_toptions and tensor_options_args_names: |
|
if "dtype" in tensor_options_args_names: |
|
|
|
if not f.func.is_out_fn(): |
|
raise RuntimeError( |
|
f"{f.func}: dtype in tensor_options_args without output arg, {ps} {ps.arguments}" |
|
) |
|
if not all(a in tensor_options_args_names for a in ("layout", "device")): |
|
raise RuntimeError( |
|
f"{f.func}: incomplete tensor options for output check" |
|
) |
|
|
|
inits.append( |
|
f"""\ |
|
check_out_type_matches({arg_parser_outputs["out"].expr}, {arg_parser_outputs["dtype"].expr}, |
|
{arg_parser_outputs["dtype"].is_none_expr}, {arg_parser_outputs["layout"].expr}, |
|
{arg_parser_outputs["device"].expr}, {arg_parser_outputs["device"].is_none_expr}); |
|
""" |
|
) |
|
|
|
if "requires_grad" not in tensor_options_args_names: |
|
raise RuntimeError( |
|
f'{f.func}: expected "requires_grad" in tensor_options_args absent, but found [{tensor_options_args_names}]' |
|
) |
|
|
|
return DispatchLambdaArgumentExprs( |
|
exprs=tuple(lambda_args_exprs[a.name] for a in lambda_args), |
|
inits=inits, |
|
) |
|
|