|
from __future__ import annotations |
|
|
|
from typing import NoReturn, TYPE_CHECKING |
|
|
|
from torchgen.api.types import ( |
|
ArrayRefCType, |
|
BaseCType, |
|
Binding, |
|
boolT, |
|
ConstRefCType, |
|
deviceT, |
|
Expr, |
|
intArrayRefT, |
|
iOptTensorListRefT, |
|
layoutT, |
|
ListCType, |
|
longT, |
|
memoryFormatT, |
|
MutRefCType, |
|
NamedCType, |
|
opmath_t, |
|
OptionalCType, |
|
optionalIntArrayRefT, |
|
optionalScalarRefT, |
|
optionalSymIntArrayRefT, |
|
optionalTensorRefT, |
|
scalar_t, |
|
scalarT, |
|
scalarTypeT, |
|
SpecialArgName, |
|
symIntArrayRefT, |
|
SymIntT, |
|
tensorOptionsT, |
|
tensorT, |
|
VectorCType, |
|
) |
|
|
|
|
|
if TYPE_CHECKING: |
|
from collections.abc import Sequence |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
options_ctype = NamedCType("options", ConstRefCType(BaseCType(tensorOptionsT))) |
|
|
|
out_tensor_ctype = NamedCType("out", ConstRefCType(BaseCType(tensorT))) |
|
|
|
longVec_ctype = VectorCType(BaseCType(longT)) |
|
longSymVec_ctype = VectorCType(BaseCType(SymIntT)) |
|
optionalLongVec_ctype = OptionalCType(VectorCType(BaseCType(longT))) |
|
optionalScalar_ctype = OptionalCType(BaseCType(scalarT)) |
|
optionalTensor_ctype = OptionalCType(BaseCType(tensorT)) |
|
|
|
|
|
class UnsatError(RuntimeError): |
|
pass |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def translate( |
|
bindings: Sequence[Expr | Binding], |
|
goals: Sequence[NamedCType | Binding], |
|
*, |
|
method: bool = False, |
|
allow_expensive_conversions: bool = False, |
|
) -> list[Expr]: |
|
binding_exprs: list[Expr] = [] |
|
for b in bindings: |
|
if isinstance(b, Binding): |
|
binding_exprs.append( |
|
Expr( |
|
expr=b.name, |
|
type=b.nctype, |
|
) |
|
) |
|
else: |
|
binding_exprs.append(b) |
|
|
|
goal_ctypes: list[NamedCType] = [] |
|
for g in goals: |
|
if isinstance(g, Binding): |
|
goal_ctypes.append(g.nctype) |
|
else: |
|
goal_ctypes.append(g) |
|
|
|
|
|
ctx: dict[NamedCType, str] = {} |
|
for b in binding_exprs: |
|
ctx[b.type] = b.expr |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
t = b.type |
|
if ( |
|
isinstance(t, ConstRefCType) |
|
and isinstance(t.elem, OptionalCType) |
|
and isinstance(t.elem.elem, BaseCType) |
|
and str(t.elem.elem.type) == "at::Tensor" |
|
): |
|
ctx[NamedCType(t.elem.elem.name, ConstRefCType(BaseCType(tensorT)))] = ( |
|
f"({b.expr}.has_value() ? *{b.expr} : at::Tensor())" |
|
) |
|
|
|
if t.type == ConstRefCType(OptionalCType(BaseCType(tensorT))): |
|
ctx[NamedCType(t.name, BaseCType(optionalTensorRefT))] = ( |
|
f"(({b.expr}.has_value() && (*{b.expr}).defined()) ? at::OptionalTensorRef(*{b.expr}) : at::OptionalTensorRef())" |
|
) |
|
|
|
if t.type == ConstRefCType(BaseCType(scalarT)): |
|
ctx[NamedCType(t.name, BaseCType(opmath_t))] = f"({b.expr}).to<opmath_t>()" |
|
|
|
if t.type == ConstRefCType(OptionalCType(BaseCType(scalarT))): |
|
ctx[NamedCType(t.name, BaseCType(optionalScalarRefT))] = ( |
|
f"({b.expr}.has_value() ? at::OptionalScalarRef(&({b.expr}.value())) : at::OptionalScalarRef())" |
|
) |
|
|
|
if t.type == BaseCType(scalar_t): |
|
ctx[NamedCType(t.name, BaseCType(opmath_t))] = ( |
|
f"static_cast<opmath_t>({b.expr})" |
|
) |
|
|
|
|
|
if t.type == ConstRefCType(ListCType(OptionalCType(BaseCType(tensorT)))): |
|
ctx[NamedCType(t.name, BaseCType(iOptTensorListRefT))] = ( |
|
f"at::IOptTensorListRef({b.expr})" |
|
) |
|
|
|
|
|
if method: |
|
ctx[NamedCType("self", MutRefCType(BaseCType(tensorT)))] = ( |
|
"const_cast<Tensor&>(*this)" |
|
) |
|
ctx[NamedCType("self", ConstRefCType(BaseCType(tensorT)))] = ( |
|
"const_cast<Tensor&>(*this)" |
|
) |
|
|
|
|
|
|
|
def unsat(goal: NamedCType) -> NoReturn: |
|
ctx_desc = "\n".join( |
|
f" {t.cpp_type()} {t.name}; // {e}" for t, e in ctx.items() |
|
) |
|
raise UnsatError( |
|
f""" |
|
Failed to synthesize the expression "{goal.cpp_type()} {goal.name}". |
|
When I failed, the following bindings were available in the context: |
|
|
|
{ctx_desc} |
|
|
|
This probably means there is a missing rule in the rules of torchgen.api.translate. |
|
Check this module for more information. |
|
""" |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def solve(goal: NamedCType, *, direct: bool) -> str: |
|
def direct_solve(goal: NamedCType) -> str: |
|
return solve(goal, direct=True) |
|
|
|
if goal in ctx: |
|
|
|
return ctx[goal] |
|
|
|
|
|
if isinstance(goal.type, ConstRefCType): |
|
try: |
|
|
|
|
|
|
|
return solve( |
|
NamedCType(goal.name, MutRefCType(goal.type.elem)), direct=direct |
|
) |
|
except UnsatError: |
|
pass |
|
|
|
|
|
if isinstance(goal.type, MutRefCType): |
|
try: |
|
return solve(NamedCType(goal.name, goal.type.elem), direct=direct) |
|
except UnsatError: |
|
pass |
|
|
|
|
|
|
|
|
|
if goal.type == ArrayRefCType(BaseCType(longT)): |
|
return solve(NamedCType(goal.name, BaseCType(intArrayRefT)), direct=direct) |
|
|
|
if direct: |
|
unsat(goal) |
|
|
|
|
|
if goal == NamedCType("memory_format", OptionalCType(BaseCType(memoryFormatT))): |
|
memory_format = direct_solve( |
|
NamedCType( |
|
SpecialArgName.possibly_redundant_memory_format, |
|
OptionalCType(BaseCType(memoryFormatT)), |
|
) |
|
) |
|
|
|
|
|
if options_ctype in goal_ctypes: |
|
return memory_format |
|
try: |
|
options = direct_solve(options_ctype) |
|
return f"c10::impl::check_tensor_options_and_extract_memory_format({options}, {memory_format})" |
|
except UnsatError: |
|
return memory_format |
|
elif goal == NamedCType("options", BaseCType(tensorOptionsT)): |
|
dtype = direct_solve( |
|
NamedCType("dtype", OptionalCType(BaseCType(scalarTypeT))) |
|
) |
|
pin_memory = direct_solve( |
|
NamedCType("pin_memory", OptionalCType(BaseCType(boolT))) |
|
) |
|
device = direct_solve( |
|
NamedCType("device", OptionalCType(BaseCType(deviceT))) |
|
) |
|
layout = direct_solve( |
|
NamedCType("layout", OptionalCType(BaseCType(layoutT))) |
|
) |
|
return f"TensorOptions().dtype({dtype}).layout({layout}).device({device}).pinned_memory({pin_memory})" |
|
|
|
elif goal == NamedCType("dtype", OptionalCType(BaseCType(scalarTypeT))): |
|
try: |
|
options = direct_solve(options_ctype) |
|
return f"c10::optTypeMetaToScalarType({options}.dtype_opt())" |
|
except UnsatError: |
|
out_tensor = direct_solve(out_tensor_ctype) |
|
return f"{out_tensor}.scalar_type()" |
|
|
|
elif goal == NamedCType("layout", OptionalCType(BaseCType(layoutT))): |
|
try: |
|
options = direct_solve(options_ctype) |
|
return f"{options}.layout_opt()" |
|
except UnsatError: |
|
out_tensor = direct_solve(out_tensor_ctype) |
|
return f"{out_tensor}.layout()" |
|
|
|
elif goal == NamedCType("device", OptionalCType(BaseCType(deviceT))): |
|
try: |
|
options = direct_solve(options_ctype) |
|
return f"{options}.device_opt()" |
|
except UnsatError: |
|
out_tensor = direct_solve(out_tensor_ctype) |
|
return f"{out_tensor}.device()" |
|
|
|
elif goal == NamedCType("pin_memory", OptionalCType(BaseCType(boolT))): |
|
try: |
|
options = direct_solve(options_ctype) |
|
return f"{options}.pinned_memory_opt()" |
|
except UnsatError: |
|
|
|
|
|
out_tensor = direct_solve(out_tensor_ctype) |
|
return "::std::nullopt" |
|
|
|
|
|
elif goal.type == BaseCType(intArrayRefT): |
|
try: |
|
return direct_solve(NamedCType(goal.name, longVec_ctype)) |
|
except UnsatError: |
|
|
|
symIntArrayRef_type = direct_solve( |
|
NamedCType(goal.name, BaseCType(symIntArrayRefT)) |
|
) |
|
return f"C10_AS_INTARRAYREF_SLOW({symIntArrayRef_type})" |
|
elif goal.type == BaseCType(symIntArrayRefT): |
|
try: |
|
r = direct_solve(NamedCType(goal.name, BaseCType(intArrayRefT))) |
|
return f"c10::fromIntArrayRefSlow({r})" |
|
except UnsatError: |
|
return direct_solve(NamedCType(goal.name, longSymVec_ctype)) |
|
elif goal.type == BaseCType(SymIntT): |
|
return direct_solve(NamedCType(goal.name, BaseCType(longT))) |
|
elif goal.type == OptionalCType(BaseCType(SymIntT)): |
|
argname = direct_solve( |
|
NamedCType(goal.name, OptionalCType(BaseCType(longT))) |
|
) |
|
return f"{argname}.has_value() ? ::std::make_optional(c10::SymInt(*{argname})) : ::std::nullopt" |
|
elif goal.type == BaseCType(longT): |
|
symInt_type = direct_solve(NamedCType(goal.name, BaseCType(SymIntT))) |
|
return f"{symInt_type}.guard_int(__FILE__, __LINE__)" |
|
elif goal.type == OptionalCType(BaseCType(longT)): |
|
argname = direct_solve( |
|
NamedCType(goal.name, OptionalCType(BaseCType(SymIntT))) |
|
) |
|
return f"{argname}.has_value() ? ::std::make_optional({argname}->guard_int(__FILE__, __LINE__)) : ::std::nullopt" |
|
elif goal.type == BaseCType(optionalIntArrayRefT): |
|
try: |
|
return direct_solve(NamedCType(goal.name, optionalLongVec_ctype)) |
|
except UnsatError: |
|
argname = direct_solve( |
|
NamedCType(goal.name, BaseCType(optionalSymIntArrayRefT)) |
|
) |
|
return f"{argname}.has_value() ? ::std::make_optional(C10_AS_INTARRAYREF_SLOW(*{argname})) : ::std::nullopt" |
|
elif goal.type == BaseCType(optionalSymIntArrayRefT): |
|
|
|
|
|
argname = direct_solve( |
|
NamedCType(goal.name, BaseCType(optionalIntArrayRefT)) |
|
) |
|
return f"{argname}.has_value() ? ::std::make_optional(c10::fromIntArrayRefSlow(*{argname})) : ::std::nullopt" |
|
elif goal.type == BaseCType(optionalScalarRefT): |
|
return direct_solve(NamedCType(goal.name, optionalScalar_ctype)) |
|
elif goal.type == BaseCType(optionalTensorRefT): |
|
return direct_solve(NamedCType(goal.name, optionalTensor_ctype)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if allow_expensive_conversions: |
|
if goal.type == VectorCType(BaseCType(longT)): |
|
intArrayRef_ctype = NamedCType(goal.name, BaseCType(intArrayRefT)) |
|
argname = direct_solve(intArrayRef_ctype) |
|
return f"{argname}.vec()" |
|
if goal.type == VectorCType(BaseCType(SymIntT)): |
|
symIntArrayRef_ctype = NamedCType(goal.name, BaseCType(symIntArrayRefT)) |
|
argname = direct_solve(symIntArrayRef_ctype) |
|
return f"{argname}.vec()" |
|
elif goal.type == OptionalCType(VectorCType(BaseCType(longT))): |
|
optionalIntArrayRef_ctype = NamedCType( |
|
goal.name, BaseCType(optionalIntArrayRefT) |
|
) |
|
argname = direct_solve(optionalIntArrayRef_ctype) |
|
return f"{argname}.has_value() ? ::std::make_optional({argname}->vec()) : ::std::nullopt" |
|
elif goal.type == OptionalCType(BaseCType(scalarT)): |
|
optionalScalarRef_ctype = NamedCType( |
|
goal.name, BaseCType(optionalScalarRefT) |
|
) |
|
argname = direct_solve(optionalScalarRef_ctype) |
|
return f"{argname}.has_value() ? ::std::make_optional({argname}) : ::std::nullopt" |
|
elif goal.type == OptionalCType(BaseCType(scalarT)): |
|
optionalTensorRef_ctype = NamedCType( |
|
goal.name, BaseCType(optionalTensorRefT) |
|
) |
|
argname = direct_solve(optionalTensorRef_ctype) |
|
return f"{argname}.has_value() ? ::std::make_optional({argname}) : ::std::nullopt" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if goal.type == MutRefCType(BaseCType(tensorT)): |
|
const_ref_tensor_ctype = NamedCType( |
|
goal.name, ConstRefCType(BaseCType(tensorT)) |
|
) |
|
argname = direct_solve(const_ref_tensor_ctype) |
|
return f"const_cast<Tensor&>({argname})" |
|
|
|
unsat(goal) |
|
|
|
return [Expr(solve(g, direct=False), g) for g in goal_ctypes] |
|
|