|
|
|
import contextlib |
|
import functools |
|
import inspect |
|
import re |
|
import sys |
|
import traceback |
|
import weakref |
|
from collections.abc import Sequence |
|
from typing import ( |
|
Any, |
|
Callable, |
|
Literal, |
|
Optional, |
|
overload, |
|
TYPE_CHECKING, |
|
TypeVar, |
|
Union, |
|
) |
|
from typing_extensions import deprecated, ParamSpec |
|
|
|
import torch |
|
import torch._library as _library |
|
from torch._library.custom_ops import ( |
|
_cast, |
|
_maybe_get_opdef, |
|
custom_op, |
|
CustomOpDef, |
|
device_types_t, |
|
) |
|
from torch._library.infer_schema import infer_schema |
|
from torch._library.triton import triton_op, wrap_triton |
|
from torch._ops import OpOverload |
|
from torch.types import _dtype |
|
|
|
|
|
__all__ = [ |
|
"Library", |
|
"impl", |
|
"define", |
|
"fallthrough_kernel", |
|
"impl_abstract", |
|
"register_autocast", |
|
"register_fake", |
|
"register_torch_dispatch", |
|
"register_vmap", |
|
"get_ctx", |
|
"custom_op", |
|
"triton_op", |
|
"wrap_triton", |
|
"infer_schema", |
|
] |
|
|
|
_T = TypeVar("_T") |
|
_P = ParamSpec("_P") |
|
|
|
|
|
|
|
|
|
|
|
_impls: set[str] = set() |
|
_defs: set[str] = set() |
|
|
|
|
|
_reserved_namespaces = ["prim"] |
|
|
|
|
|
def fallthrough_kernel(): |
|
""" |
|
A dummy function to pass to ``Library.impl`` in order to register a fallthrough. |
|
""" |
|
raise NotImplementedError("fallthrough_kernel() should never be called.") |
|
|
|
|
|
class Library: |
|
""" |
|
A class to create libraries that can be used to register new operators or |
|
override operators in existing libraries from Python. |
|
A user can optionally pass in a dispatch keyname if they only want to register |
|
kernels corresponding to only one specific dispatch key. |
|
|
|
To create a library to override operators in an existing library (with name ns), set the kind to "IMPL". |
|
To create a new library (with name ns) to register new operators, set the kind to "DEF". |
|
To create a fragment of a possibly existing library to register operators (and bypass |
|
the limitation that there is only one library for a given namespace), set the kind to |
|
"FRAGMENT". |
|
|
|
Args: |
|
ns: library name |
|
kind: "DEF", "IMPL" (default: "IMPL"), "FRAGMENT" |
|
dispatch_key: PyTorch dispatch key (default: "") |
|
""" |
|
|
|
def __init__(self, ns, kind, dispatch_key=""): |
|
if kind not in ("IMPL", "DEF", "FRAGMENT"): |
|
raise ValueError("Unsupported kind: ", kind) |
|
|
|
if ns in _reserved_namespaces and (kind == "DEF" or kind == "FRAGMENT"): |
|
raise ValueError( |
|
ns, |
|
" is a reserved namespace. Please try creating a library with another name.", |
|
) |
|
if torch._running_with_deploy(): |
|
_library.utils.warn_deploy() |
|
return |
|
|
|
frame = traceback.extract_stack(limit=3)[0] |
|
filename, lineno = frame.filename, frame.lineno |
|
self.m: Optional[Any] = torch._C._dispatch_library( |
|
kind, ns, dispatch_key, filename, lineno |
|
) |
|
self.ns = ns |
|
self._op_defs: set[str] = set() |
|
self._op_impls: set[str] = set() |
|
self._registration_handles: list[torch._library.utils.RegistrationHandle] = [] |
|
self.kind = kind |
|
self.dispatch_key = dispatch_key |
|
|
|
|
|
|
|
|
|
weakref.finalize( |
|
self, |
|
_del_library, |
|
_impls, |
|
self._op_impls, |
|
_defs, |
|
self._op_defs, |
|
self._registration_handles, |
|
) |
|
|
|
def __repr__(self): |
|
return f"Library(kind={self.kind}, ns={self.ns}, dispatch_key={self.dispatch_key})>" |
|
|
|
def define(self, schema, alias_analysis="", *, tags=()): |
|
r"""Defines a new operator and its semantics in the ns namespace. |
|
|
|
Args: |
|
schema: function schema to define a new operator. |
|
alias_analysis (optional): Indicates if the aliasing properties of the operator arguments can be |
|
inferred from the schema (default behavior) or not ("CONSERVATIVE"). |
|
tags (Tag | Sequence[Tag]): one or more torch.Tag to apply to this |
|
operator. Tagging an operator changes the operator's behavior |
|
under various PyTorch subsystems; please read the docs for the |
|
torch.Tag carefully before applying it. |
|
|
|
Returns: |
|
name of the operator as inferred from the schema. |
|
|
|
Example:: |
|
>>> my_lib = Library("mylib", "DEF") |
|
>>> my_lib.define("sum(Tensor self) -> Tensor") |
|
""" |
|
if torch._running_with_deploy(): |
|
_library.utils.warn_deploy() |
|
return |
|
|
|
|
|
|
|
if alias_analysis not in ["", "FROM_SCHEMA", "CONSERVATIVE"]: |
|
raise RuntimeError(f"Invalid alias_analysis type {alias_analysis}") |
|
assert self.m is not None |
|
if isinstance(tags, torch.Tag): |
|
tags = (tags,) |
|
|
|
name = schema.split("(")[0] |
|
packet_name = name.split(".")[0] if "." in name else name |
|
has_preexisting_packet = hasattr(torch.ops, self.ns) and hasattr( |
|
getattr(torch.ops, self.ns), packet_name |
|
) |
|
|
|
result = self.m.define(schema, alias_analysis, tuple(tags)) |
|
name = schema.split("(")[0] |
|
qualname = self.ns + "::" + name |
|
|
|
|
|
|
|
if has_preexisting_packet: |
|
ns = getattr(torch.ops, self.ns) |
|
packet = getattr(ns, packet_name) |
|
torch._ops._refresh_packet(packet) |
|
|
|
self._op_defs.add(qualname) |
|
_defs.add(qualname) |
|
return result |
|
|
|
def _register_fake(self, op_name, fn, _stacklevel=1): |
|
r"""Registers the fake impl for an operator defined in the library.""" |
|
if torch._running_with_deploy(): |
|
_library.utils.warn_deploy() |
|
return |
|
|
|
source = torch._library.utils.get_source(_stacklevel + 1) |
|
frame = sys._getframe(_stacklevel) |
|
caller_module = inspect.getmodule(frame) |
|
|
|
|
|
caller_module_name = None if caller_module is None else caller_module.__name__ |
|
|
|
|
|
|
|
if caller_module_name is not None and caller_module_name.startswith( |
|
"torchvision." |
|
): |
|
caller_module_name = None |
|
|
|
qualname = f"{self.ns}::{op_name}" |
|
entry = torch._library.simple_registry.singleton.find(qualname) |
|
if caller_module_name is not None: |
|
func_to_register = _check_pystubs_once(fn, qualname, caller_module_name) |
|
else: |
|
func_to_register = fn |
|
|
|
handle = entry.fake_impl.register(func_to_register, source) |
|
self._registration_handles.append(handle) |
|
|
|
def _register_torch_dispatch_rule(self, op_name, torch_dispatch_class, fn): |
|
r"""Registers a torch_dispatch rule for the given operator and torch_dispatch_class. |
|
|
|
This allows for open registration to specify the behavior between the operator |
|
and the torch_dispatch_class without needing to modify the torch_dispatch_class |
|
or the operator directly. |
|
|
|
The torch_dispatch_class is either a Tensor subclass with `__torch_dispatch__` or a |
|
TorchDispatchMode. |
|
|
|
If it is a Tensor subclass, we expect fn to have the following signature: |
|
(cls, func: OpOverload, types: Tuple[type, ...], args, kwargs) -> Any |
|
|
|
If it is a TorchDispatchMode, we expect fn to have the following signature: |
|
(mode, func: OpOverload, types: Tuple[type, ...], args, kwargs) -> Any |
|
""" |
|
if torch._running_with_deploy(): |
|
_library.utils.warn_deploy() |
|
return |
|
|
|
qualname = f"{self.ns}::{op_name}" |
|
entry = torch._library.simple_registry.singleton.find(qualname) |
|
handle = entry.torch_dispatch_rules.register(torch_dispatch_class, fn) |
|
self._registration_handles.append(handle) |
|
|
|
def _impl_with_aoti_compile(self, op_name, dispatch_key=""): |
|
r"""Register the operator to use the AOTI-compiled implementation. |
|
|
|
Args: |
|
op_name: operator name (along with the overload) or OpOverload object. |
|
dispatch_key: dispatch key that the input function should be registered for. By default, it uses |
|
the dispatch key that the library was created with. |
|
|
|
Example:: |
|
>>> my_lib = Library("aten", "IMPL") |
|
>>> my_lib._impl_with_aoti_compile("div.Tensor", "CPU") |
|
""" |
|
if torch._running_with_deploy(): |
|
_library.utils.warn_deploy() |
|
return |
|
|
|
if dispatch_key == "": |
|
dispatch_key = self.dispatch_key |
|
assert torch.DispatchKeySet(dispatch_key).has(torch._C.DispatchKey.Dense) |
|
|
|
if isinstance(op_name, str): |
|
name = op_name |
|
elif isinstance(op_name, OpOverload): |
|
name = op_name._schema.name |
|
overload_name = op_name._schema.overload_name |
|
if overload_name != "": |
|
name = name + "." + overload_name |
|
else: |
|
raise RuntimeError( |
|
"_impl_with_aoti_compile should be passed either a name or an OpOverload object " |
|
"as the first argument" |
|
) |
|
|
|
key = self.ns + "/" + name.split("::")[-1] + "/" + dispatch_key |
|
if key in _impls: |
|
|
|
|
|
raise RuntimeError( |
|
"This is not allowed since there's already a kernel registered from python overriding {}" |
|
"'s behavior for {} dispatch key and {} namespace.".format( |
|
name.split("::")[-1], dispatch_key, self.ns |
|
) |
|
) |
|
|
|
assert self.m is not None |
|
impl_fn: Callable = self.m.impl_with_aoti_compile |
|
impl_fn(self.ns, name.split("::")[-1], dispatch_key) |
|
|
|
_impls.add(key) |
|
self._op_impls.add(key) |
|
|
|
def impl(self, op_name, fn, dispatch_key="", *, with_keyset=False): |
|
r"""Registers the function implementation for an operator defined in the library. |
|
|
|
Args: |
|
op_name: operator name (along with the overload) or OpOverload object. |
|
fn: function that's the operator implementation for the input dispatch key or :func:`~fallthrough_kernel` |
|
to register a fallthrough. |
|
dispatch_key: dispatch key that the input function should be registered for. By default, it uses |
|
the dispatch key that the library was created with. |
|
with_keyset: flag controlling if the current dispatcher call keyset should be passed as the first argument |
|
to :attr:`fn` when calling. This should be used to create the appropriate keyset for redispatch calls. |
|
|
|
Example:: |
|
>>> my_lib = Library("aten", "IMPL") |
|
>>> def div_cpu(self, other): |
|
>>> return self * (1 / other) |
|
>>> my_lib.impl("div.Tensor", div_cpu, "CPU") |
|
""" |
|
if torch._running_with_deploy(): |
|
_library.utils.warn_deploy() |
|
return |
|
|
|
if not callable(fn): |
|
raise TypeError( |
|
f"Input function is required to be a callable but found type {type(fn)}" |
|
) |
|
if dispatch_key == "": |
|
dispatch_key = self.dispatch_key |
|
|
|
if isinstance(op_name, str): |
|
name = op_name |
|
elif isinstance(op_name, OpOverload): |
|
name = op_name._schema.name |
|
overload_name = op_name._schema.overload_name |
|
if overload_name != "": |
|
name = name + "." + overload_name |
|
else: |
|
raise RuntimeError( |
|
"impl should be passed either a name or an OpOverload object as the first argument" |
|
) |
|
|
|
key = self.ns + "/" + name.split("::")[-1] + "/" + dispatch_key |
|
if key in _impls: |
|
|
|
|
|
raise RuntimeError( |
|
"This is not allowed since there's already a kernel registered from python overriding {}" |
|
"'s behavior for {} dispatch key and {} namespace.".format( |
|
name.split("::")[-1], dispatch_key, self.ns |
|
) |
|
) |
|
|
|
if dispatch_key == "Meta": |
|
dispatcher_op_name = name |
|
if "::" not in dispatcher_op_name: |
|
dispatcher_op_name = f"{self.ns}::{dispatcher_op_name}" |
|
|
|
|
|
|
|
|
|
|
|
if torch._C._dispatch_has_kernel_for_dispatch_key( |
|
dispatcher_op_name, "CompositeImplicitAutograd" |
|
): |
|
raise RuntimeError( |
|
f"We should not register a meta kernel directly to the operator '{name}'," |
|
" because it has a CompositeImplicitAutograd kernel in core." |
|
" Instead we should let the operator decompose, and ensure that we have meta kernels" |
|
" for the base ops that it decomposes into." |
|
) |
|
|
|
assert self.m is not None |
|
self.m.impl( |
|
name, |
|
dispatch_key if dispatch_key != "" else "CompositeImplicitAutograd", |
|
fn, |
|
with_keyset, |
|
) |
|
|
|
_impls.add(key) |
|
self._op_impls.add(key) |
|
|
|
def fallback(self, fn, dispatch_key="", *, with_keyset=False): |
|
r"""Registers the function implementation as the fallback for the given key. |
|
|
|
This function only works for a library with global namespace ("_"). |
|
|
|
Args: |
|
fn: function used as fallback for the given dispatch key or :func:`~fallthrough_kernel` |
|
to register a fallthrough. |
|
dispatch_key: dispatch key that the input function should be registered for. By default, it uses |
|
the dispatch key that the library was created with. |
|
with_keyset: flag controlling if the current dispatcher call keyset should be passed as the first argument |
|
to :attr:`fn` when calling. This should be used to create the appropriate keyset for redispatch calls. |
|
|
|
Example:: |
|
>>> my_lib = Library("_", "IMPL") |
|
>>> def fallback_kernel(op, *args, **kwargs): |
|
>>> # Handle all autocast ops generically |
|
>>> # ... |
|
>>> my_lib.fallback(fallback_kernel, "Autocast") |
|
""" |
|
if torch._running_with_deploy(): |
|
_library.utils.warn_deploy() |
|
return |
|
|
|
if dispatch_key == "": |
|
dispatch_key = self.dispatch_key |
|
|
|
if self.ns != "_": |
|
raise RuntimeError( |
|
f"""Fallback can only be registered using libary fragment on the global namespace "_" but it is {self.ns}""" |
|
) |
|
|
|
assert dispatch_key != "" |
|
assert self.m is not None |
|
|
|
self.m.fallback(dispatch_key, fn, with_keyset) |
|
|
|
def _destroy(self): |
|
if self.m is not None: |
|
self.m.reset() |
|
self.m = None |
|
for handle in self._registration_handles: |
|
handle.destroy() |
|
self._registration_handles.clear() |
|
global _impls |
|
_impls -= self._op_impls |
|
for name in self._op_defs: |
|
|
|
|
|
|
|
|
|
|
|
|
|
ns, name_with_overload = name.split("::") |
|
name = name_with_overload.split(".")[0] |
|
if not hasattr(torch.ops, ns): |
|
continue |
|
namespace = getattr(torch.ops, ns) |
|
if not hasattr(namespace, name): |
|
continue |
|
delattr(namespace, name) |
|
namespace._dir.remove(name) |
|
|
|
|
|
def _del_library( |
|
captured_impls, |
|
op_impls, |
|
captured_defs, |
|
op_defs, |
|
registration_handles, |
|
): |
|
captured_impls -= op_impls |
|
captured_defs -= op_defs |
|
for handle in registration_handles: |
|
handle.destroy() |
|
|
|
|
|
@contextlib.contextmanager |
|
def _scoped_library(*args, **kwargs): |
|
try: |
|
lib = Library(*args, **kwargs) |
|
yield lib |
|
finally: |
|
lib._destroy() |
|
|
|
|
|
_keep_alive: list[Library] = [] |
|
|
|
|
|
NAMELESS_SCHEMA = re.compile(r"\(.*\) -> .*") |
|
|
|
|
|
@functools.singledispatch |
|
def define(qualname, schema, *, lib=None, tags=()): |
|
r"""Defines a new operator. |
|
|
|
In PyTorch, defining an op (short for "operator") is a two step-process: |
|
- we need to define the op (by providing an operator name and schema) |
|
- we need to implement behavior for how the operator interacts with |
|
various PyTorch subsystems, like CPU/CUDA Tensors, Autograd, etc. |
|
|
|
This entrypoint defines the custom operator (the first step) |
|
you must then perform the second step by calling various |
|
``impl_*`` APIs, like :func:`torch.library.impl` or |
|
:func:`torch.library.register_fake`. |
|
|
|
Args: |
|
qualname (str): The qualified name for the operator. Should be |
|
a string that looks like "namespace::name", e.g. "aten::sin". |
|
Operators in PyTorch need a namespace to |
|
avoid name collisions; a given operator may only be created once. |
|
If you are writing a Python library, we recommend the namespace to |
|
be the name of your top-level module. |
|
schema (str): The schema of the operator. E.g. "(Tensor x) -> Tensor" |
|
for an op that accepts one Tensor and returns one Tensor. It does |
|
not contain the operator name (that is passed in ``qualname``). |
|
lib (Optional[Library]): If provided, the lifetime of this operator |
|
will be tied to the lifetime of the Library object. |
|
tags (Tag | Sequence[Tag]): one or more torch.Tag to apply to this |
|
operator. Tagging an operator changes the operator's behavior |
|
under various PyTorch subsystems; please read the docs for the |
|
torch.Tag carefully before applying it. |
|
|
|
Example:: |
|
>>> import torch |
|
>>> import numpy as np |
|
>>> |
|
>>> # Define the operator |
|
>>> torch.library.define("mylib::sin", "(Tensor x) -> Tensor") |
|
>>> |
|
>>> # Add implementations for the operator |
|
>>> @torch.library.impl("mylib::sin", "cpu") |
|
>>> def f(x): |
|
>>> return torch.from_numpy(np.sin(x.numpy())) |
|
>>> |
|
>>> # Call the new operator from torch.ops. |
|
>>> x = torch.randn(3) |
|
>>> y = torch.ops.mylib.sin(x) |
|
>>> assert torch.allclose(y, x.sin()) |
|
|
|
""" |
|
if not isinstance(qualname, str): |
|
raise ValueError( |
|
f"define(qualname, schema): expected qualname " |
|
f"to be instance of str, got {type(qualname)}" |
|
) |
|
namespace, name = torch._library.utils.parse_namespace(qualname) |
|
if lib is None: |
|
lib = Library(namespace, "FRAGMENT") |
|
_keep_alive.append(lib) |
|
if not NAMELESS_SCHEMA.fullmatch(schema): |
|
raise ValueError( |
|
f"define(qualname, schema, ...): expected schema " |
|
f'to look like e.g. "(Tensor x) -> Tensor" but ' |
|
f'got "{schema}"' |
|
) |
|
lib.define(name + schema, alias_analysis="", tags=tags) |
|
|
|
|
|
@define.register |
|
def _(lib: Library, schema, alias_analysis=""): |
|
"""The old torch.library.define. |
|
We're keeping this around for BC reasons |
|
""" |
|
|
|
def wrap(f): |
|
name = lib.define(schema, alias_analysis) |
|
lib.impl(name, f) |
|
return f |
|
|
|
return wrap |
|
|
|
|
|
@overload |
|
def impl( |
|
qualname: str, |
|
types: Union[str, Sequence[str]], |
|
func: Literal[None] = None, |
|
*, |
|
lib: Optional[Library] = None, |
|
) -> Callable[[Callable[..., object]], None]: ... |
|
|
|
|
|
@overload |
|
def impl( |
|
qualname: str, |
|
types: Union[str, Sequence[str]], |
|
func: Callable[..., object], |
|
*, |
|
lib: Optional[Library] = None, |
|
) -> None: ... |
|
|
|
|
|
|
|
@overload |
|
def impl( |
|
lib: Library, |
|
name: str, |
|
dispatch_key: str = "", |
|
) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]: ... |
|
|
|
|
|
@functools.singledispatch |
|
def impl( |
|
qualname: str, |
|
types: Union[str, Sequence[str]], |
|
func: Optional[Callable[_P, _T]] = None, |
|
*, |
|
lib: Optional[Library] = None, |
|
) -> object: |
|
"""Register an implementation for a device type for this operator. |
|
|
|
You may pass "default" for ``types`` to register this implementation as the |
|
default implementation for ALL device types. |
|
Please only use this if the implementation truly supports all device types; |
|
for example, this is true if it is a composition of built-in PyTorch operators. |
|
|
|
This API may be used as a decorator. You can use nested decorators |
|
with this API provided they return a function and are placed inside |
|
this API (see Example 2). |
|
|
|
Some valid types are: "cpu", "cuda", "xla", "mps", "ipu", "xpu". |
|
|
|
Args: |
|
qualname (str): Should be a string that looks like "namespace::operator_name". |
|
types (str | Sequence[str]): The device types to register an impl to. |
|
lib (Optional[Library]): If provided, the lifetime of this registration |
|
will be tied to the lifetime of the Library object. |
|
|
|
Examples: |
|
>>> import torch |
|
>>> import numpy as np |
|
>>> # Example 1: Register function. |
|
>>> # Define the operator |
|
>>> torch.library.define("mylib::mysin", "(Tensor x) -> Tensor") |
|
>>> |
|
>>> # Add implementations for the cpu device |
|
>>> @torch.library.impl("mylib::mysin", "cpu") |
|
>>> def f(x): |
|
>>> return torch.from_numpy(np.sin(x.numpy())) |
|
>>> |
|
>>> x = torch.randn(3) |
|
>>> y = torch.ops.mylib.mysin(x) |
|
>>> assert torch.allclose(y, x.sin()) |
|
>>> |
|
>>> # Example 2: Register function with decorator. |
|
>>> def custom_decorator(func): |
|
>>> def wrapper(*args, **kwargs): |
|
>>> return func(*args, **kwargs) + 1 |
|
>>> return wrapper |
|
>>> |
|
>>> # Define the operator |
|
>>> torch.library.define("mylib::sin_plus_one", "(Tensor x) -> Tensor") |
|
>>> |
|
>>> # Add implementations for the operator |
|
>>> @torch.library.impl("mylib::sin_plus_one", "cpu") |
|
>>> @custom_decorator |
|
>>> def f(x): |
|
>>> return torch.from_numpy(np.sin(x.numpy())) |
|
>>> |
|
>>> # Call the new operator from torch.ops. |
|
>>> x = torch.randn(3) |
|
>>> |
|
>>> y1 = torch.ops.mylib.sin_plus_one(x) |
|
>>> y2 = torch.sin(x) + 1 |
|
>>> assert torch.allclose(y1, y2) |
|
""" |
|
return _impl(qualname, types, func, lib=lib, disable_dynamo=False) |
|
|
|
|
|
if not TYPE_CHECKING: |
|
|
|
@impl.register |
|
def _( |
|
lib: Library, name: str, dispatch_key: str = "" |
|
) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]: |
|
"""Legacy torch.library.impl API. Kept around for BC""" |
|
|
|
def wrap(f: Callable[_P, _T]) -> Callable[_P, _T]: |
|
lib.impl(name, f, dispatch_key) |
|
return f |
|
|
|
return wrap |
|
|
|
|
|
@overload |
|
def _impl( |
|
qualname: str, |
|
types: Union[str, Sequence[str]], |
|
func: Literal[None] = None, |
|
*, |
|
lib: Optional[Library] = None, |
|
disable_dynamo: bool = False, |
|
) -> Callable[[Callable[..., object]], None]: ... |
|
|
|
|
|
@overload |
|
def _impl( |
|
qualname: str, |
|
types: Union[str, Sequence[str]], |
|
func: Callable[..., object], |
|
*, |
|
lib: Optional[Library] = None, |
|
disable_dynamo: bool = False, |
|
) -> None: ... |
|
|
|
|
|
def _impl( |
|
qualname: str, |
|
types: Union[str, Sequence[str]], |
|
func: Optional[Callable[..., object]] = None, |
|
*, |
|
lib: Optional[Library] = None, |
|
disable_dynamo: bool = False, |
|
) -> Optional[Callable[[Callable[..., object]], None]]: |
|
|
|
if isinstance(types, str): |
|
types = (types,) |
|
keys = set({}) |
|
for typ in types: |
|
is_dispatch_key = torch._C._parse_dispatch_key(typ) |
|
if is_dispatch_key: |
|
|
|
|
|
|
|
|
|
|
|
|
|
keys.add(typ) |
|
else: |
|
keys.add(_device_type_to_key(typ)) |
|
|
|
def register_(func: Callable[..., object]) -> None: |
|
namespace, _ = torch._library.utils.parse_namespace(qualname) |
|
|
|
if lib is None: |
|
use_lib = Library(namespace, "FRAGMENT") |
|
_keep_alive.append(use_lib) |
|
else: |
|
use_lib = lib |
|
if disable_dynamo: |
|
|
|
@torch._disable_dynamo |
|
def func_no_dynamo(*args, **kwargs): |
|
return func(*args, **kwargs) |
|
|
|
for key in keys: |
|
use_lib.impl(qualname, func_no_dynamo, key) |
|
else: |
|
for key in keys: |
|
use_lib.impl(qualname, func, key) |
|
|
|
if func is None: |
|
return register_ |
|
else: |
|
register_(func) |
|
return None |
|
|
|
|
|
def _device_type_to_key(device_type: str) -> str: |
|
if device_type == "default": |
|
|
|
|
|
|
|
|
|
return "CompositeExplicitAutograd" |
|
return torch._C._dispatch_key_for_device(device_type) |
|
|
|
|
|
@deprecated( |
|
"`torch.library.impl_abstract` was renamed to `torch.library.register_fake`. Please use that " |
|
"instead; we will remove `torch.library.impl_abstract` in a future version of PyTorch.", |
|
category=FutureWarning, |
|
) |
|
def impl_abstract(qualname, func=None, *, lib=None, _stacklevel=1): |
|
r"""This API was renamed to :func:`torch.library.register_fake` in PyTorch 2.4. |
|
Please use that instead. |
|
""" |
|
if func is not None: |
|
_stacklevel = _stacklevel + 1 |
|
return register_fake(qualname, func, lib=lib, _stacklevel=_stacklevel) |
|
|
|
|
|
_op_identifier = Union[ |
|
str, "torch._ops.OpOverload", "torch._library.custom_ops.CustomOpDef" |
|
] |
|
|
|
|
|
def register_kernel( |
|
op: _op_identifier, |
|
device_types: device_types_t, |
|
func: Optional[Callable] = None, |
|
/, |
|
*, |
|
lib: Optional[Library] = None, |
|
): |
|
"""Register an implementation for a device type for this operator. |
|
|
|
Some valid device_types are: "cpu", "cuda", "xla", "mps", "ipu", "xpu". |
|
This API may be used as a decorator. |
|
|
|
Args: |
|
op (str | OpOverload): The operator to register an impl to. |
|
device_types (None | str | Sequence[str]): The device_types to register an impl to. |
|
If None, we will register to all device types -- please only use |
|
this option if your implementation is truly device-type-agnostic. |
|
func (Callable): The function to register as the implementation for |
|
the given device types. |
|
lib (Optional[Library]): If provided, the lifetime of this registration |
|
|
|
Examples:: |
|
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA) |
|
>>> import torch |
|
>>> from torch import Tensor |
|
>>> from torch.library import custom_op |
|
>>> import numpy as np |
|
>>> |
|
>>> # Create a custom op that works on cpu |
|
>>> @custom_op("mylib::numpy_sin", mutates_args=(), device_types="cpu") |
|
>>> def numpy_sin(x: Tensor) -> Tensor: |
|
>>> x_np = x.numpy() |
|
>>> y_np = np.sin(x_np) |
|
>>> return torch.from_numpy(y_np) |
|
>>> |
|
>>> # Add implementations for the cuda device |
|
>>> @torch.library.register_kernel("mylib::numpy_sin", "cuda") |
|
>>> def _(x): |
|
>>> x_np = x.cpu().numpy() |
|
>>> y_np = np.sin(x_np) |
|
>>> return torch.from_numpy(y_np).to(device=x.device) |
|
>>> |
|
>>> x_cpu = torch.randn(3) |
|
>>> x_cuda = x_cpu.cuda() |
|
>>> assert torch.allclose(numpy_sin(x_cpu), x_cpu.sin()) |
|
>>> assert torch.allclose(numpy_sin(x_cuda), x_cuda.sin()) |
|
|
|
""" |
|
|
|
if not isinstance( |
|
op, (str, torch._ops.OpOverload, torch._library.custom_ops.CustomOpDef) |
|
): |
|
raise ValueError( |
|
f"register_kernel({op}): got unexpected type for op: {type(op)}" |
|
) |
|
if isinstance(op, torch._ops.OpOverload): |
|
op = op._name |
|
opdef = _maybe_get_opdef(op) |
|
if opdef is not None: |
|
return opdef.register_kernel(device_types, func) |
|
assert isinstance(op, str) |
|
if device_types is None: |
|
device_types = "CompositeExplicitAutograd" |
|
|
|
return _impl(op, device_types, func, lib=lib, disable_dynamo=True) |
|
|
|
|
|
def register_autocast( |
|
op: _op_identifier, |
|
device_type: str, |
|
cast_inputs: _dtype, |
|
/, |
|
*, |
|
lib: Optional[Library] = None, |
|
): |
|
r"""Register an autocast dispatch rule for this custom op. |
|
|
|
Valid `device_type` include: "cpu" and "cuda". |
|
|
|
Args: |
|
op (str | OpOverload): The operator to register an autocast dispatch rule to. |
|
device_type(str): Device type to use. 'cuda' or 'cpu'. |
|
The type is the same as the `type` attribute of a :class:`torch.device`. |
|
Thus, you may obtain the device type of a tensor using `Tensor.device.type`. |
|
cast_inputs (:class:`torch.dtype`): When custom op runs in an autocast-enabled region, |
|
casts incoming floating-point Tensors to the target dtype (non-floating-point Tensors |
|
are not affected), then executes custom op with autocast disabled. |
|
lib (Optional[Library]): If provided, the lifetime of this registration |
|
|
|
Examples:: |
|
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA) |
|
>>> import torch |
|
>>> from torch import Tensor |
|
>>> from torch.library import custom_op |
|
>>> |
|
>>> # Create a custom op that works on cuda |
|
>>> @torch.library.custom_op("mylib::my_sin", mutates_args=()) |
|
>>> def my_sin(x: Tensor) -> Tensor: |
|
>>> return torch.sin(x) |
|
>>> |
|
>>> # Register autocast dispatch rule for the cuda device |
|
>>> torch.library.register_autocast("mylib::my_sin", "cuda", torch.float16) |
|
>>> |
|
>>> x = torch.randn(3, dtype=torch.float32, device="cuda") |
|
>>> with torch.autocast("cuda", dtype=torch.float16): |
|
>>> y = torch.ops.mylib.my_sin(x) |
|
>>> assert y.dtype == torch.float16 |
|
|
|
""" |
|
if not isinstance( |
|
op, (str, torch._ops.OpOverload, torch._library.custom_ops.CustomOpDef) |
|
): |
|
raise ValueError( |
|
f"register_autocast({op}): got unexpected type for op: {type(op)}" |
|
) |
|
if device_type not in ["cpu", "cuda"]: |
|
raise ValueError(f"Unknown device type: {device_type}") |
|
|
|
if isinstance(op, torch._ops.OpOverload): |
|
op = op._name |
|
opdef = _maybe_get_opdef(op) |
|
if opdef is not None: |
|
return opdef.register_autocast(device_type, cast_inputs) |
|
|
|
assert isinstance(op, str) |
|
qualname = op |
|
_op = torch._library.utils.lookup_op(qualname) |
|
|
|
namespace, opname = torch._library.utils.parse_namespace(qualname) |
|
if lib is None: |
|
lib = Library(namespace, "FRAGMENT") |
|
_keep_alive.append(lib) |
|
|
|
def kernel(_, *args, **kwargs): |
|
assert len(kwargs) == 0, "Custom ops do not support kwargs yet." |
|
autocast_keyset = torch._C.DispatchKeySet( |
|
torch._C.DispatchKey.AutocastCPU |
|
) | torch._C.DispatchKeySet(torch._C.DispatchKey.AutocastCUDA) |
|
with torch._C._ExcludeDispatchKeyGuard(autocast_keyset): |
|
return _op(*_cast(args, device_type, cast_inputs)) |
|
|
|
if device_type == "cuda": |
|
return lib.impl(opname, kernel, "AutocastCUDA", with_keyset=True) |
|
else: |
|
|
|
return lib.impl(opname, kernel, "AutocastCPU", with_keyset=True) |
|
|
|
|
|
def register_fake( |
|
op: _op_identifier, |
|
func: Optional[Callable] = None, |
|
/, |
|
*, |
|
lib: Optional[Library] = None, |
|
_stacklevel: int = 1, |
|
): |
|
r"""Register a FakeTensor implementation ("fake impl") for this operator. |
|
|
|
Also sometimes known as a "meta kernel", "abstract impl". |
|
|
|
An "FakeTensor implementation" specifies the behavior of this operator on |
|
Tensors that carry no data ("FakeTensor"). Given some input Tensors with |
|
certain properties (sizes/strides/storage_offset/device), it specifies |
|
what the properties of the output Tensors are. |
|
|
|
The FakeTensor implementation has the same signature as the operator. |
|
It is run for both FakeTensors and meta tensors. To write a FakeTensor |
|
implementation, assume that all Tensor inputs to the operator are |
|
regular CPU/CUDA/Meta tensors, but they do not have storage, and |
|
you are trying to return regular CPU/CUDA/Meta tensor(s) as output. |
|
The FakeTensor implementation must consist of only PyTorch operations |
|
(and may not directly access the storage or data of any input or |
|
intermediate Tensors). |
|
|
|
This API may be used as a decorator (see examples). |
|
|
|
For a detailed guide on custom ops, please see |
|
https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html |
|
|
|
Examples: |
|
>>> import torch |
|
>>> import numpy as np |
|
>>> from torch import Tensor |
|
>>> |
|
>>> # Example 1: an operator without data-dependent output shape |
|
>>> @torch.library.custom_op("mylib::custom_linear", mutates_args=()) |
|
>>> def custom_linear(x: Tensor, weight: Tensor, bias: Tensor) -> Tensor: |
|
>>> raise NotImplementedError("Implementation goes here") |
|
>>> |
|
>>> @torch.library.register_fake("mylib::custom_linear") |
|
>>> def _(x, weight, bias): |
|
>>> assert x.dim() == 2 |
|
>>> assert weight.dim() == 2 |
|
>>> assert bias.dim() == 1 |
|
>>> assert x.shape[1] == weight.shape[1] |
|
>>> assert weight.shape[0] == bias.shape[0] |
|
>>> assert x.device == weight.device |
|
>>> |
|
>>> return (x @ weight.t()) + bias |
|
>>> |
|
>>> with torch._subclasses.fake_tensor.FakeTensorMode(): |
|
>>> x = torch.randn(2, 3) |
|
>>> w = torch.randn(3, 3) |
|
>>> b = torch.randn(3) |
|
>>> y = torch.ops.mylib.custom_linear(x, w, b) |
|
>>> |
|
>>> assert y.shape == (2, 3) |
|
>>> |
|
>>> # Example 2: an operator with data-dependent output shape |
|
>>> @torch.library.custom_op("mylib::custom_nonzero", mutates_args=()) |
|
>>> def custom_nonzero(x: Tensor) -> Tensor: |
|
>>> x_np = x.numpy(force=True) |
|
>>> res = np.stack(np.nonzero(x_np), axis=1) |
|
>>> return torch.tensor(res, device=x.device) |
|
>>> |
|
>>> @torch.library.register_fake("mylib::custom_nonzero") |
|
>>> def _(x): |
|
>>> # Number of nonzero-elements is data-dependent. |
|
>>> # Since we cannot peek at the data in an fake impl, |
|
>>> # we use the ctx object to construct a new symint that |
|
>>> # represents the data-dependent size. |
|
>>> ctx = torch.library.get_ctx() |
|
>>> nnz = ctx.new_dynamic_size() |
|
>>> shape = [nnz, x.dim()] |
|
>>> result = x.new_empty(shape, dtype=torch.int64) |
|
>>> return result |
|
>>> |
|
>>> from torch.fx.experimental.proxy_tensor import make_fx |
|
>>> |
|
>>> x = torch.tensor([0, 1, 2, 3, 4, 0]) |
|
>>> trace = make_fx(torch.ops.mylib.custom_nonzero, tracing_mode="symbolic")(x) |
|
>>> trace.print_readable() |
|
>>> |
|
>>> assert torch.allclose(trace(x), torch.ops.mylib.custom_nonzero(x)) |
|
|
|
""" |
|
if not isinstance( |
|
op, (str, torch._ops.OpOverload, torch._library.custom_ops.CustomOpDef) |
|
): |
|
raise ValueError(f"register_fake({op}): got unexpected type for op: {type(op)}") |
|
if isinstance(op, torch._ops.OpOverload): |
|
op = op._name |
|
opdef = _maybe_get_opdef(op) |
|
if opdef is not None: |
|
if func is None: |
|
return opdef.register_fake |
|
else: |
|
return opdef.register_fake(func) |
|
assert isinstance(op, str) |
|
|
|
stacklevel = _stacklevel |
|
|
|
def register(func): |
|
namespace, op_name = torch._library.utils.parse_namespace(op) |
|
if lib is None: |
|
use_lib = Library(namespace, "FRAGMENT") |
|
_keep_alive.append(use_lib) |
|
else: |
|
use_lib = lib |
|
use_lib._register_fake(op_name, func, _stacklevel=stacklevel + 1) |
|
return func |
|
|
|
if func is None: |
|
return register |
|
else: |
|
stacklevel += 1 |
|
return register(func) |
|
|
|
|
|
def register_autograd( |
|
op: _op_identifier, |
|
backward: Callable, |
|
/, |
|
*, |
|
setup_context: Optional[Callable] = None, |
|
lib=None, |
|
) -> None: |
|
r"""Register a backward formula for this custom op. |
|
|
|
In order for an operator to work with autograd, you need to register |
|
a backward formula: |
|
1. You must tell us how to compute gradients during the backward pass |
|
by providing us a "backward" function. |
|
2. If you need any values from the forward to compute gradients, you can |
|
use `setup_context` to save values for backward. |
|
|
|
``backward`` runs during the backward pass. It accepts ``(ctx, *grads)``: |
|
- ``grads`` is one or more gradients. The number of gradients matches |
|
the number of outputs of the operator. |
|
The ``ctx`` object is `the same ctx object <context_method_mixins>`_ used by |
|
:class:`torch.autograd.Function`. The semantics of ``backward_fn`` are the |
|
same as :meth:`torch.autograd.Function.backward`. |
|
|
|
``setup_context(ctx, inputs, output)`` runs during the forward pass. |
|
Please save quantities needed for backward onto the ``ctx`` object via |
|
either :meth:`torch.autograd.function.FunctionCtx.save_for_backward` |
|
or assigning them as attributes of ``ctx``. If your custom op has |
|
kwarg-only arguments, we expect the signature of ``setup_context`` |
|
to be ``setup_context(ctx, inputs, keyword_only_inputs, output)``. |
|
|
|
Both ``setup_context_fn`` and ``backward_fn`` must be traceable. That is, |
|
they may not directly access :meth:`torch.Tensor.data_ptr` and they must |
|
not depend on or mutate global state. If you need a non-traceable backward, |
|
you can make it a separate custom_op that you call inside ``backward_fn``. |
|
|
|
If you need different autograd behavior on different devices, then we |
|
recommend creating two different custom operators, one for each device |
|
that needs different behavior, and switching between them at runtime. |
|
|
|
Examples: |
|
>>> import torch |
|
>>> import numpy as np |
|
>>> from torch import Tensor |
|
>>> |
|
>>> @torch.library.custom_op("mylib::numpy_sin", mutates_args=()) |
|
>>> def numpy_sin(x: Tensor) -> Tensor: |
|
>>> x_np = x.cpu().numpy() |
|
>>> y_np = np.sin(x_np) |
|
>>> return torch.from_numpy(y_np).to(device=x.device) |
|
>>> |
|
>>> def setup_context(ctx, inputs, output) -> Tensor: |
|
>>> x, = inputs |
|
>>> ctx.save_for_backward(x) |
|
>>> |
|
>>> def backward(ctx, grad): |
|
>>> x, = ctx.saved_tensors |
|
>>> return grad * x.cos() |
|
>>> |
|
>>> torch.library.register_autograd( |
|
... "mylib::numpy_sin", backward, setup_context=setup_context |
|
... ) |
|
>>> |
|
>>> x = torch.randn(3, requires_grad=True) |
|
>>> y = numpy_sin(x) |
|
>>> (grad_x,) = torch.autograd.grad(y, x, torch.ones_like(y)) |
|
>>> assert torch.allclose(grad_x, x.cos()) |
|
>>> |
|
>>> # Example with a keyword-only arg |
|
>>> @torch.library.custom_op("mylib::numpy_mul", mutates_args=()) |
|
>>> def numpy_mul(x: Tensor, *, val: float) -> Tensor: |
|
>>> x_np = x.cpu().numpy() |
|
>>> y_np = x_np * val |
|
>>> return torch.from_numpy(y_np).to(device=x.device) |
|
>>> |
|
>>> def setup_context(ctx, inputs, keyword_only_inputs, output) -> Tensor: |
|
>>> ctx.val = keyword_only_inputs["val"] |
|
>>> |
|
>>> def backward(ctx, grad): |
|
>>> return grad * ctx.val |
|
>>> |
|
>>> torch.library.register_autograd( |
|
... "mylib::numpy_mul", backward, setup_context=setup_context |
|
... ) |
|
>>> |
|
>>> x = torch.randn(3, requires_grad=True) |
|
>>> y = numpy_mul(x, val=3.14) |
|
>>> (grad_x,) = torch.autograd.grad(y, x, torch.ones_like(y)) |
|
>>> assert torch.allclose(grad_x, torch.full_like(x, 3.14)) |
|
|
|
""" |
|
if not isinstance( |
|
op, (str, torch._ops.OpOverload, torch._library.custom_ops.CustomOpDef) |
|
): |
|
raise ValueError( |
|
f"register_autograd({op}): got unexpected type for op: {type(op)}" |
|
) |
|
if isinstance(op, torch._ops.OpOverload): |
|
op = op._name |
|
opdef = _maybe_get_opdef(op) |
|
if opdef is not None: |
|
opdef.register_autograd(backward, setup_context=setup_context) |
|
return |
|
|
|
assert isinstance(op, str) |
|
qualname = op |
|
op = torch._library.utils.lookup_op(qualname) |
|
schema = op._schema |
|
if not _library.utils.is_functional_schema(schema): |
|
raise RuntimeError( |
|
f"Cannot register autograd formula for non-functional operator " |
|
f"{op} with schema {schema}. Please create " |
|
f"a functional operator and register an autograd formula for that." |
|
) |
|
if _library.utils.has_kwarg_only_tensors(schema): |
|
raise NotImplementedError( |
|
f"register_autograd with kwarg-only Tensor args. In the original " |
|
f"definition of the op, please make your tensors not kwarg-only. " |
|
f"Got: {schema}" |
|
) |
|
|
|
info = _library.autograd.Info(backward, setup_context) |
|
autograd_kernel = _library.autograd.make_autograd_impl(op, info) |
|
namespace, opname = torch._library.utils.parse_namespace(qualname) |
|
if lib is None: |
|
lib = Library(namespace, "FRAGMENT") |
|
_keep_alive.append(lib) |
|
lib.impl(opname, autograd_kernel, "Autograd", with_keyset=True) |
|
|
|
|
|
def register_torch_dispatch( |
|
op: _op_identifier, |
|
torch_dispatch_class: Any, |
|
func: Optional[Callable] = None, |
|
/, |
|
*, |
|
lib: Optional[Library] = None, |
|
): |
|
r"""Registers a torch_dispatch rule for the given operator and ``torch_dispatch_class``. |
|
|
|
This allows for open registration to specify the behavior between the operator |
|
and the ``torch_dispatch_class`` without needing to modify the ``torch_dispatch_class`` |
|
or the operator directly. |
|
|
|
The ``torch_dispatch_class`` is either a Tensor subclass with ``__torch_dispatch__`` or a |
|
TorchDispatchMode. |
|
|
|
If it is a Tensor subclass, we expect ``func`` to have the following signature: |
|
``(cls, func: OpOverload, types: Tuple[type, ...], args, kwargs) -> Any`` |
|
|
|
If it is a TorchDispatchMode, we expect ``func`` to have the following signature: |
|
``(mode, func: OpOverload, types: Tuple[type, ...], args, kwargs) -> Any`` |
|
|
|
``args`` and ``kwargs`` will have been normalized the same way they are |
|
in ``__torch_dispatch__`` (see :ref:`torch-dispatch-calling-convention`). |
|
|
|
Examples: |
|
|
|
>>> import torch |
|
>>> |
|
>>> @torch.library.custom_op("mylib::foo", mutates_args={}) |
|
>>> def foo(x: torch.Tensor) -> torch.Tensor: |
|
>>> return x.clone() |
|
>>> |
|
>>> class MyMode(torch.utils._python_dispatch.TorchDispatchMode): |
|
>>> def __torch_dispatch__(self, func, types, args=(), kwargs=None): |
|
>>> return func(*args, **kwargs) |
|
>>> |
|
>>> @torch.library.register_torch_dispatch("mylib::foo", MyMode) |
|
>>> def _(mode, func, types, args, kwargs): |
|
>>> x, = args |
|
>>> return x + 1 |
|
>>> |
|
>>> x = torch.randn(3) |
|
>>> y = foo(x) |
|
>>> assert torch.allclose(y, x) |
|
>>> |
|
>>> with MyMode(): |
|
>>> y = foo(x) |
|
>>> assert torch.allclose(y, x + 1) |
|
|
|
""" |
|
if not isinstance( |
|
op, (str, torch._ops.OpOverload, torch._library.custom_ops.CustomOpDef) |
|
): |
|
raise ValueError( |
|
f"register_torch_dispatch({op}): got unexpected type for op: {type(op)}" |
|
) |
|
if isinstance(op, torch._ops.OpOverload): |
|
op = op._name |
|
opdef = _maybe_get_opdef(op) |
|
if opdef is not None: |
|
return opdef.register_torch_dispatch(torch_dispatch_class, func) |
|
assert isinstance(op, str) |
|
|
|
def register(func): |
|
namespace, op_name = torch._library.utils.parse_namespace(op) |
|
if lib is None: |
|
use_lib = Library(namespace, "FRAGMENT") |
|
_keep_alive.append(use_lib) |
|
else: |
|
use_lib = lib |
|
use_lib._register_torch_dispatch_rule(op_name, torch_dispatch_class, func) |
|
return func |
|
|
|
if func is None: |
|
return register |
|
else: |
|
return register(func) |
|
|
|
|
|
def register_vmap( |
|
op: _op_identifier, |
|
func: Optional[Callable] = None, |
|
/, |
|
*, |
|
lib=None, |
|
): |
|
r"""Register a vmap implementation to support :func:`torch.vmap` for this custom op. |
|
|
|
This API may be used as a decorator (see examples). |
|
|
|
In order for an operator to work with :func:`torch.vmap`, you may need to register a |
|
vmap implementation in the following signature: |
|
|
|
``vmap_func(info, in_dims: Tuple[Optional[int]], *args, **kwargs)``, |
|
|
|
where ``*args`` and ``**kwargs`` are the arguments and kwargs for ``op``. |
|
We do not support kwarg-only Tensor args. |
|
|
|
It specifies how do we compute the batched version of ``op`` given inputs with an additional |
|
dimension (specified by ``in_dims``). |
|
|
|
For each arg in ``args``, ``in_dims`` has a corresponding ``Optional[int]``. It is ``None`` |
|
if the arg is not a Tensor or if the arg is not being vmapped over, otherwise, it is an integer |
|
specifying what dimension of the Tensor is being vmapped over. |
|
|
|
``info`` is a collection of additional metadata that may be helpful: |
|
``info.batch_size`` specifies the size of the dimension being vmapped over, while |
|
``info.randomness`` is the ``randomness`` option that was passed to :func:`torch.vmap`. |
|
|
|
The return of the function ``func`` is a tuple of ``(output, out_dims)``. Similar to ``in_dims``, |
|
``out_dims`` should be of the same structure as ``output`` and contain one ``out_dim`` |
|
per output that specifies if the output has the vmapped dimension and what index it is in. |
|
|
|
Examples: |
|
>>> import torch |
|
>>> import numpy as np |
|
>>> from torch import Tensor |
|
>>> from typing import Tuple |
|
>>> |
|
>>> def to_numpy(tensor): |
|
>>> return tensor.cpu().numpy() |
|
>>> |
|
>>> lib = torch.library.Library("mylib", "FRAGMENT") |
|
>>> @torch.library.custom_op("mylib::numpy_cube", mutates_args=()) |
|
>>> def numpy_cube(x: Tensor) -> Tuple[Tensor, Tensor]: |
|
>>> x_np = to_numpy(x) |
|
>>> dx = torch.tensor(3 * x_np ** 2, device=x.device) |
|
>>> return torch.tensor(x_np ** 3, device=x.device), dx |
|
>>> |
|
>>> def numpy_cube_vmap(info, in_dims, x): |
|
>>> result = numpy_cube(x) |
|
>>> return result, (in_dims[0], in_dims[0]) |
|
>>> |
|
>>> torch.library.register_vmap(numpy_cube, numpy_cube_vmap) |
|
>>> |
|
>>> x = torch.randn(3) |
|
>>> torch.vmap(numpy_cube)(x) |
|
>>> |
|
>>> @torch.library.custom_op("mylib::numpy_mul", mutates_args=()) |
|
>>> def numpy_mul(x: Tensor, y: Tensor) -> Tensor: |
|
>>> return torch.tensor(to_numpy(x) * to_numpy(y), device=x.device) |
|
>>> |
|
>>> @torch.library.register_vmap("mylib::numpy_mul") |
|
>>> def numpy_mul_vmap(info, in_dims, x, y): |
|
>>> x_bdim, y_bdim = in_dims |
|
>>> x = x.movedim(x_bdim, -1) if x_bdim is not None else x.unsqueeze(-1) |
|
>>> y = y.movedim(y_bdim, -1) if y_bdim is not None else y.unsqueeze(-1) |
|
>>> result = x * y |
|
>>> result = result.movedim(-1, 0) |
|
>>> return result, 0 |
|
>>> |
|
>>> |
|
>>> x = torch.randn(3) |
|
>>> y = torch.randn(3) |
|
>>> torch.vmap(numpy_mul)(x, y) |
|
|
|
.. note:: |
|
The vmap function should aim to preserve the semantics of the entire custom operator. |
|
That is, ``grad(vmap(op))`` should be replaceable with a ``grad(map(op))``. |
|
|
|
If your custom operator has any custom behavior in the backward pass, please |
|
keep this in mind. |
|
|
|
""" |
|
if not isinstance( |
|
op, (str, torch._ops.OpOverload, torch._library.custom_ops.CustomOpDef) |
|
): |
|
raise ValueError(f"register_vmap({op}): got unexpected type for op: {type(op)}") |
|
if isinstance(op, torch._ops.OpOverload): |
|
op = op._name |
|
opdef = _maybe_get_opdef(op) |
|
if opdef is not None: |
|
return opdef.register_vmap(func) |
|
assert isinstance(op, str) |
|
qualname = op |
|
op = torch._library.utils.lookup_op(qualname) |
|
schema = op._schema |
|
if _library.utils.has_kwarg_only_tensors(schema): |
|
raise NotImplementedError( |
|
f"register_vmap with kwarg-only Tensor args. In the original " |
|
f"definition of the op, please make your tensors not kwarg-only. " |
|
f"Got: {schema}" |
|
) |
|
|
|
def register(func): |
|
nonlocal op, lib |
|
|
|
namespace, opname = torch._library.utils.parse_namespace(qualname) |
|
if lib is None: |
|
lib = Library(namespace, "FRAGMENT") |
|
_keep_alive.append(lib) |
|
|
|
from torch._functorch.autograd_function import custom_function_call_vmap_helper |
|
from torch._functorch.pyfunctorch import retrieve_current_functorch_interpreter |
|
|
|
def wrapped_func(keyset, *args, **kwargs): |
|
interpreter = retrieve_current_functorch_interpreter() |
|
return custom_function_call_vmap_helper( |
|
interpreter, func, op, *args, **kwargs |
|
) |
|
|
|
lib.impl(opname, wrapped_func, "FuncTorchBatched", with_keyset=True) |
|
|
|
if func is None: |
|
return register |
|
else: |
|
return register(func) |
|
|
|
|
|
|
|
|
|
|
|
def _check_pystubs_once(func, qualname, actual_module_name): |
|
checked = False |
|
|
|
def inner(*args, **kwargs): |
|
nonlocal checked |
|
if checked: |
|
return func(*args, **kwargs) |
|
|
|
op = torch._library.utils.lookup_op(qualname) |
|
if op._defined_in_python: |
|
checked = True |
|
return func(*args, **kwargs) |
|
|
|
maybe_pystub = torch._C._dispatch_pystub( |
|
op._schema.name, op._schema.overload_name |
|
) |
|
if maybe_pystub is None: |
|
if torch._library.utils.requires_set_python_module(): |
|
namespace = op.namespace |
|
cpp_filename = op._handle.debug() |
|
raise RuntimeError( |
|
f"Operator '{qualname}' was defined in C++ and has a Python " |
|
f"fake impl. In this situation, we require there to also be a " |
|
f'companion C++ `m.set_python_module("{actual_module_name}")` ' |
|
f"call, but we could not find one. Please add that to " |
|
f"to the top of the C++ TORCH_LIBRARY({namespace}, ...) block the " |
|
f"operator was registered in ({cpp_filename})" |
|
) |
|
else: |
|
pystub_module = maybe_pystub[0] |
|
if actual_module_name != pystub_module: |
|
cpp_filename = op._handle.debug() |
|
raise RuntimeError( |
|
f"Operator '{qualname}' specified that its python fake impl " |
|
f"is in the Python module '{pystub_module}' but it was actually found " |
|
f"in '{actual_module_name}'. Please either move the fake impl " |
|
f"or correct the m.set_python_module call ({cpp_filename})" |
|
) |
|
checked = True |
|
return func(*args, **kwargs) |
|
|
|
return inner |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_ctx() -> "torch._library.fake_impl.FakeImplCtx": |
|
"""get_ctx() returns the current AbstractImplCtx object. |
|
|
|
Calling ``get_ctx()`` is only valid inside of an fake impl |
|
(see :func:`torch.library.register_fake` for more usage details. |
|
""" |
|
return torch._library.fake_impl.global_ctx_getter() |
|
|
|
|
|
_OPCHECK_DEFAULT_UTILS = ( |
|
"test_schema", |
|
"test_autograd_registration", |
|
"test_faketensor", |
|
"test_aot_dispatch_dynamic", |
|
) |
|
|
|
|
|
def opcheck( |
|
op: Union[torch._ops.OpOverload, torch._ops.OpOverloadPacket, CustomOpDef], |
|
args: tuple[Any, ...], |
|
kwargs: Optional[dict[str, Any]] = None, |
|
*, |
|
test_utils: Union[str, Sequence[str]] = _OPCHECK_DEFAULT_UTILS, |
|
raise_exception: bool = True, |
|
atol=None, |
|
rtol=None, |
|
) -> dict[str, str]: |
|
"""Given an operator and some sample arguments, tests if the operator is |
|
registered correctly. |
|
|
|
That is, when you use the torch.library/TORCH_LIBRARY APIs to create a |
|
custom op, you specified metadata (e.g. mutability info) about the custom op |
|
and these APIs require that the functions you pass them satisfy certain |
|
properties (e.g. no data pointer access in the fake/meta/abstract kernel) |
|
``opcheck`` tests these metadata and properties. |
|
|
|
Concretely, we test the following: |
|
|
|
- test_schema: If the schema matches the implementation of |
|
the operator. For example: if the schema specifies a Tensor is mutated, |
|
then we check the implementation mutates the Tensor. If the schema |
|
specifies that we return a new Tensor, then we check that the |
|
implementation returns a new Tensor (instead of an existing one or |
|
a view of an existing one). |
|
- test_autograd_registration: If the operator supports training |
|
(autograd): we check that its autograd formula is registered via |
|
torch.library.register_autograd or a manual registration to one |
|
or more DispatchKey::Autograd keys. Any other DispatchKey-based |
|
registrations may lead to undefined behavior. |
|
- test_faketensor: If the operator has a FakeTensor kernel |
|
(and if it is correct). The FakeTensor kernel is necessary ( |
|
but not sufficient) for the operator to work with PyTorch compilation |
|
APIs (torch.compile/export/FX). We check that a FakeTensor kernel |
|
(also sometimes known as a meta kernel) was registered for the |
|
operator and that it is correct. This test takes the result of |
|
running the operator on real tensors and the result of running |
|
the operator on FakeTensors and checks that they have the same |
|
Tensor metadata (sizes/strides/dtype/device/etc). |
|
- test_aot_dispatch_dynamic: If the operator has correct behavior |
|
with PyTorch compilation APIs (torch.compile/export/FX). |
|
This checks that the outputs (and gradients, if applicable) are the |
|
same under eager-mode PyTorch and torch.compile. |
|
This test is a superset of ``test_faketensor`` and is an e2e test; |
|
other things it tests are that the operator supports |
|
functionalization and that the backward pass (if it exists) also |
|
supports FakeTensor and functionalization. |
|
|
|
For best results, please call ``opcheck`` multiple times with a |
|
representative set of inputs. If your operator supports |
|
autograd, please use ``opcheck`` with inputs with ``requires_grad = True``; |
|
if your operator supports multiple devices (e.g. CPU and CUDA), please |
|
use ``opcheck`` with inputs on all supported devices. |
|
|
|
Args: |
|
op: The operator. Must either be a function decorated with |
|
:func:`torch.library.custom_op` or an OpOverload/OpOverloadPacket |
|
found in torch.ops.* (e.g. torch.ops.aten.sin, torch.ops.mylib.foo) |
|
args: The args to the operator |
|
kwargs: The kwargs to the operator |
|
test_utils: Tests that we should run. Default: all of them. |
|
Example: ("test_schema", "test_faketensor") |
|
raise_exception: If we should raise an exception on the first |
|
error. If False, we will return a dict with information |
|
on if each test passed or not. |
|
rtol (Optional[float]): Relative tolerance for floating point comparisons. |
|
If specified ``atol`` must also be specified. |
|
If omitted, default values based on the ``dtype`` are selected |
|
(see the table in :func:`torch.testing.assert_close`). |
|
atol (Optional[float]): Absolute tolerance for floating point comparisons. |
|
If specified ``rtol`` must also be specified. |
|
If omitted, default values based on the ``dtype`` are selected |
|
(see the table in :func:`torch.testing.assert_close`). |
|
|
|
.. warning:: |
|
|
|
opcheck and :func:`torch.autograd.gradcheck` test different things; |
|
opcheck tests if your usage of torch.library APIs is correct while |
|
:func:`torch.autograd.gradcheck` tests if your autograd formula is |
|
mathematically correct. Use both to test custom ops that support |
|
gradient computation. |
|
|
|
Example: |
|
|
|
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA) |
|
>>> @torch.library.custom_op("mylib::numpy_mul", mutates_args=()) |
|
>>> def numpy_mul(x: Tensor, y: float) -> Tensor: |
|
>>> x_np = x.numpy(force=True) |
|
>>> z_np = x_np * y |
|
>>> return torch.from_numpy(z_np).to(x.device) |
|
>>> |
|
>>> @numpy_mul.register_fake |
|
>>> def _(x, y): |
|
>>> return torch.empty_like(x) |
|
>>> |
|
>>> def setup_context(ctx, inputs, output): |
|
>>> y, = inputs |
|
>>> ctx.y = y |
|
>>> |
|
>>> def backward(ctx, grad): |
|
>>> return grad * ctx.y, None |
|
>>> |
|
>>> numpy_mul.register_autograd(backward, setup_context=setup_context) |
|
>>> |
|
>>> sample_inputs = [ |
|
>>> (torch.randn(3), 3.14), |
|
>>> (torch.randn(2, 3, device='cuda'), 2.718), |
|
>>> (torch.randn(1, 10, requires_grad=True), 1.234), |
|
>>> (torch.randn(64, 64, device='cuda', requires_grad=True), 90.18), |
|
>>> ] |
|
>>> |
|
>>> for args in sample_inputs: |
|
>>> torch.library.opcheck(numpy_mul, args) |
|
|
|
""" |
|
import torch.testing._internal.optests as optests |
|
|
|
return optests.opcheck( |
|
op, |
|
args, |
|
kwargs, |
|
test_utils=test_utils, |
|
raise_exception=raise_exception, |
|
rtol=rtol, |
|
atol=atol, |
|
) |
|
|