|
|
|
import itertools |
|
import unittest.mock |
|
from collections.abc import Iterator |
|
from contextlib import contextmanager |
|
|
|
import torch |
|
import torch._C |
|
import torch._ops |
|
import torch.utils._python_dispatch |
|
import torch.utils._pytree as pytree |
|
|
|
|
|
__all__ = ["enable_python_dispatcher", "no_python_dispatcher", "enable_pre_dispatch"] |
|
|
|
no_python_dispatcher = torch._C._DisablePythonDispatcher |
|
enable_python_dispatcher = torch._C._EnablePythonDispatcher |
|
enable_pre_dispatch = torch._C._EnablePreDispatch |
|
|
|
CROSSREF_FUNCTIONALIZE = False |
|
|
|
|
|
def all_py_loaded_overloads() -> Iterator[torch._ops.OpOverload]: |
|
""" |
|
Warning: the set of overloads this will report is very subtle. It is precisely |
|
the set of torch.ops functions that have actually been accessed from Python |
|
(e.g., we actually called torch.ops.aten.blah at some point. This is DIFFERENT |
|
from the set of registered operators, which will in general be a larger set, |
|
as this would include all operators which we ran C++ static initializers or |
|
Python operator registration on. This does not eagerly populate the list on |
|
torch.ops.aten; this list is lazy! |
|
|
|
In other words, this is good for traversing over everything that has an |
|
OpOverload object allocated in Python. We use it for cache invalidation, but |
|
don't rely on this list being complete. |
|
|
|
Note that even if we did report all C++ registered overloads, this isn't guaranteed |
|
to be complete either, as a subsequent lazy load of a library which triggers more |
|
registrations could add more things to the set. |
|
""" |
|
for ns in torch.ops: |
|
packets = getattr(torch.ops, ns) |
|
for op_name in packets: |
|
packet = getattr(packets, op_name) |
|
for overload in packet: |
|
yield getattr(packet, overload) |
|
|
|
|
|
@contextmanager |
|
def suspend_functionalization(): |
|
f_tls = torch._C._dispatch_tls_is_dispatch_key_included( |
|
torch._C.DispatchKey.Functionalize |
|
) |
|
f_rv = torch._C._functionalization_reapply_views_tls() |
|
if f_tls: |
|
torch._disable_functionalization() |
|
try: |
|
yield |
|
finally: |
|
if f_tls: |
|
torch._enable_functionalization(reapply_views=f_rv) |
|
|
|
|
|
def check_tensor_metadata_matches(nv, rv, desc): |
|
assert callable(desc) |
|
assert nv.size() == rv.size(), f"{desc()}: sizes {nv.size()} != {rv.size()}" |
|
assert nv.dtype == rv.dtype, f"{desc()}: dtype {nv.dtype} != {rv.dtype}" |
|
same_strides, idx = torch._prims_common.check_significant_strides( |
|
nv, rv, only_cuda=False |
|
) |
|
assert same_strides, ( |
|
f"{desc()}: strides {nv.stride()} != {rv.stride()} (mismatch at index {idx})" |
|
) |
|
|
|
|
|
def check_metadata_matches(n, r, desc): |
|
assert callable(desc) |
|
n_vals, _n_spec = pytree.tree_flatten(n) |
|
r_vals, _r_spec = pytree.tree_flatten(r) |
|
|
|
|
|
assert len(n_vals) == len(r_vals), f"{len(n_vals)} != {len(r_vals)}" |
|
for i, nv, rv in zip(range(len(n_vals)), n_vals, r_vals): |
|
if not isinstance(rv, torch.Tensor): |
|
continue |
|
check_tensor_metadata_matches(nv, rv, lambda: f"{desc()} output {i}") |
|
|
|
|
|
class Lit: |
|
def __init__(self, s): |
|
self.s = s |
|
|
|
def __repr__(self): |
|
return self.s |
|
|
|
|
|
def _fmt(a: object) -> object: |
|
if isinstance(a, torch.Tensor): |
|
return Lit( |
|
f"torch.empty_strided({tuple(a.size())}, {a.stride()}, dtype={a.dtype})" |
|
) |
|
else: |
|
return a |
|
|
|
|
|
def make_crossref_functionalize(op, final_key): |
|
from torch._subclasses.fake_tensor import FakeTensorMode |
|
|
|
|
|
if op == torch.ops.aten.lift_fresh.default: |
|
return final_key |
|
|
|
def handler(*args, **kwargs): |
|
fake_mode = FakeTensorMode() |
|
|
|
def fakeify_defun(t): |
|
if isinstance(t, torch.Tensor): |
|
if torch._is_functional_tensor(t): |
|
r = torch._from_functional_tensor(t) |
|
|
|
|
|
|
|
|
|
assert t.size() == r.size() |
|
assert t.stride() == r.stride() |
|
else: |
|
r = t |
|
|
|
return fake_mode.from_tensor(r) |
|
return t |
|
|
|
def maybe_detach(t): |
|
if isinstance(t, torch.Tensor): |
|
return t.detach() |
|
else: |
|
return t |
|
|
|
|
|
|
|
with ( |
|
torch.utils._python_dispatch._disable_current_modes(), |
|
suspend_functionalization(), |
|
): |
|
f_args, f_kwargs = pytree.tree_map(fakeify_defun, (args, kwargs)) |
|
orig_f_args, orig_f_kwargs = pytree.tree_map( |
|
maybe_detach, (f_args, f_kwargs) |
|
) |
|
with fake_mode: |
|
f_r = op(*f_args, **f_kwargs) |
|
r = op._op_dk(final_key, *args, **kwargs) |
|
|
|
def desc(): |
|
fmt_args = ", ".join( |
|
itertools.chain( |
|
(repr(pytree.tree_map(_fmt, a)) for a in orig_f_args), |
|
( |
|
f"{k}={pytree.tree_map(_fmt, v)}" |
|
for k, v in orig_f_kwargs.items() |
|
), |
|
) |
|
) |
|
return f"{op}({fmt_args})" |
|
|
|
check_metadata_matches(f_r, r, desc) |
|
return r |
|
|
|
return handler |
|
|
|
|
|
|
|
|
|
@contextmanager |
|
def enable_crossref_functionalize(): |
|
for op in all_py_loaded_overloads(): |
|
op._uncache_dispatch(torch._C.DispatchKey.Functionalize) |
|
try: |
|
with ( |
|
enable_python_dispatcher(), |
|
unittest.mock.patch("torch._dispatch.python.CROSSREF_FUNCTIONALIZE", True), |
|
): |
|
yield |
|
finally: |
|
for op in all_py_loaded_overloads(): |
|
op._uncache_dispatch(torch._C.DispatchKey.Functionalize) |
|
|