|
from __future__ import annotations |
|
|
|
import itertools |
|
from typing import TYPE_CHECKING |
|
|
|
from torchgen.api import cpp |
|
from torchgen.api.types import ArgName, Binding, CType, NamedCType |
|
from torchgen.model import ( |
|
Argument, |
|
FunctionSchema, |
|
Return, |
|
SelfArgument, |
|
TensorOptionsArguments, |
|
Type, |
|
) |
|
from torchgen.utils import assert_never, concatMap |
|
|
|
|
|
if TYPE_CHECKING: |
|
from collections.abc import Sequence |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def name(func: FunctionSchema) -> str: |
|
return cpp.name(func) |
|
|
|
|
|
def argumenttype_type( |
|
t: Type, |
|
*, |
|
mutable: bool, |
|
binds: ArgName, |
|
remove_non_owning_ref_types: bool = False, |
|
symint: bool = True, |
|
) -> NamedCType: |
|
|
|
|
|
|
|
|
|
return cpp.argumenttype_type( |
|
t, |
|
mutable=mutable, |
|
binds=binds, |
|
symint=symint, |
|
remove_non_owning_ref_types=remove_non_owning_ref_types, |
|
) |
|
|
|
|
|
def argument_type( |
|
a: Argument, |
|
*, |
|
binds: ArgName, |
|
remove_non_owning_ref_types: bool = False, |
|
symint: bool = True, |
|
) -> NamedCType: |
|
return argumenttype_type( |
|
a.type, |
|
mutable=a.is_write, |
|
binds=binds, |
|
remove_non_owning_ref_types=remove_non_owning_ref_types, |
|
symint=symint, |
|
) |
|
|
|
|
|
def returns_type(rs: Sequence[Return], *, symint: bool = True) -> CType: |
|
|
|
return cpp.returns_type(rs, symint=symint) |
|
|
|
|
|
def jit_arguments(func: FunctionSchema) -> list[Argument]: |
|
def to_argument( |
|
a: Argument | TensorOptionsArguments | SelfArgument, |
|
) -> list[Argument]: |
|
if isinstance(a, Argument): |
|
return [a] |
|
elif isinstance(a, SelfArgument): |
|
return [a.argument] |
|
elif isinstance(a, TensorOptionsArguments): |
|
return [a.dtype, a.layout, a.device, a.pin_memory] |
|
else: |
|
assert_never(a) |
|
|
|
return list( |
|
concatMap( |
|
to_argument, |
|
itertools.chain( |
|
func.arguments.positional, func.arguments.kwarg_only, func.arguments.out |
|
), |
|
) |
|
) |
|
|
|
|
|
def argument( |
|
a: Argument, *, remove_non_owning_ref_types: bool = False, symint: bool = True |
|
) -> Binding: |
|
return Binding( |
|
nctype=argument_type( |
|
a, |
|
binds=a.name, |
|
remove_non_owning_ref_types=remove_non_owning_ref_types, |
|
symint=symint, |
|
), |
|
name=a.name, |
|
argument=a, |
|
) |
|
|
|
|
|
def arguments(func: FunctionSchema, *, symint: bool = True) -> list[Binding]: |
|
return [argument(a, symint=symint) for a in jit_arguments(func)] |
|
|