|
from __future__ import annotations |
|
|
|
from typing import TYPE_CHECKING |
|
|
|
from torchgen import local |
|
from torchgen.api import cpp |
|
from torchgen.api.types import ( |
|
ArgName, |
|
BaseCType, |
|
Binding, |
|
boolT, |
|
ConstRefCType, |
|
CType, |
|
deviceT, |
|
layoutT, |
|
ListCType, |
|
MutRefCType, |
|
NamedCType, |
|
OptionalCType, |
|
scalarT, |
|
scalarTypeT, |
|
tensorT, |
|
) |
|
from torchgen.model import ( |
|
Argument, |
|
FunctionSchema, |
|
Return, |
|
SelfArgument, |
|
TensorOptionsArguments, |
|
Type, |
|
) |
|
from torchgen.utils import assert_never |
|
|
|
|
|
if TYPE_CHECKING: |
|
from collections.abc import Sequence |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def name(func: FunctionSchema) -> str: |
|
name = str(func.name.name) |
|
|
|
if func.is_out_fn(): |
|
name += "_out" |
|
if func.name.overload_name: |
|
name += f"_{func.name.overload_name}" |
|
return name |
|
|
|
|
|
def argumenttype_type( |
|
t: Type, *, mutable: bool, binds: ArgName, symint: bool |
|
) -> NamedCType: |
|
if str(t) == "Tensor?": |
|
tensor_type: OptionalCType = OptionalCType(BaseCType(tensorT)) |
|
if mutable and not local.use_const_ref_for_mutable_tensors(): |
|
return NamedCType(binds, MutRefCType(tensor_type)) |
|
else: |
|
return NamedCType(binds, ConstRefCType(tensor_type)) |
|
elif str(t) == "Tensor?[]": |
|
return NamedCType( |
|
binds, ConstRefCType(ListCType(OptionalCType(BaseCType(tensorT)))) |
|
) |
|
elif str(t) == "Scalar": |
|
return NamedCType(binds, ConstRefCType(BaseCType(scalarT))) |
|
elif str(t) == "Scalar?": |
|
return NamedCType(binds, ConstRefCType(OptionalCType(BaseCType(scalarT)))) |
|
return cpp.argumenttype_type(t, mutable=mutable, binds=binds, symint=symint) |
|
|
|
|
|
def returns_type(rs: Sequence[Return], *, symint: bool) -> CType: |
|
return cpp.returns_type(rs, symint=symint) |
|
|
|
|
|
def argument_type(a: Argument, *, binds: ArgName, symint: bool) -> NamedCType: |
|
return argumenttype_type(a.type, mutable=a.is_write, binds=binds, symint=symint) |
|
|
|
|
|
def argument( |
|
a: Argument | SelfArgument | TensorOptionsArguments, |
|
*, |
|
is_out: bool, |
|
symint: bool, |
|
) -> list[Binding]: |
|
|
|
|
|
|
|
|
|
|
|
should_default = not is_out |
|
if isinstance(a, Argument): |
|
default: str | None = None |
|
if should_default and a.default is not None: |
|
default = cpp.default_expr(a.default, a.type, symint=symint) |
|
return [ |
|
Binding( |
|
nctype=argument_type(a, binds=a.name, symint=symint), |
|
name=a.name, |
|
default=default, |
|
argument=a, |
|
) |
|
] |
|
elif isinstance(a, SelfArgument): |
|
|
|
return argument(a.argument, is_out=is_out, symint=symint) |
|
elif isinstance(a, TensorOptionsArguments): |
|
default = None |
|
if should_default: |
|
default = "{}" |
|
|
|
|
|
|
|
return [ |
|
Binding( |
|
nctype=NamedCType("dtype", OptionalCType(BaseCType(scalarTypeT))), |
|
name="dtype", |
|
default=default, |
|
argument=a, |
|
), |
|
Binding( |
|
nctype=NamedCType("layout", OptionalCType(BaseCType(layoutT))), |
|
name="layout", |
|
default=default, |
|
argument=a, |
|
), |
|
Binding( |
|
nctype=NamedCType("device", OptionalCType(BaseCType(deviceT))), |
|
name="device", |
|
default=default, |
|
argument=a, |
|
), |
|
Binding( |
|
nctype=NamedCType("pin_memory", OptionalCType(BaseCType(boolT))), |
|
name="pin_memory", |
|
default=default, |
|
argument=a, |
|
), |
|
] |
|
else: |
|
assert_never(a) |
|
|
|
|
|
def arguments(func: FunctionSchema, *, symint: bool) -> list[Binding]: |
|
args: list[Argument | TensorOptionsArguments | SelfArgument] = [] |
|
args.extend(func.arguments.non_out) |
|
args.extend(func.arguments.out) |
|
return [ |
|
r for arg in args for r in argument(arg, symint=symint, is_out=func.is_out_fn()) |
|
] |
|
|