|
|
|
import dataclasses |
|
import inspect |
|
import sys |
|
import warnings |
|
from collections.abc import Iterable, Iterator |
|
from typing import Any, Callable, Union |
|
|
|
import torch |
|
import torch.utils._pytree as pytree |
|
from torch import _C, _utils_internal |
|
from torch._ops import OpOverload |
|
|
|
|
|
def warn_deploy(stacklevel=3): |
|
warnings.warn( |
|
"Python torch.library APIs do nothing under torch::deploy (multipy). " |
|
"Please instead use C++ custom operator registration APIs.", |
|
RuntimeWarning, |
|
stacklevel=stacklevel, |
|
) |
|
|
|
|
|
@dataclasses.dataclass |
|
class Kernel: |
|
"""Models a (function, source location)""" |
|
|
|
func: Callable |
|
source: str |
|
|
|
def __call__(self, *args, **kwargs): |
|
return self.func(*args, **kwargs) |
|
|
|
|
|
class RegistrationHandle: |
|
"""Does something when someone calls .destroy() on it""" |
|
|
|
def __init__(self, on_destroy: Callable): |
|
self._on_destroy = on_destroy |
|
|
|
def destroy(self) -> None: |
|
self._on_destroy() |
|
|
|
|
|
def get_source(stacklevel: int) -> str: |
|
"""Get a string that represents the caller. |
|
|
|
Example: "/path/to/foo.py:42" |
|
|
|
Use stacklevel=1 to get the caller's source |
|
Use stacklevel=2 to get the caller's caller's source |
|
etc. |
|
""" |
|
frame = inspect.getframeinfo(sys._getframe(stacklevel)) |
|
source = f"{frame.filename}:{frame.lineno}" |
|
return source |
|
|
|
|
|
def parse_namespace(qualname: str) -> tuple[str, str]: |
|
splits = qualname.split("::") |
|
if len(splits) != 2: |
|
raise ValueError( |
|
f"Expected `qualname` to be of the form " |
|
f'"namespace::name", but got {qualname}. ' |
|
f"The qualname passed to the torch.library APIs must consist " |
|
f"of a namespace and a name, e.g. aten::sin" |
|
) |
|
return splits[0], splits[1] |
|
|
|
|
|
def lookup_op(qualname: str) -> OpOverload: |
|
namespace, name = parse_namespace(qualname) |
|
if "." in name: |
|
name, overload = name.split(".") |
|
else: |
|
overload = "default" |
|
ns = getattr(torch.ops, namespace) |
|
packet = getattr(ns, name) |
|
return getattr(packet, overload) |
|
|
|
|
|
def is_builtin(op: OpOverload) -> bool: |
|
assert isinstance(op, OpOverload) |
|
return op.namespace in {"aten", "prim", "prims"} |
|
|
|
|
|
def is_functional_schema(schema: Any) -> bool: |
|
"""Check if the schema is functional. |
|
|
|
An operator is functional if: |
|
- it does not mutate any of its inputs |
|
- it does not return a view on any of its inputs |
|
- it has at least one return |
|
""" |
|
|
|
def is_functional(schema): |
|
if schema.is_mutable: |
|
return False |
|
rets = schema.returns |
|
is_non_mutating_view = len(rets) > 0 and any( |
|
r.alias_info is not None and not r.alias_info.is_write for r in rets |
|
) |
|
if is_non_mutating_view: |
|
return False |
|
if not schema.returns: |
|
return False |
|
return True |
|
|
|
if isinstance(schema, torch._C.FunctionSchema): |
|
return is_functional(schema) |
|
|
|
|
|
from torchgen.model import FunctionSchema |
|
|
|
if isinstance(schema, str): |
|
schema = FunctionSchema.parse(schema) |
|
assert isinstance(schema, FunctionSchema) |
|
return is_functional(schema) |
|
|
|
|
|
|
|
def is_tensorlist_like_type(typ: Any) -> bool: |
|
return ( |
|
typ == _C.ListType(_C.TensorType.get()) |
|
or typ == _C.ListType(_C.OptionalType(_C.TensorType.get())) |
|
or typ == _C.OptionalType(_C.ListType(_C.TensorType.get())) |
|
or typ == _C.OptionalType(_C.ListType(_C.OptionalType(_C.TensorType.get()))) |
|
) |
|
|
|
|
|
|
|
def is_tensor_like_type(typ: Any) -> bool: |
|
return typ == _C.TensorType.get() or typ == _C.OptionalType(_C.TensorType.get()) |
|
|
|
|
|
def mutates_and_returns_first_arg(op: OpOverload): |
|
"""Check if an op is an inplace aten op, i.e. it mutates and returns the first arg. |
|
|
|
TODO: torchgen/model.py's FunctionSchema.parse is the source of truth for this, |
|
but not all PyTorch builds have torchgen (due to the yaml dependency being weird). |
|
Figure this out. |
|
|
|
Example: add_(Tensor(a!) x, Tensor y) -> Tensor(a) |
|
""" |
|
if op.namespace != "aten": |
|
return False |
|
schema = op._schema |
|
if not len(schema.returns) == 1: |
|
return False |
|
if schema.returns[0].alias_info is None: |
|
return False |
|
alias_set = schema.returns[0].alias_info.after_set |
|
if len(alias_set) != 1: |
|
return False |
|
loc = next(iter(alias_set)) |
|
if len(schema.arguments) < 1: |
|
return False |
|
first_arg = schema.arguments[0] |
|
if first_arg.alias_info is None: |
|
return False |
|
if not first_arg.alias_info.is_write: |
|
return False |
|
alias_set = first_arg.alias_info.after_set |
|
if len(alias_set) != 1: |
|
return False |
|
if loc != next(iter(alias_set)): |
|
return False |
|
for arg in schema.arguments[1:]: |
|
if arg.alias_info is not None: |
|
return False |
|
return True |
|
|
|
|
|
def fill_defaults(schema, args, kwargs): |
|
new_args = [] |
|
new_kwargs = {} |
|
for i in range(len(schema.arguments)): |
|
info = schema.arguments[i] |
|
if info.kwarg_only: |
|
if info.name in kwargs: |
|
new_kwargs[info.name] = kwargs[info.name] |
|
else: |
|
new_kwargs[info.name] = info.default_value |
|
else: |
|
if i < len(args): |
|
new_args.append(args[i]) |
|
else: |
|
new_args.append(info.default_value) |
|
return tuple(new_args), new_kwargs |
|
|
|
|
|
def zip_schema( |
|
schema: _C.FunctionSchema, args: tuple[Any, ...], kwargs: dict[str, Any] |
|
) -> Iterable[tuple[_C.Argument, Any]]: |
|
"""zips schema.arguments and (args, kwargs) together. |
|
|
|
Assumes that (args, kwargs) were the inputs to some torch._ops.OpOverload: |
|
that is, (args, kwargs) must be bindable to the schema (args, kwargs). |
|
""" |
|
assert len(schema.arguments) >= len(args) + len(kwargs) |
|
for i in range(len(schema.arguments)): |
|
info = schema.arguments[i] |
|
if info.kwarg_only: |
|
if info.name in kwargs: |
|
yield info, kwargs[info.name] |
|
continue |
|
if i >= len(args): |
|
if not info.kwarg_only and info.name in kwargs: |
|
yield info, kwargs[info.name] |
|
|
|
|
|
|
|
continue |
|
yield info, args[i] |
|
return |
|
|
|
|
|
def hop_schema_from_fx_node(node): |
|
from torchgen.gen_schema_utils import FunctionSchemaGen |
|
|
|
hop = node.target |
|
if not isinstance(hop, torch._ops.HigherOrderOperator): |
|
raise RuntimeError("fx_node's target must be a hop.") |
|
|
|
def _collect_example_val(node): |
|
meta_val = node.meta.get("val", None) |
|
if meta_val is None: |
|
assert node.op == "get_attr" |
|
meta_val = getattr(node.graph.owning_module, node.target) |
|
return meta_val |
|
|
|
example_inputs = [] |
|
for arg in node.args: |
|
if isinstance(arg, (torch.fx.Node, torch.fx.node.Node)): |
|
example_inputs.append(_collect_example_val(arg)) |
|
elif isinstance( |
|
arg, (torch.fx.immutable_collections.immutable_list, list, tuple) |
|
): |
|
example_inputs.append([_collect_example_val(x) for x in arg]) |
|
else: |
|
raise RuntimeError(f"Unsupported arg type {type(arg)}") |
|
|
|
|
|
bound_args: inspect.BoundArguments = inspect.signature(hop.__call__).bind( |
|
*example_inputs |
|
) |
|
|
|
|
|
|
|
example_output = _collect_example_val(node) |
|
return FunctionSchemaGen.from_example( |
|
hop._name, tuple(bound_args.arguments.items()), (list(example_output),) |
|
) |
|
|
|
|
|
def can_generate_trivial_fake_impl(op: OpOverload) -> bool: |
|
assert isinstance(op, OpOverload) |
|
if is_builtin(op): |
|
|
|
|
|
return False |
|
schema = op._schema |
|
|
|
if not schema.is_mutable: |
|
return False |
|
if len(schema.returns) > 0: |
|
return False |
|
|
|
return True |
|
|
|
|
|
def requires_set_python_module() -> bool: |
|
"""If an op was defined in C++ and extended from Python using the |
|
torch.library APIs, returns if we require that there have been a |
|
m.set_python_module("mylib.ops") call from C++ that associates |
|
the C++ op with a python module. |
|
""" |
|
return getattr(_utils_internal, "REQUIRES_SET_PYTHON_MODULE", True) |
|
|
|
|
|
def handle_dispatch_mode(curr_mode, op_overload, *args, **kwargs): |
|
assert isinstance(curr_mode, torch.utils._python_dispatch.TorchDispatchMode) |
|
args_flattened, _ = torch.utils._pytree.tree_flatten((args, kwargs.values())) |
|
|
|
|
|
|
|
|
|
overload_types = [ |
|
type(a) |
|
for a in args_flattened |
|
if isinstance(a, torch.Tensor) |
|
and torch._C._dispatch_keys(a).has(torch._C.DispatchKey.Python) |
|
] |
|
|
|
|
|
return curr_mode.__torch_dispatch__(op_overload, overload_types, args, kwargs) |
|
|
|
|
|
def has_kwarg_only_args(schema: _C.FunctionSchema): |
|
return any(a.kwarg_only for a in schema.arguments) |
|
|
|
|
|
def has_kwarg_only_tensors(schema: _C.FunctionSchema): |
|
for a in schema.arguments: |
|
if not (is_tensor_like_type(a.type) or is_tensorlist_like_type(a.type)): |
|
continue |
|
if not a.kwarg_only: |
|
continue |
|
return True |
|
return False |
|
|
|
|
|
def has_tensor_arg(schema: _C.FunctionSchema) -> bool: |
|
""" |
|
Given a schema, returns True if the schema has a Tensor arg. |
|
A Tensor arg is any arg with a type annotation that might involve Tensor. |
|
""" |
|
return any( |
|
(is_tensor_like_type(a.type) or is_tensorlist_like_type(a.type)) |
|
for a in schema.arguments |
|
) |
|
|
|
|
|
def get_device_arg_index(schema: _C.FunctionSchema) -> Union[int, None]: |
|
""" |
|
Given a schema, returns the id of the `device: torch.device` argument. |
|
If it does not exist, returns None. |
|
""" |
|
for index, arg in enumerate(schema.arguments): |
|
if arg.type is _C.DeviceObjType.get() and arg.name == "device": |
|
return index |
|
return None |
|
|
|
|
|
def iter_tensors( |
|
args: tuple[Any], kwargs: dict[str, Any], allowed_nesting: int = 1 |
|
) -> Iterator[torch.Tensor]: |
|
def check(arg): |
|
if isinstance(arg, torch.Tensor): |
|
yield arg |
|
elif allowed_nesting > 0 and isinstance(arg, (tuple, list)): |
|
yield from iter_tensors(tuple(arg), {}, allowed_nesting - 1) |
|
|
|
for arg in args: |
|
yield from check(arg) |
|
for kwarg in kwargs.values(): |
|
yield from check(kwarg) |
|
|
|
|
|
def check_aliasing_constraint(name, prev, result, get_module=lambda: "???"): |
|
""" |
|
custom operators' outputs must not alias any inputs or other outputs. |
|
""" |
|
storages = {id(t.untyped_storage()) for t in prev if isinstance(t, torch.Tensor)} |
|
tuple_result = result |
|
if not isinstance(result, tuple): |
|
tuple_result = (result,) |
|
for tensor in iter_tensors(tuple_result, {}): |
|
key = id(tensor.untyped_storage()) |
|
if id(tensor.untyped_storage()) in storages: |
|
raise RuntimeError( |
|
f"{name} (with implementation in {get_module()}): " |
|
f"The output of this custom operator (1) must not " |
|
f"also be an input to this custom operator and " |
|
f"(2) may not alias any inputs to this custom operator " |
|
f"or other returns. " |
|
f"The most common way to trigger this error is if " |
|
f"we have y = custom_op(x) and y and x are the same Tensor. " |
|
f"Please instead return a clone of the offending output " |
|
f"tensor(s) (e.g. return x.clone()) or refactor the custom " |
|
f"operator to not return y." |
|
) |
|
storages.add(key) |
|
|
|
|
|
class MutationChecker: |
|
""" |
|
Check if an operator mutated its arguments. |
|
Usage: |
|
|
|
checker = MutationChecker(op, flat_args, args_spec) |
|
op(*args, **kwargs) |
|
checker.check() |
|
""" |
|
|
|
def __init__(self, op, flat_args, args_spec): |
|
self.op = op |
|
self.args_spec = args_spec |
|
self.flat_args = flat_args |
|
self.real_pre_hashes = [ |
|
hash_tensor(a) if isinstance(a, torch.Tensor) else None for a in flat_args |
|
] |
|
|
|
def check(self): |
|
real_post_hashes = [ |
|
hash_tensor(a) if isinstance(a, torch.Tensor) else None |
|
for a in self.flat_args |
|
] |
|
was_mutated = [ |
|
not torch.equal(pre, post) |
|
and not (pre.isnan().all() and post.isnan().all()) |
|
if isinstance(pre, torch.Tensor) and isinstance(post, torch.Tensor) |
|
else None |
|
for pre, post in zip(self.real_pre_hashes, real_post_hashes) |
|
] |
|
was_mutated_args, was_mutated_kwargs = pytree.tree_unflatten( |
|
was_mutated, self.args_spec |
|
) |
|
for info, was_mutated in zip_schema( |
|
self.op._schema, was_mutated_args, was_mutated_kwargs |
|
): |
|
|
|
def check_one(info, was_mutated): |
|
if info.is_write == was_mutated: |
|
return |
|
raise RuntimeError( |
|
f"{self.op._name}: for argument '{info.name}': the operator's schema " |
|
f"{self.op._schema} specified that " |
|
f"the operator {'mutates' if info.is_write else 'does not mutate'} " |
|
f"the argument, but this seems to be emperically wrong. " |
|
f"Please make the schema and operator behavior consistent. " |
|
f"You can specify that an operator mutates a Tensor by " |
|
f"e.g. changing its schema type from 'Tensor name' to 'Tensor(a!) name'" |
|
f"(use different identifiers (a, b, c, ...) for different Tensors)" |
|
) |
|
|
|
if is_tensor_like_type(info.type): |
|
check_one(info, was_mutated) |
|
elif is_tensorlist_like_type(info.type): |
|
was_any_mutated = False if was_mutated is None else any(was_mutated) |
|
check_one(info, was_any_mutated) |
|
|
|
|
|
def hash_tensor(t: torch.Tensor) -> torch.Tensor: |
|
"""Some inexpensive hash. Used as a quick and dirty indicator for tensor mutation""" |
|
return t.detach().float().mean() |
|
|
|
|
|
def has_fake_kernel(op: torch._ops.OpOverload) -> bool: |
|
"""If an operator (that stays alive until FakeTensorMode) has a Fake kernel. |
|
Don't use this if the operator decomposes before FakeTensorMode. |
|
""" |
|
if can_generate_trivial_fake_impl(op): |
|
return True |
|
name = op._name |
|
if torch._C._dispatch_has_kernel_for_dispatch_key( |
|
name, "CompositeImplicitAutograd" |
|
): |
|
return True |
|
opdef = torch._library.custom_ops._maybe_get_opdef(name) |
|
if opdef is None: |
|
|
|
if torch._C._dispatch_has_kernel_for_dispatch_key( |
|
name, "CompositeExplicitAutograd" |
|
): |
|
return True |
|
entry = torch._library.simple_registry.singleton.find(name) |
|
if entry.fake_impl.kernel is not None: |
|
return True |
|
if torch._C._dispatch_has_kernel_for_dispatch_key(name, "Meta"): |
|
return True |
|
else: |
|
|
|
if opdef._abstract_fn is not None: |
|
return True |
|
return False |
|
|
|
|
|
def mutated_args_kwargs(schema: _C.FunctionSchema) -> tuple[list[int], list[str]]: |
|
idxs = [] |
|
keys = [] |
|
for i, info in enumerate(schema.arguments): |
|
if info.alias_info is not None and info.alias_info.is_write: |
|
if info.kwarg_only: |
|
keys.append(info.name) |
|
else: |
|
idxs.append(i) |
|
return idxs, keys |
|
|