|
|
|
import abc |
|
import contextlib |
|
import ctypes |
|
import importlib |
|
import inspect |
|
import sys |
|
import types |
|
from typing import Any, Callable, Optional, TYPE_CHECKING, TypeVar, Union |
|
from typing_extensions import Concatenate, ParamSpec |
|
|
|
import torch |
|
import torch.utils._pytree as pytree |
|
from torch import _utils_internal |
|
from torch._C import _dispatch_is_included_in_alias as is_included_in_alias, DispatchKey |
|
from torch._functorch.pyfunctorch import dispatch_functorch, TransformType |
|
from torch.utils._python_dispatch import TorchDispatchMode |
|
|
|
|
|
if TYPE_CHECKING: |
|
from torch._subclasses.functional_tensor import BaseFunctionalizeAPI |
|
|
|
|
|
_T = TypeVar("_T") |
|
_P = ParamSpec("_P") |
|
|
|
|
|
|
|
_SET_GLOBAL_FLAGS = hasattr(sys, "getdlopenflags") and hasattr(sys, "setdlopenflags") |
|
|
|
|
|
@contextlib.contextmanager |
|
def dl_open_guard(): |
|
""" |
|
Context manager to set the RTLD_GLOBAL dynamic linker flag while we open a |
|
shared library to load custom operators. |
|
""" |
|
if not _SET_GLOBAL_FLAGS: |
|
yield |
|
return |
|
old_flags = sys.getdlopenflags() |
|
sys.setdlopenflags(old_flags | ctypes.RTLD_GLOBAL) |
|
try: |
|
yield |
|
finally: |
|
sys.setdlopenflags(old_flags) |
|
|
|
|
|
class OperatorBase: |
|
""" |
|
Base class for OpOverload (which represents C++ ATen operators) and HigherOrderOperator |
|
(which represents Python-only operators that are unrepresentable in TorchScript). |
|
""" |
|
|
|
def __init__(self): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self._dispatch_cache: dict[ |
|
DispatchKey, Union[DispatchKey, Callable[..., Any]] |
|
] = {} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.py_kernels: dict[DispatchKey, Callable[..., Any]] = {} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.python_key_table: dict[ |
|
type[Union[TorchDispatchMode, torch.Tensor]], Callable[..., Any] |
|
] = {} |
|
|
|
|
|
|
|
|
|
self.functorch_table = {} |
|
|
|
def __call__(self, *args, **kwargs): |
|
raise NotImplementedError |
|
|
|
def has_kernel_for_dispatch_key(self, k): |
|
return k in self.py_kernels |
|
|
|
def has_kernel_for_any_dispatch_key(self, ks): |
|
for k in self.py_kernels: |
|
if not torch._C._dispatch_is_alias_key(k) and ks.has(k): |
|
return True |
|
return False |
|
|
|
def py_impl( |
|
self, |
|
k: Union[ |
|
type[TorchDispatchMode], |
|
type[torch.Tensor], |
|
TransformType, |
|
DispatchKey, |
|
], |
|
) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]: |
|
def inner(fn: Callable[_P, _T]) -> Callable[_P, _T]: |
|
if inspect.isclass(k) and ( |
|
issubclass(k, TorchDispatchMode) or issubclass(k, torch.Tensor) |
|
): |
|
assert k not in self.python_key_table |
|
|
|
self.python_key_table[k] = fn |
|
self._dispatch_cache.clear() |
|
return fn |
|
|
|
if isinstance(k, TransformType): |
|
assert k not in self.functorch_table |
|
self.functorch_table[k] = fn |
|
return fn |
|
|
|
assert isinstance(k, DispatchKey) |
|
assert k != DispatchKey.Python, ( |
|
"Please register a mode for the DispatchKey.Python key instead." |
|
) |
|
|
|
if k in self.py_kernels: |
|
raise RuntimeError( |
|
f"Trying to override a python impl for {k} on operator {self.name()}" |
|
) |
|
self.py_kernels[k] = fn |
|
self._dispatch_cache.clear() |
|
return fn |
|
|
|
return inner |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def py_functionalize_impl( |
|
self, fn: Callable[Concatenate["BaseFunctionalizeAPI", _P], _T] |
|
) -> Callable[Concatenate["BaseFunctionalizeAPI", _P], _T]: |
|
from torch._subclasses.functional_tensor import ( |
|
CppFunctionalizeAPI, |
|
FunctionalTensorMode, |
|
FunctorchFunctionalizeAPI, |
|
PythonFunctionalizeAPI, |
|
) |
|
|
|
|
|
|
|
def functionalize_dk_fn(*args: _P.args, **kwargs: _P.kwargs) -> _T: |
|
return fn(CppFunctionalizeAPI(), *args, **kwargs) |
|
|
|
def functionalize_dispatch_mode_fn( |
|
mode: Optional[FunctionalTensorMode], *args: _P.args, **kwargs: _P.kwargs |
|
) -> _T: |
|
return fn(PythonFunctionalizeAPI(mode), *args, **kwargs) |
|
|
|
def functionalize_functorch_fn( |
|
interpreter, *args: _P.args, **kwargs: _P.kwargs |
|
) -> _T: |
|
return fn(FunctorchFunctionalizeAPI(interpreter), *args, **kwargs) |
|
|
|
self.py_impl(DispatchKey.Functionalize)(functionalize_dk_fn) |
|
self.py_impl(FunctionalTensorMode)(functionalize_dispatch_mode_fn) |
|
self.py_impl(TransformType.Functionalize)(functionalize_functorch_fn) |
|
|
|
return fn |
|
|
|
def name(self): |
|
raise NotImplementedError |
|
|
|
|
|
|
|
def resolve_key(op: OperatorBase, k: DispatchKey): |
|
|
|
if op.has_kernel_for_dispatch_key(k): |
|
return k |
|
|
|
cand = DispatchKey.CompositeExplicitAutogradNonFunctional |
|
if ( |
|
k == DispatchKey.Undefined or is_included_in_alias(k, cand) |
|
) and op.has_kernel_for_dispatch_key(cand): |
|
return cand |
|
|
|
cand = DispatchKey.CompositeExplicitAutograd |
|
if ( |
|
k == DispatchKey.Undefined or is_included_in_alias(k, cand) |
|
) and op.has_kernel_for_dispatch_key(cand): |
|
return cand |
|
has_backend_kernel = op.has_kernel_for_any_dispatch_key( |
|
torch._C._dispatch_get_backend_keyset_from_autograd(k) |
|
) or op.has_kernel_for_dispatch_key(DispatchKey.CompositeExplicitAutograd) |
|
|
|
cand = DispatchKey.CompositeImplicitAutogradNestedTensor |
|
if ( |
|
(k != DispatchKey.Undefined and is_included_in_alias(k, cand)) |
|
and op.has_kernel_for_dispatch_key(cand) |
|
and not has_backend_kernel |
|
): |
|
return cand |
|
cand = DispatchKey.CompositeImplicitAutograd |
|
if ( |
|
k == DispatchKey.Undefined or is_included_in_alias(k, cand) |
|
) and op.has_kernel_for_dispatch_key(cand): |
|
if k == DispatchKey.AutogradOther and op.has_kernel_for_any_dispatch_key( |
|
torch._C._dispatch_autogradother_backends |
|
): |
|
raise RuntimeError("ambiguous autogradother kernel") |
|
elif not has_backend_kernel: |
|
return cand |
|
|
|
cand = DispatchKey.Autograd |
|
if is_included_in_alias(k, cand) and op.has_kernel_for_dispatch_key(cand): |
|
return cand |
|
|
|
cand = DispatchKey.FuncTorchBatchedDecomposition |
|
if is_included_in_alias(k, cand) and op.has_kernel_for_dispatch_key(cand): |
|
return cand |
|
|
|
if torch._C._dispatch_has_backend_fallback(k): |
|
|
|
|
|
return k |
|
raise NotImplementedError(f"could not find kernel for {op} at dispatch key {k}") |
|
|
|
|
|
_higher_order_ops: dict[str, "HigherOrderOperator"] = {} |
|
|
|
_HIGHER_ORDER_OP_DEFAULT_FALLTHROUGH_DISPATCH_KEYS = [ |
|
DispatchKey.PythonDispatcher, |
|
DispatchKey.PythonTLSSnapshot, |
|
DispatchKey.ADInplaceOrView, |
|
DispatchKey.BackendSelect, |
|
DispatchKey.AutocastCPU, |
|
DispatchKey.AutocastCUDA, |
|
] |
|
|
|
|
|
class HigherOrderOperator(OperatorBase, abc.ABC): |
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, name, *, cacheable=False): |
|
super().__init__() |
|
if type(self) is HigherOrderOperator: |
|
raise RuntimeError( |
|
"Direct instantiation of HigherOrderOperator is not allowed. Please subclass it." |
|
) |
|
self._name = name |
|
|
|
|
|
self.__name__ = name |
|
_higher_order_ops[name] = self |
|
self._ns = "higher_order" |
|
self.__module__ = "torch.ops.higher_order" |
|
self._cacheable = cacheable |
|
|
|
self.non_fallthrough_keys = torch._C._dispatch_keyset_full() |
|
|
|
for dispatch_key in _HIGHER_ORDER_OP_DEFAULT_FALLTHROUGH_DISPATCH_KEYS: |
|
self.fallthrough(dispatch_key) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def py_impl( |
|
self, |
|
k: Union[ |
|
type[TorchDispatchMode], |
|
type[torch.Tensor], |
|
TransformType, |
|
DispatchKey, |
|
], |
|
) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]: |
|
if isinstance(k, DispatchKey) and not self.non_fallthrough_keys.has(k): |
|
self.non_fallthrough_keys = self.non_fallthrough_keys.add(k) |
|
return super().py_impl(k) |
|
|
|
@property |
|
def namespace(self): |
|
return self._ns |
|
|
|
def cacheable(self): |
|
return self._cacheable |
|
|
|
def fallthrough(self, dispatch_key): |
|
self.non_fallthrough_keys = self.non_fallthrough_keys.remove(dispatch_key) |
|
|
|
|
|
|
|
def dispatch(self, /, dispatch_key, *args, **kwargs): |
|
from torch.utils._python_dispatch import _get_current_dispatch_mode |
|
|
|
if dispatch_key in self._dispatch_cache: |
|
kernel = self._dispatch_cache[dispatch_key] |
|
assert not isinstance(kernel, DispatchKey) |
|
return kernel(*args, **kwargs) |
|
|
|
if dispatch_key == DispatchKey.FuncTorchDynamicLayerFrontMode: |
|
return dispatch_functorch(self, args, kwargs) |
|
|
|
if dispatch_key == DispatchKey.Python: |
|
|
|
|
|
|
|
overloaded_args_list = [] |
|
|
|
def has_python_key(tensor): |
|
return torch._C._dispatch_keys(tensor).has("Python") |
|
|
|
def check_overloaded(arg): |
|
if isinstance(arg, torch.Tensor) and has_python_key(arg): |
|
overloaded_args_list.append(arg) |
|
|
|
for arg in (*args, *kwargs.values()): |
|
check_overloaded(arg) |
|
if isinstance(arg, (list, tuple)): |
|
for a in arg: |
|
check_overloaded(a) |
|
|
|
overloaded_args = tuple(overloaded_args_list) |
|
|
|
|
|
from torch.utils._python_dispatch import _pop_mode_temporarily |
|
|
|
curr_mode = _get_current_dispatch_mode() |
|
if curr_mode is not None: |
|
if type(curr_mode) in self.python_key_table: |
|
handler = self.python_key_table[type(curr_mode)] |
|
with _pop_mode_temporarily() as mode: |
|
|
|
|
|
result = handler(mode, *args, **kwargs) |
|
else: |
|
raise NotImplementedError( |
|
f"There was no rule registered for HOP {self._name} and mode {curr_mode}. " |
|
f"We recommend filing an issue." |
|
) |
|
if result is not NotImplemented: |
|
return result |
|
|
|
|
|
for arg in overloaded_args: |
|
subclass_type = type(arg) |
|
if ( |
|
subclass_type.__torch_dispatch__ |
|
== torch._C._disabled_torch_dispatch_impl |
|
): |
|
continue |
|
|
|
|
|
|
|
|
|
|
|
|
|
if subclass_type is torch._subclasses.fake_tensor.FakeTensor: |
|
subclass_type = torch._subclasses.fake_tensor.FakeTensorMode |
|
handler = self.python_key_table[subclass_type] |
|
result = handler(arg.fake_mode, *args, **kwargs) |
|
return result |
|
|
|
if subclass_type in self.python_key_table: |
|
handler = self.python_key_table[subclass_type] |
|
|
|
|
|
result = handler(*args, **kwargs) |
|
else: |
|
raise NotImplementedError( |
|
f"There was no rule registered for HOP {self._name} and subclass {subclass_type}. " |
|
f"We recommend filing an issue." |
|
) |
|
if result is not NotImplemented: |
|
return result |
|
|
|
|
|
raise TypeError( |
|
f"Multiple dispatch failed for {self._name}. There was no registered that " |
|
f"did not return NotImplemented. Use HOP.py_impl to register some. " |
|
f"Tried mode: {curr_mode}) and subclasses: " |
|
f"{[type(a) for a in overloaded_args]}" |
|
) |
|
|
|
functionality_key = torch._C._to_functionality_key(dispatch_key) |
|
if functionality_key == DispatchKey.PreDispatch: |
|
from torch.utils._python_dispatch import _pop_mode_temporarily |
|
|
|
|
|
|
|
if ( |
|
_len_torch_dispatch_stack_pre_dispatch() > 0 |
|
) and not torch._C._dispatch_tls_is_dispatch_key_excluded( |
|
DispatchKey.Python |
|
): |
|
curr_mode = _get_current_dispatch_mode_pre_dispatch() |
|
assert curr_mode is not None, ( |
|
"Illegal invocation of dispatch on DispatchKey.PreDispatch without a mode." |
|
) |
|
assert type(curr_mode) in self.python_key_table, ( |
|
f"Current active mode {curr_mode} not registered" |
|
) |
|
handler = self.python_key_table[type(curr_mode)] |
|
with _pop_mode_temporarily(functionality_key) as mode: |
|
return handler(mode, *args, **kwargs) |
|
|
|
final_key = resolve_key(self, dispatch_key) |
|
|
|
|
|
|
|
if final_key not in self.py_kernels: |
|
raise NotImplementedError( |
|
f"could not find kernel for HigherOrderOperator {self._name} " |
|
f"at dispatch key {final_key} (resolved from {dispatch_key})" |
|
) |
|
|
|
|
|
|
|
|
|
|
|
if dispatch_key != DispatchKey.PreDispatch: |
|
self._dispatch_cache[dispatch_key] = self.py_kernels[final_key] |
|
kernel = self.py_kernels[final_key] |
|
|
|
|
|
assert not isinstance(kernel, DispatchKey) |
|
return kernel(*args, **kwargs) |
|
|
|
@abc.abstractmethod |
|
def __call__(self, /, *args, **kwargs): |
|
def wrapper(): |
|
flat_args = _to_flat_tuple(args, kwargs) |
|
if torch.overrides.has_torch_function(flat_args): |
|
return torch.overrides.handle_torch_function( |
|
self, flat_args, *args, **kwargs |
|
) |
|
|
|
dispatch_key_set = _compute_keyset(args, kwargs, self.non_fallthrough_keys) |
|
return self.dispatch( |
|
dispatch_key_set.highestPriorityTypeId(), *args, **kwargs |
|
) |
|
|
|
return wrapper() |
|
|
|
def __str__(self): |
|
return f"{self.name()}" |
|
|
|
def name(self): |
|
return self._name |
|
|
|
|
|
def _to_flat_tuple(args, kwargs): |
|
return pytree.arg_tree_leaves(*args, **kwargs) |
|
|
|
|
|
def _compute_keyset(args, kwargs, non_fallthrough_keys): |
|
tensors = _get_tensors(args, kwargs) |
|
return key_extractor(tensors, non_fallthrough_keys) |
|
|
|
|
|
def _get_tensors(args, kwargs): |
|
flat_all = _to_flat_tuple(args, kwargs) |
|
tensor_args = [t for t in flat_all if isinstance(t, torch.Tensor)] |
|
return tuple(tensor_args) |
|
|
|
|
|
|
|
|
|
def key_extractor(tensors, key_mask): |
|
key_set = torch._C._dispatch_tls_local_include_set() |
|
for tensor in tensors: |
|
key_set = key_set | torch._C._dispatch_keys(tensor) |
|
key_set = key_set - torch._C._dispatch_tls_local_exclude_set() |
|
key_set = key_set & key_mask |
|
return key_set |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class _ModeStackStateForPreDispatch: |
|
def __init__(self): |
|
self.__infra_modes = [None, None] |
|
self._schema_check_mode = None |
|
|
|
def set(self, index, mode): |
|
assert index < len(self.__infra_modes) |
|
self.__infra_modes[index] = mode |
|
|
|
def get(self, index): |
|
assert index < len(self.__infra_modes) |
|
return self.__infra_modes[index] |
|
|
|
def count(self): |
|
return len([i for i in self.__infra_modes if i is not None]) + int( |
|
self._schema_check_mode is not None |
|
) |
|
|
|
|
|
_mode_stack_state_for_pre_dispatch = _ModeStackStateForPreDispatch() |
|
|
|
|
|
def unset_mode_pre_dispatch(mode_key, schema_check=False): |
|
current_mode_stack_pre_dispatch = mode_stack_state_for_pre_dispatch() |
|
assert mode_key is None or mode_key in ( |
|
torch._C._TorchDispatchModeKey.PROXY, |
|
torch._C._TorchDispatchModeKey.FUNCTIONAL, |
|
) |
|
if schema_check: |
|
assert mode_key is None |
|
|
|
def _unset_mode(): |
|
if mode_key == torch._C._TorchDispatchModeKey.PROXY: |
|
current_mode = current_mode_stack_pre_dispatch.get(0) |
|
mode_stack_state_for_pre_dispatch().set(0, None) |
|
return current_mode |
|
elif mode_key == torch._C._TorchDispatchModeKey.FUNCTIONAL: |
|
current_mode = current_mode_stack_pre_dispatch.get(1) |
|
mode_stack_state_for_pre_dispatch().set(1, None) |
|
return current_mode |
|
else: |
|
current_mode = mode_stack_state_for_pre_dispatch()._schema_check_mode |
|
mode_stack_state_for_pre_dispatch()._schema_check_mode = None |
|
return current_mode |
|
|
|
current_mode = _unset_mode() |
|
|
|
new_pre_dispatch_len = _len_torch_dispatch_stack_pre_dispatch() |
|
|
|
|
|
|
|
|
|
if new_pre_dispatch_len == 0: |
|
torch._C._dispatch_tls_set_dispatch_key_included(DispatchKey.PreDispatch, False) |
|
|
|
return current_mode |
|
|
|
|
|
def _set_mode_pre_dispatch(mode): |
|
from torch._subclasses.functional_tensor import FunctionalTensorMode |
|
from torch._subclasses.schema_check_mode import SchemaCheckMode |
|
from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode |
|
|
|
assert isinstance( |
|
mode, |
|
( |
|
FunctionalTensorMode, |
|
ProxyTorchDispatchMode, |
|
SchemaCheckMode, |
|
), |
|
) |
|
|
|
previous_mode_stack_len = _len_torch_dispatch_stack_pre_dispatch() |
|
if isinstance(mode, SchemaCheckMode): |
|
current_mode = mode_stack_state_for_pre_dispatch()._schema_check_mode |
|
if previous_mode_stack_len > 0: |
|
raise AssertionError( |
|
"SchemaCheckMode for pre-dispatch must be used exclusively, found other modes on the stack" |
|
) |
|
mode_stack_state_for_pre_dispatch()._schema_check_mode = mode |
|
elif isinstance(mode, FunctionalTensorMode): |
|
current_mode = mode_stack_state_for_pre_dispatch().get(1) |
|
assert current_mode is None |
|
mode_stack_state_for_pre_dispatch().set(1, mode) |
|
else: |
|
current_mode = mode_stack_state_for_pre_dispatch().get(0) |
|
assert current_mode is None |
|
mode_stack_state_for_pre_dispatch().set(0, mode) |
|
|
|
|
|
|
|
|
|
|
|
if previous_mode_stack_len == 0: |
|
torch._C._dispatch_tls_set_dispatch_key_included(DispatchKey.PreDispatch, True) |
|
|
|
|
|
def _pop_mode_from_pre_dispatch(): |
|
mode_stack = mode_stack_state_for_pre_dispatch() |
|
pre_dispatch_len = _len_torch_dispatch_stack_pre_dispatch() |
|
|
|
if pre_dispatch_len == 0: |
|
raise AssertionError("Trying to pop empty mode stack") |
|
|
|
if mode_stack._schema_check_mode is not None: |
|
return unset_mode_pre_dispatch(None, schema_check=True) |
|
if mode_stack.get(1) is not None: |
|
return unset_mode_pre_dispatch(torch._C._TorchDispatchModeKey.FUNCTIONAL) |
|
if mode_stack.get(0) is not None: |
|
return unset_mode_pre_dispatch(torch._C._TorchDispatchModeKey.PROXY) |
|
|
|
|
|
def _len_torch_dispatch_stack_pre_dispatch(): |
|
return mode_stack_state_for_pre_dispatch().count() |
|
|
|
|
|
def _get_dispatch_mode_pre_dispatch(mode_key): |
|
assert mode_key in ( |
|
torch._C._TorchDispatchModeKey.PROXY, |
|
torch._C._TorchDispatchModeKey.FUNCTIONAL, |
|
) |
|
if mode_key == torch._C._TorchDispatchModeKey.PROXY: |
|
return mode_stack_state_for_pre_dispatch().get(0) |
|
else: |
|
return mode_stack_state_for_pre_dispatch().get(1) |
|
|
|
|
|
def _get_current_dispatch_mode_pre_dispatch(): |
|
if mode_stack_state_for_pre_dispatch()._schema_check_mode is not None: |
|
return mode_stack_state_for_pre_dispatch()._schema_check_mode |
|
else: |
|
stack_len = mode_stack_state_for_pre_dispatch().count() |
|
if stack_len == 2: |
|
return mode_stack_state_for_pre_dispatch().get(1) |
|
if stack_len == 1: |
|
return ( |
|
mode_stack_state_for_pre_dispatch().get(1) |
|
if mode_stack_state_for_pre_dispatch().get(1) is not None |
|
else mode_stack_state_for_pre_dispatch().get(0) |
|
) |
|
return None |
|
|
|
|
|
def mode_stack_state_for_pre_dispatch(): |
|
global _mode_stack_state_for_pre_dispatch |
|
return _mode_stack_state_for_pre_dispatch |
|
|
|
|
|
cached_ops: set["OpOverload"] = set() |
|
|
|
|
|
def add_cached_op(op_overload): |
|
global cached_ops |
|
cached_ops.add(op_overload) |
|
|
|
|
|
def reset_cached_ops(): |
|
global cached_ops |
|
cached_ops.clear() |
|
|
|
|
|
def get_cached_ops(): |
|
global cached_ops |
|
return cached_ops |
|
|
|
|
|
|
|
|
|
class OpOverload(OperatorBase): |
|
def __init__(self, overloadpacket, op, op_dk, schema, tags): |
|
super().__init__() |
|
self._op = op |
|
self._op_dk = op_dk |
|
self._schema = schema |
|
self._overloadpacket = overloadpacket |
|
self._tags = tags |
|
self._overloadname = ( |
|
"default" if schema.overload_name == "" else schema.overload_name |
|
) |
|
if tags: |
|
self._nondeterministic_seeded = torch.Tag.nondeterministic_seeded in tags |
|
self._name = self._schema.name |
|
if schema.overload_name: |
|
self._name += "." + schema.overload_name |
|
self.__name__ = f"{self._schema.name.split('::')[1]}.{self._overloadname}" |
|
self.__module__ = overloadpacket.__module__ |
|
op.__module__ = overloadpacket.__module__ |
|
self.__qualname__ = self._name |
|
self.__annotations__ = {} |
|
|
|
|
|
self._lazy_handle = None |
|
|
|
|
|
self._defined_in_python = self.__qualname__ in torch.library._defs |
|
|
|
|
|
is_write = None |
|
for a in self._schema.arguments: |
|
if a.alias_info is None: |
|
continue |
|
if is_write is None: |
|
is_write = a.alias_info.is_write |
|
else: |
|
|
|
|
|
is_write = a.alias_info.is_write or is_write |
|
self.is_view = is_write is not None and not is_write |
|
|
|
@property |
|
def _namespace(self): |
|
return self._schema.name.split("::")[0] |
|
|
|
@property |
|
def _opname(self): |
|
return self._schema.name.split("::")[1] |
|
|
|
@property |
|
def _handle(self): |
|
if self._lazy_handle is None: |
|
self._lazy_handle = torch._C._dispatch_find_schema_or_throw( |
|
self._schema.name, self._schema.overload_name |
|
) |
|
return self._lazy_handle |
|
|
|
|
|
def __deepcopy__(self, memo=None): |
|
return self |
|
|
|
def __repr__(self): |
|
return "<OpOverload(op='{}.{}', overload='{}')>".format( |
|
*self._schema.name.split("::"), self._overloadname |
|
) |
|
|
|
|
|
|
|
def __call__(self, /, *args, **kwargs): |
|
return self._op(*args, **kwargs) |
|
|
|
|
|
|
|
def redispatch(self, /, keyset, *args, **kwargs): |
|
return self._handle.redispatch_boxed(keyset, *args, **kwargs) |
|
|
|
def __hash__(self): |
|
return hash(self._op) |
|
|
|
|
|
def __str__(self): |
|
return "{}.{}.{}".format(*self._schema.name.split("::"), self._overloadname) |
|
|
|
def has_kernel_for_dispatch_key(self, k): |
|
return super().has_kernel_for_dispatch_key( |
|
k |
|
) or torch._C._dispatch_has_kernel_for_dispatch_key(self.name(), k) |
|
|
|
def has_kernel_for_any_dispatch_key(self, ks): |
|
return torch._C._dispatch_has_kernel_for_any_dispatch_key( |
|
self.name(), ks |
|
) or super().has_kernel_for_any_dispatch_key(ks) |
|
|
|
@property |
|
def namespace(self): |
|
return self._schema.name.split("::")[0] |
|
|
|
def _can_decompose(self): |
|
dk = DispatchKey.CompositeImplicitAutograd |
|
return dk in self.py_kernels or torch._C._dispatch_has_kernel_for_dispatch_key( |
|
self.name(), dk |
|
) |
|
|
|
def decompose(self, *args, **kwargs): |
|
dk = DispatchKey.CompositeImplicitAutograd |
|
if dk in self.py_kernels: |
|
|
|
|
|
|
|
|
|
return self.py_kernels[dk](*args, **kwargs) |
|
elif torch._C._dispatch_has_kernel_for_dispatch_key(self.name(), dk): |
|
return self._op_dk(dk, *args, **kwargs) |
|
else: |
|
return NotImplemented |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _uncache_dispatch(self, key): |
|
self._dispatch_cache.pop(key, None) |
|
|
|
|
|
def _get_dispatch(self, key): |
|
|
|
assert key not in self._dispatch_cache, f"{self} {key}" |
|
|
|
if key == DispatchKey.Python: |
|
if not isinstance(self, TorchBindOpOverload) and not self.python_key_table: |
|
self._dispatch_cache[key] = key |
|
add_cached_op(self) |
|
return key |
|
|
|
def handler(*args, **kwargs): |
|
from torch.utils._python_dispatch import _get_current_dispatch_mode |
|
|
|
|
|
|
|
curr_mode = type(_get_current_dispatch_mode()) |
|
assert curr_mode is not None, ( |
|
"Illegal invocation of dispatch on DispatchKey.Python without a mode." |
|
) |
|
|
|
if curr_mode not in self.python_key_table: |
|
if isinstance(self, TorchBindOpOverload): |
|
with ( |
|
torch.utils._python_dispatch._pop_mode_temporarily() as mode |
|
): |
|
return torch._library.utils.handle_dispatch_mode( |
|
mode, self, *args, **kwargs |
|
) |
|
else: |
|
return self._op_dk(key, *args, **kwargs) |
|
|
|
with torch.utils._python_dispatch._pop_mode_temporarily() as mode: |
|
return self.python_key_table[curr_mode](mode, *args, **kwargs) |
|
|
|
self._dispatch_cache[key] = handler |
|
add_cached_op(self) |
|
return handler |
|
|
|
functionality_key = torch._C._to_functionality_key(key) |
|
if functionality_key == DispatchKey.PreDispatch: |
|
curr_stack_len = _len_torch_dispatch_stack_pre_dispatch() |
|
|
|
|
|
if ( |
|
curr_stack_len > 0 |
|
and not torch._C._dispatch_tls_is_dispatch_key_excluded( |
|
DispatchKey.Python |
|
) |
|
): |
|
|
|
def handler(*args, **kwargs): |
|
@contextlib.contextmanager |
|
def _temporarily_pop_modes_from_pre_dispatch(): |
|
top_mode = _pop_mode_from_pre_dispatch() |
|
try: |
|
yield top_mode |
|
finally: |
|
_set_mode_pre_dispatch(top_mode) |
|
|
|
with _temporarily_pop_modes_from_pre_dispatch() as curr_mode: |
|
return torch._library.utils.handle_dispatch_mode( |
|
curr_mode, self, *args, **kwargs |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
return handler |
|
|
|
final_key = resolve_key(self, key) |
|
|
|
|
|
cache_result = key != DispatchKey.PreDispatch |
|
|
|
|
|
|
|
|
|
if key == DispatchKey.Functionalize: |
|
import torch._dispatch.python as pydispatch |
|
|
|
if pydispatch.CROSSREF_FUNCTIONALIZE: |
|
handler = pydispatch.make_crossref_functionalize(self, final_key) |
|
if cache_result: |
|
self._dispatch_cache[key] = handler |
|
add_cached_op(self) |
|
return handler |
|
|
|
r = self.py_kernels.get(final_key, final_key) |
|
if cache_result: |
|
self._dispatch_cache[key] = r |
|
add_cached_op(self) |
|
return r |
|
|
|
def name(self): |
|
return self._name |
|
|
|
@property |
|
def overloadpacket(self): |
|
return self._overloadpacket |
|
|
|
@property |
|
def op(self): |
|
return self._op |
|
|
|
@property |
|
def tags(self): |
|
return self._tags |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TorchBindOpOverload(OpOverload): |
|
def _fallthrough_keys(self) -> list[DispatchKey]: |
|
|
|
|
|
_DEFAULT_FALLTHROUGH_KEYS = [ |
|
DispatchKey.Autograd, |
|
DispatchKey.AutogradCPU, |
|
DispatchKey.AutogradCUDA, |
|
DispatchKey.ADInplaceOrView, |
|
DispatchKey.BackendSelect, |
|
DispatchKey.PythonTLSSnapshot, |
|
DispatchKey.PythonDispatcher, |
|
] |
|
|
|
def _may_use_fallthrough_instead_of_fallback(key: DispatchKey): |
|
if torch._C._dispatch_has_kernel_for_dispatch_key(self.name(), key): |
|
return torch._C._dispatch_kernel_for_dispatch_key_is_fallthrough( |
|
self.name(), key |
|
) |
|
|
|
return ( |
|
key not in self.py_kernels |
|
or self.py_kernels[key] is torch.library.fallthrough_kernel |
|
) |
|
|
|
return [ |
|
key |
|
for key in _DEFAULT_FALLTHROUGH_KEYS |
|
if _may_use_fallthrough_instead_of_fallback(key) |
|
] |
|
|
|
@contextlib.contextmanager |
|
def _register_as_effectful_op_temporarily(self): |
|
from torch._higher_order_ops.effects import ( |
|
_EffectType, |
|
_register_effectful_op, |
|
SIDE_EFFECTS, |
|
) |
|
|
|
try: |
|
if self not in SIDE_EFFECTS: |
|
_register_effectful_op(self, _EffectType.ORDERED) |
|
yield |
|
finally: |
|
if self in SIDE_EFFECTS: |
|
del SIDE_EFFECTS[self] |
|
|
|
|
|
|
|
def __call__(self, /, *args, **kwargs): |
|
if _must_dispatch_in_python(args, kwargs): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with self._register_as_effectful_op_temporarily(): |
|
return self._dispatch_in_python(args, kwargs, self._fallthrough_keys()) |
|
return self._op(*args, **kwargs) |
|
|
|
def _dispatch_in_python(self, args, kwargs, fallthrough_keys): |
|
non_fallthrough_keys = torch._C._dispatch_keyset_full() |
|
for key in fallthrough_keys: |
|
non_fallthrough_keys = non_fallthrough_keys.remove(key) |
|
|
|
dispatch_key_set = _compute_keyset(args, kwargs, non_fallthrough_keys) |
|
dispatch_key = dispatch_key_set.highestPriorityTypeId() |
|
|
|
handler = ( |
|
self._get_dispatch(dispatch_key) |
|
if dispatch_key not in self._dispatch_cache |
|
else self._dispatch_cache[dispatch_key] |
|
) |
|
|
|
if isinstance(handler, DispatchKey): |
|
|
|
|
|
if torch._C._dispatch_kernel_for_dispatch_key_is_fallthrough( |
|
self.name(), dispatch_key |
|
): |
|
return self._dispatch_in_python( |
|
args, kwargs, fallthrough_keys + [dispatch_key] |
|
) |
|
|
|
raise RuntimeError( |
|
f"Torchbind op {self} received a FakeScriptObject input when dispatching {handler}." |
|
f" but no python implementation is found." |
|
f" Please file an issue on this when you encounter this error." |
|
f" This error can happen when you export or compile the model." |
|
f" It can still happpen even if a C++ implementation for {dispatch_key}. " |
|
f" has been registered. That's because FakeScriptObject purely lives in python and cannot work " |
|
f" with a C++ implementation." |
|
) |
|
|
|
assert isinstance(handler, Callable) |
|
return handler(*args, **kwargs) |
|
|
|
|
|
def _must_dispatch_in_python(args, kwargs): |
|
return pytree.tree_any( |
|
lambda obj: isinstance( |
|
obj, torch._library.fake_class_registry.FakeScriptObject |
|
), |
|
(args, kwargs), |
|
) |
|
|
|
|
|
def _has_script_object_arg(schema: torch.FunctionSchema) -> bool: |
|
return any(isinstance(arg.type, torch.ClassType) for arg in schema.arguments) |
|
|
|
|
|
|
|
|
|
class OpOverloadPacket: |
|
def __init__(self, qualified_op_name, op_name, op, overload_names): |
|
|
|
|
|
self._qualified_op_name = qualified_op_name |
|
self.__name__ = op_name |
|
self._op = op |
|
self._overload_names = overload_names |
|
self._dir = [] |
|
self._has_torchbind_op_overload = any( |
|
_has_script_object_arg(schema) for schema in self._schemas.values() |
|
) |
|
|
|
|
|
def __deepcopy__(self, memo=None): |
|
return self |
|
|
|
def __repr__(self): |
|
return "<OpOverloadPacket(op='{}.{}')>".format( |
|
*self._qualified_op_name.split("::") |
|
) |
|
|
|
def __hash__(self): |
|
return hash(self._op) |
|
|
|
def __str__(self): |
|
return "{}.{}".format(*self._qualified_op_name.split("::")) |
|
|
|
@property |
|
def op(self): |
|
return self._op |
|
|
|
@property |
|
def _schemas(self): |
|
return { |
|
overload_name: torch._C._get_schema(self._qualified_op_name, overload_name) |
|
for overload_name in self._overload_names |
|
} |
|
|
|
def __getattr__(self, key): |
|
|
|
if key == "__file__": |
|
return "torch.ops" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
try: |
|
if key.startswith("__"): |
|
return getattr(self._op, key) |
|
except AttributeError: |
|
|
|
|
|
|
|
|
|
raise AttributeError( |
|
f"'{str(self)}' can't have an overload name beginning with '__' and the " |
|
f"underlying op {str(self._op)} has no attribute {key} either." |
|
) from None |
|
|
|
try: |
|
|
|
use_key = "" if key == "default" else key |
|
|
|
op_dk_tags = torch._C._get_operation_overload( |
|
self._qualified_op_name, use_key |
|
) |
|
if op_dk_tags is None: |
|
raise AttributeError( |
|
f"The underlying op of '{str(self)}' has no overload name '{key}'" |
|
) |
|
|
|
op_, op_dk_, tags = op_dk_tags |
|
schema = torch._C._get_schema(self._qualified_op_name, use_key) |
|
overload = ( |
|
OpOverload(self, op_, op_dk_, schema, tags) |
|
if not _has_script_object_arg(schema) |
|
else TorchBindOpOverload(self, op_, op_dk_, schema, tags) |
|
) |
|
|
|
setattr(self, key, overload) |
|
self._dir.append(key) |
|
return overload |
|
except RuntimeError: |
|
raise AttributeError( |
|
f"The underlying op of '{str(self)}' has no overload name '{key}'" |
|
) from None |
|
|
|
def __iter__(self): |
|
return iter(self._dir) |
|
|
|
|
|
|
|
def __call__(self, /, *args, **kwargs): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self._has_torchbind_op_overload and _must_dispatch_in_python(args, kwargs): |
|
return _call_overload_packet_from_python(self, args, kwargs) |
|
return self._op(*args, **(kwargs or {})) |
|
|
|
|
|
def overloads(self): |
|
return [n if n else "default" for n in self._overload_names] |
|
|
|
|
|
|
|
|
|
def _call_overload_packet_from_python(op: OpOverloadPacket, args, kwargs): |
|
|
|
torch_function_called, ret = torch._C._maybe_call_torch_function_for_op_packet( |
|
op, *args, **kwargs |
|
) |
|
|
|
if torch_function_called: |
|
return ret |
|
|
|
|
|
|
|
|
|
|
|
|
|
exceptions = {} |
|
found_op = None |
|
for overload_name in op.overloads(): |
|
op_overload = getattr(op, overload_name) |
|
try: |
|
_ = torch._C._check_schema_allow_fake_script_object( |
|
op_overload._schema, *args, **kwargs |
|
) |
|
found_op = op_overload |
|
break |
|
except RuntimeError as e: |
|
exceptions[overload_name] = e |
|
|
|
if found_op: |
|
return found_op(*args, **kwargs) |
|
|
|
err_msg = ( |
|
f"Fail to match any TorchBindOverload of {op} with following exceptions:\n" |
|
) |
|
for key, msg in exceptions.items(): |
|
err_msg += f"Overload name {key}:\n {msg}\n" |
|
raise RuntimeError(err_msg) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class _OpNamespace(types.ModuleType): |
|
""" |
|
An op namespace to dynamically bind Operators into Python. |
|
|
|
Say a user has created a custom Operator called "my_namespace::my_op". To |
|
call this op, the user will write torch.ops.my_namespace.my_op(...). |
|
At startup, this operation will not yet be bound into Python. Instead, the |
|
following sequence of magic tricks will occur: |
|
1. `torch.ops.my_namespace` will invoke the `__getattr__` magic method |
|
on the `torch.ops` object, which will create a new `_OpNamespace` |
|
object called `my_namespace` and set it as an attribute on the `ops` |
|
object. |
|
2. `torch.ops.my_namespace.my_op` will then invoke `__getattr__` on |
|
the `my_namespace` object, which will retrieve the operation via |
|
`torch.get_operation`, a function bound from C++, and then in a similar |
|
fashion bind this new object onto the `my_namespace` object. |
|
3. `torch.ops.my_namespace.my_op(...)` then calls this new operation |
|
and subsequent accesses will incur no further lookup (the namespace and |
|
operation will already exist). |
|
""" |
|
|
|
def __init__(self, name): |
|
super().__init__("torch.ops." + name) |
|
self.name = name |
|
self._dir = [] |
|
|
|
def __iter__(self): |
|
return iter(self._dir) |
|
|
|
def __getattr__(self, op_name): |
|
|
|
if op_name == "__file__": |
|
return "torch.ops" |
|
elif op_name in ["__origin__", "__self__"]: |
|
raise AttributeError( |
|
f"Invalid attribute '{op_name}' for '_OpNamespace' '{self.name}'" |
|
) |
|
|
|
|
|
|
|
namespace_name = self.name |
|
qualified_op_name = f"{namespace_name}::{op_name}" |
|
module_name = self.__module__ + "." + namespace_name |
|
|
|
try: |
|
op, overload_names = _get_packet(qualified_op_name, module_name) |
|
if op is None: |
|
raise AttributeError( |
|
f"'_OpNamespace' '{self.name}' object has no attribute '{op_name}'" |
|
) |
|
except RuntimeError as e: |
|
|
|
|
|
raise AttributeError( |
|
f"'_OpNamespace' '{self.name}' object has no attribute '{op_name}'" |
|
) from e |
|
|
|
op.__module__ = module_name |
|
opoverloadpacket = OpOverloadPacket( |
|
qualified_op_name, op_name, op, overload_names |
|
) |
|
opoverloadpacket.__module__ = self.__module__ + "." + namespace_name |
|
|
|
|
|
setattr(self, op_name, opoverloadpacket) |
|
self._dir.append(op_name) |
|
return opoverloadpacket |
|
|
|
|
|
def _get_packet(qualname, op_module): |
|
op, overload_names = torch._C._jit_get_operation(qualname) |
|
if op is not None: |
|
|
|
|
|
torch.jit._builtins._register_builtin(op, qualname) |
|
op.__module__ = op_module |
|
return op, overload_names |
|
|
|
|
|
def _refresh_packet(packet): |
|
op, overload_names = _get_packet(packet._qualified_op_name, packet._op.__module__) |
|
assert op is not None |
|
packet._op = op |
|
packet._overload_names = overload_names |
|
|
|
|
|
class _PyOpNamespace(_OpNamespace): |
|
def __init__(self, name, ops): |
|
super().__init__(name) |
|
self._ops = ops |
|
|
|
def __getattr__(self, name): |
|
|
|
op = self._ops.get(name, None) |
|
if op is None: |
|
raise AttributeError( |
|
f"'_PyOpNamespace' '{self.name}' object has no attribute '{name}'" |
|
) |
|
setattr(self, name, op) |
|
return op |
|
|
|
|
|
class _Ops(types.ModuleType): |
|
__file__ = "_ops.py" |
|
|
|
def __init__(self): |
|
super().__init__("torch.ops") |
|
self.loaded_libraries = set() |
|
self._higher_order_op_namespace = _PyOpNamespace( |
|
"torch.ops.higher_order", _higher_order_ops |
|
) |
|
self._dir = [] |
|
|
|
def __getattr__(self, name): |
|
|
|
if name == "higher_order": |
|
return self._higher_order_op_namespace |
|
|
|
|
|
namespace = _OpNamespace(name) |
|
setattr(self, name, namespace) |
|
self._dir.append(name) |
|
return namespace |
|
|
|
def __iter__(self): |
|
return iter(self._dir) |
|
|
|
def import_module(self, module): |
|
""" |
|
Imports a Python module that has torch.library registrations. |
|
|
|
Generally, to extend PyTorch with custom operators, a user will |
|
create a Python module whose import triggers registration of |
|
the custom operators via a torch.ops.load_library call or a call |
|
to one or more torch.library.* APIs. |
|
|
|
It is unexpected for Python modules to have side effects, so some |
|
linters and formatters will complain. Use this API to import Python |
|
modules that contain these torch.library side effects. |
|
|
|
Args: |
|
module (str): The name of the Python module to import |
|
|
|
""" |
|
importlib.import_module(module) |
|
|
|
def load_library(self, path): |
|
""" |
|
Loads a shared library from the given path into the current process. |
|
|
|
The library being loaded may run global initialization code to register |
|
custom operators with the PyTorch JIT runtime. This allows dynamically |
|
loading custom operators. For this, you should compile your operator |
|
and the static registration code into a shared library object, and then |
|
call ``torch.ops.load_library('path/to/libcustom.so')`` to load the |
|
shared object. |
|
|
|
After the library is loaded, it is added to the |
|
``torch.ops.loaded_libraries`` attribute, a set that may be inspected |
|
for the paths of all libraries loaded using this function. |
|
|
|
Args: |
|
path (str): A path to a shared library to load. |
|
""" |
|
if torch._running_with_deploy(): |
|
return |
|
|
|
path = _utils_internal.resolve_library_path(path) |
|
with dl_open_guard(): |
|
|
|
|
|
|
|
ctypes.CDLL(path) |
|
self.loaded_libraries.add(path) |
|
|
|
|
|
|
|
ops: _Ops = _Ops() |
|
|