|
from __future__ import annotations |
|
|
|
from torchgen.api import dispatcher |
|
from torchgen.api.types import ( |
|
BaseCppType, |
|
BaseCType, |
|
Binding, |
|
boolT, |
|
ConstRefCType, |
|
CType, |
|
longT, |
|
NamedCType, |
|
tensorT, |
|
) |
|
from torchgen.model import ( |
|
Argument, |
|
BaseTy, |
|
BaseType, |
|
FunctionSchema, |
|
NativeFunction, |
|
NativeFunctionsViewGroup, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
base_binding = Binding( |
|
name="base", |
|
nctype=NamedCType(name="base", type=ConstRefCType(BaseCType(tensorT))), |
|
argument=Argument( |
|
name="base", type=BaseType(BaseTy.Tensor), default=None, annotation=None |
|
), |
|
default=None, |
|
) |
|
mutated_view_binding = Binding( |
|
name="mutated_view", |
|
nctype=NamedCType(name="mutated_view", type=ConstRefCType(BaseCType(tensorT))), |
|
argument=Argument( |
|
name="base", type=BaseType(BaseTy.Tensor), default=None, annotation=None |
|
), |
|
default=None, |
|
) |
|
mutated_view_idx_binding = Binding( |
|
name="mutated_view_idx", |
|
nctype=NamedCType(name="mutated_view_idx", type=BaseCType(longT)), |
|
argument=Argument( |
|
name="base", type=BaseType(BaseTy.Tensor), default=None, annotation=None |
|
), |
|
default=None, |
|
) |
|
reapply_views_binding = Binding( |
|
name="reapply_views", |
|
nctype=NamedCType(name="reapply_views", type=BaseCType(boolT)), |
|
argument=Argument( |
|
name="reapply_views", type=BaseType(BaseTy.bool), default=None, annotation=None |
|
), |
|
default=None, |
|
) |
|
|
|
InverseReturnModeT = BaseCppType("at::functionalization", "InverseReturnMode") |
|
inverse_return_mode_binding = Binding( |
|
name="inverse_return_mode", |
|
nctype=NamedCType(name="inverse_return_mode", type=BaseCType(InverseReturnModeT)), |
|
argument=Argument( |
|
name="inverse_return_mode", |
|
|
|
type=BaseType(BaseTy.bool), |
|
default=None, |
|
annotation=None, |
|
), |
|
default=None, |
|
) |
|
|
|
|
|
|
|
|
|
def name( |
|
g: NativeFunctionsViewGroup, |
|
*, |
|
is_reverse: bool, |
|
include_namespace: bool, |
|
reapply_views: bool | None = None, |
|
) -> str: |
|
if reapply_views is None: |
|
|
|
|
|
assert is_reverse |
|
if is_reverse: |
|
return reverse_name(g.view, include_namespace) |
|
|
|
assert include_namespace |
|
assert g.view_copy is not None |
|
api_name = ( |
|
g.view.func.name.unambiguous_name() |
|
if reapply_views |
|
else g.view_copy.func.name.unambiguous_name() |
|
) |
|
return f"at::_ops::{api_name}::call" |
|
|
|
|
|
def reverse_name(f: NativeFunction, include_namespace: bool) -> str: |
|
|
|
|
|
|
|
api_name = f.func.name.unambiguous_name() |
|
|
|
if include_namespace: |
|
return f"at::functionalization::FunctionalInverses::{api_name}_inverse" |
|
else: |
|
return f"{api_name}_inverse" |
|
|
|
|
|
def capture_arguments(func: FunctionSchema, *, is_reverse: bool) -> list[Binding]: |
|
|
|
|
|
|
|
args = func.arguments.flat_all |
|
assert args[0].type == BaseType(BaseTy.Tensor) |
|
non_self_args = args[1:] |
|
non_self_value_bindings = [ |
|
dispatcher.argument(a, remove_non_owning_ref_types=True) for a in non_self_args |
|
] |
|
|
|
all_bindings = [ |
|
inverse_return_mode_binding if is_reverse else reapply_views_binding |
|
] |
|
all_bindings.extend(non_self_value_bindings) |
|
return all_bindings |
|
|
|
|
|
def returns_type(func: FunctionSchema) -> CType: |
|
|
|
assert len(func.returns) >= 1 |
|
for ret in func.returns: |
|
assert ret.type.is_tensor_like() |
|
|
|
|
|
return BaseCType(tensorT) |
|
|
|
|
|
def outer_arguments(*, is_reverse: bool) -> list[Binding]: |
|
if is_reverse: |
|
return [base_binding, mutated_view_binding, mutated_view_idx_binding] |
|
else: |
|
return [base_binding, mutated_view_idx_binding] |
|
|
|
|
|
def inner_call_index(func: FunctionSchema) -> Binding | None: |
|
|
|
|
|
if len(func.returns) > 1 or ( |
|
len(func.returns) == 1 and func.returns[0].type.is_list_like() |
|
): |
|
return mutated_view_idx_binding |
|
return None |
|
|
|
|
|
def inner_arguments(func: FunctionSchema, is_reverse: bool) -> list[Binding]: |
|
args = func.arguments.flat_all |
|
assert args[0].type == BaseType(BaseTy.Tensor) |
|
non_self_args = args[1:] |
|
|
|
|
|
non_self_bindings = [dispatcher.argument(a) for a in non_self_args] |
|
if not is_reverse: |
|
|
|
return [base_binding] + non_self_bindings |
|
else: |
|
|
|
|
|
|
|
index_binding = inner_call_index(func) |
|
if index_binding is not None: |
|
return [ |
|
base_binding, |
|
mutated_view_binding, |
|
inverse_return_mode_binding, |
|
index_binding, |
|
] + non_self_bindings |
|
else: |
|
return [ |
|
base_binding, |
|
mutated_view_binding, |
|
inverse_return_mode_binding, |
|
] + non_self_bindings |
|
|