|
from __future__ import annotations |
|
|
|
import contextlib |
|
import functools |
|
from typing import Any, Callable, Optional, TYPE_CHECKING, TypeVar, Union |
|
|
|
import torchgen.local as local |
|
from torchgen.model import ( |
|
BackendIndex, |
|
DispatchKey, |
|
NativeFunction, |
|
NativeFunctionsGroup, |
|
NativeFunctionsViewGroup, |
|
) |
|
from torchgen.utils import context, S, T |
|
|
|
|
|
if TYPE_CHECKING: |
|
from collections.abc import Iterator |
|
|
|
|
|
|
|
|
|
F = TypeVar( |
|
"F", |
|
NativeFunction, |
|
NativeFunctionsGroup, |
|
NativeFunctionsViewGroup, |
|
Union[NativeFunction, NativeFunctionsGroup], |
|
Union[NativeFunction, NativeFunctionsViewGroup], |
|
) |
|
|
|
F2 = TypeVar( |
|
"F2", |
|
NativeFunction, |
|
NativeFunctionsGroup, |
|
Optional[NativeFunction], |
|
bool, |
|
str, |
|
) |
|
|
|
F3 = TypeVar("F3", tuple[NativeFunction, Any], list[NativeFunction]) |
|
|
|
|
|
@contextlib.contextmanager |
|
def native_function_manager( |
|
g: NativeFunctionsGroup | NativeFunctionsViewGroup | NativeFunction, |
|
) -> Iterator[None]: |
|
if isinstance(g, NativeFunctionsGroup): |
|
|
|
|
|
|
|
|
|
f = g.out |
|
elif isinstance(g, NativeFunctionsViewGroup): |
|
|
|
f = g.view |
|
else: |
|
f = g |
|
with context(lambda: f"in native_functions.yaml line {f.loc}:\n {f.func}"): |
|
with local.parametrize( |
|
use_const_ref_for_mutable_tensors=f.use_const_ref_for_mutable_tensors, |
|
use_ilistref_for_tensor_lists=f.part_of_structured_group, |
|
): |
|
yield |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def with_native_function(func: Callable[[F], T]) -> Callable[[F], T]: |
|
@functools.wraps(func) |
|
def wrapper(f: F) -> T: |
|
with native_function_manager(f): |
|
return func(f) |
|
|
|
return wrapper |
|
|
|
|
|
def with_native_function_and(func: Callable[[F, F2], T]) -> Callable[[F, F2], T]: |
|
@functools.wraps(func) |
|
def wrapper(f: F, f2: F2) -> T: |
|
|
|
with native_function_manager(f): |
|
return func(f, f2) |
|
|
|
return wrapper |
|
|
|
|
|
def method_with_native_function(func: Callable[[S, F], T]) -> Callable[[S, F], T]: |
|
@functools.wraps(func) |
|
def wrapper(slf: S, f: F) -> T: |
|
with native_function_manager(f): |
|
return func(slf, f) |
|
|
|
return wrapper |
|
|
|
|
|
def method_with_nested_native_function( |
|
func: Callable[[S, F3], T], |
|
) -> Callable[[S, F3], T]: |
|
@functools.wraps(func) |
|
def wrapper(slf: S, f: F3) -> T: |
|
with native_function_manager(f[0]): |
|
return func(slf, f) |
|
|
|
return wrapper |
|
|
|
|
|
|
|
|
|
def with_native_function_and_index( |
|
func: Callable[[F, BackendIndex], T], |
|
) -> Callable[[F, BackendIndex], T]: |
|
@functools.wraps(func) |
|
def wrapper(f: F, backend_index: BackendIndex) -> T: |
|
with native_function_manager(f): |
|
return func(f, backend_index) |
|
|
|
return wrapper |
|
|
|
|
|
|
|
def with_native_function_and_indices( |
|
func: Callable[[F, dict[DispatchKey, BackendIndex]], T], |
|
) -> Callable[[F, dict[DispatchKey, BackendIndex]], T]: |
|
@functools.wraps(func) |
|
def wrapper(f: F, backend_indices: dict[DispatchKey, BackendIndex]) -> T: |
|
with native_function_manager(f): |
|
return func(f, backend_indices) |
|
|
|
return wrapper |
|
|