|
from __future__ import annotations |
|
|
|
from dataclasses import dataclass |
|
|
|
import torchgen.api.types as api_types |
|
from torchgen.api import cpp, structured |
|
from torchgen.api.types import ( |
|
ArgName, |
|
BaseCppType, |
|
BaseCType, |
|
Binding, |
|
ConstRefCType, |
|
CType, |
|
NamedCType, |
|
scalarT, |
|
) |
|
from torchgen.model import ( |
|
Argument, |
|
BaseTy, |
|
BaseType, |
|
DispatchKey, |
|
FunctionSchema, |
|
NativeFunctionsGroup, |
|
Type, |
|
) |
|
|
|
|
|
def schema_kernel_name(func: FunctionSchema, dispatch_key: DispatchKey) -> str: |
|
assert func.is_out_fn(), "ufunc.kernel_name should only be invoked on out schemas" |
|
return f"ufunc_{func.name.name}_{dispatch_key}" |
|
|
|
|
|
def kernel_name(g: NativeFunctionsGroup, dispatch_key: DispatchKey) -> str: |
|
return schema_kernel_name(g.out.func, dispatch_key) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def dispatchstub_type(t: Type, *, binds: ArgName) -> NamedCType | None: |
|
|
|
r = cpp.valuetype_type(t, binds=binds, symint=False) |
|
if r is not None: |
|
return r |
|
|
|
if t == BaseType(BaseTy.Scalar): |
|
return NamedCType(binds, ConstRefCType(BaseCType(scalarT))) |
|
elif t == BaseType(BaseTy.Tensor): |
|
return None |
|
else: |
|
raise AssertionError(f"unrecognized type {repr(t)}") |
|
|
|
|
|
def opmath_type(scalar_t: BaseCppType) -> BaseCppType: |
|
if scalar_t == api_types.scalar_t: |
|
return api_types.opmath_t |
|
raise NotImplementedError |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def ufunctor_ctor_type(t: Type, *, binds: ArgName, scalar_t: BaseCppType) -> NamedCType: |
|
r = cpp.valuetype_type(t, binds=binds, symint=False) |
|
if r is not None: |
|
return r |
|
|
|
if t == BaseType(BaseTy.Scalar): |
|
return NamedCType(binds, BaseCType(opmath_type(scalar_t))) |
|
elif t == BaseType(BaseTy.Tensor): |
|
return NamedCType(binds, BaseCType(opmath_type(scalar_t))) |
|
else: |
|
raise AssertionError(f"unrecognized type {repr(t)}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
def ufunctor_apply_type( |
|
t: Type, *, binds: ArgName, scalar_t: BaseCppType |
|
) -> NamedCType: |
|
if t == BaseType(BaseTy.Tensor): |
|
return NamedCType(binds, BaseCType(scalar_t)) |
|
else: |
|
raise AssertionError(f"unrecognized type {repr(t)}") |
|
|
|
|
|
|
|
|
|
|
|
def ufunc_type(t: Type, *, binds: ArgName, compute_t: CType) -> NamedCType: |
|
r = cpp.valuetype_type(t, binds=binds, symint=False) |
|
if r is not None: |
|
return r |
|
|
|
if t == BaseType(BaseTy.Scalar): |
|
return NamedCType(binds, compute_t) |
|
elif t == BaseType(BaseTy.Tensor): |
|
return NamedCType(binds, compute_t) |
|
else: |
|
raise AssertionError(f"unrecognized type {repr(t)}") |
|
|
|
|
|
def ufunctor_ctor_argument(a: Argument, scalar_t: BaseCppType) -> Binding: |
|
return Binding( |
|
nctype=ufunctor_ctor_type(a.type, binds=a.name, scalar_t=scalar_t), |
|
name=a.name, |
|
default=None, |
|
argument=a, |
|
) |
|
|
|
|
|
def ufunctor_apply_argument(a: Argument, scalar_t: BaseCppType) -> Binding: |
|
return Binding( |
|
nctype=ufunctor_apply_type(a.type, binds=a.name, scalar_t=scalar_t), |
|
name=a.name, |
|
default=None, |
|
argument=a, |
|
) |
|
|
|
|
|
def ufunc_argument(a: Argument, compute_t: CType) -> Binding: |
|
return Binding( |
|
nctype=ufunc_type(a.type, binds=a.name, compute_t=compute_t), |
|
name=a.name, |
|
default=None, |
|
argument=a, |
|
) |
|
|
|
|
|
@dataclass(frozen=True) |
|
class UfunctorBindings: |
|
ctor: list[Binding] |
|
apply: list[Binding] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def ufunctor_arguments( |
|
g: NativeFunctionsGroup, *, scalar_tensor_idx: int | None, scalar_t: BaseCppType |
|
) -> UfunctorBindings: |
|
ctor = [] |
|
apply = [] |
|
for a in g.functional.func.arguments.flat_non_out: |
|
if a.type.is_tensor_like(): |
|
if scalar_tensor_idx == 0: |
|
|
|
ctor.append(ufunctor_ctor_argument(a, scalar_t=scalar_t)) |
|
scalar_tensor_idx = None |
|
else: |
|
if scalar_tensor_idx is not None: |
|
scalar_tensor_idx -= 1 |
|
apply.append(ufunctor_apply_argument(a, scalar_t=scalar_t)) |
|
else: |
|
ctor.append(ufunctor_ctor_argument(a, scalar_t=scalar_t)) |
|
assert scalar_tensor_idx is None |
|
return UfunctorBindings(ctor=ctor, apply=apply) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def ufunc_arguments(g: NativeFunctionsGroup, *, compute_t: CType) -> list[Binding]: |
|
return [ |
|
ufunc_argument(a, compute_t=compute_t) |
|
for a in g.functional.func.arguments.flat_non_out |
|
] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def stub_arguments(g: NativeFunctionsGroup) -> list[Binding]: |
|
|
|
|
|
return [ |
|
r |
|
for a in g.out.func.arguments.flat_non_out |
|
if not a.type.is_tensor_like() |
|
for r in structured.argument(a) |
|
] |
|
|