|
|
|
|
|
""" |
|
Core guard system for Dynamo that detects when compiled code needs to be recompiled due to |
|
changes in program state. Guards are conditions that must remain true for previously-compiled |
|
code to be valid for reuse. |
|
|
|
This module provides the infrastructure for creating, managing and checking guards, including: |
|
- Guard creation and composition |
|
- Guard state management and invalidation |
|
- Guard checking and failure handling |
|
- Utilities for guard optimization and debugging |
|
- Integration with Dynamo's compilation caching |
|
|
|
The guard system is critical for Dynamo's ability to efficiently reuse compiled code while |
|
maintaining correctness by detecting when recompilation is necessary due to changes in |
|
program state, tensor properties, or control flow. |
|
""" |
|
|
|
from __future__ import annotations |
|
|
|
import ast |
|
import builtins |
|
import collections |
|
import dataclasses |
|
import enum |
|
import functools |
|
import importlib |
|
import inspect |
|
import logging |
|
import math |
|
import sys |
|
import textwrap |
|
import types |
|
import warnings |
|
import weakref |
|
from contextlib import contextmanager |
|
from copy import deepcopy |
|
from inspect import currentframe |
|
from typing import Any, Callable, Optional, TYPE_CHECKING, Union |
|
from weakref import ReferenceType |
|
|
|
import torch |
|
import torch.overrides |
|
import torch.utils._device |
|
from torch._C._dynamo.eval_frame import code_framelocals_names |
|
from torch._C._dynamo.guards import ( |
|
check_obj_id, |
|
check_type_id, |
|
dict_version, |
|
DictGuardManager, |
|
install_no_tensor_aliasing_guard, |
|
install_object_aliasing_guard, |
|
install_storage_overlapping_guard, |
|
install_symbolic_shape_guard, |
|
profile_guard_manager, |
|
RootGuardManager, |
|
) |
|
from torch._dynamo.source import ( |
|
IndexedSource, |
|
is_from_flatten_script_object_source, |
|
is_from_local_source, |
|
is_from_optimizer_source, |
|
TensorProperty, |
|
TensorPropertySource, |
|
) |
|
from torch._dynamo.utils import CompileEventLogger |
|
from torch._guards import ( |
|
CompileContext, |
|
CompileId, |
|
DuplicateInputs, |
|
Guard, |
|
GuardBuilderBase, |
|
GuardEnvExpr, |
|
GuardSource, |
|
Source, |
|
StorageOverlap, |
|
) |
|
from torch._logging import structured |
|
from torch._utils_internal import justknobs_check |
|
from torch.fx.experimental.symbolic_shapes import ( |
|
EqualityConstraint, |
|
is_symbolic, |
|
SYMPY_INTERP, |
|
) |
|
from torch.utils._ordered_set import OrderedSet |
|
from torch.utils._traceback import format_frame, report_compile_source_on_error |
|
from torch.utils.weak import TensorWeakRef |
|
|
|
from . import config, convert_frame, exc, mutation_guard |
|
from .eval_frame import set_guard_error_hook |
|
from .source import ( |
|
AttrProxySource, |
|
AttrSource, |
|
CallFunctionNoArgsSource, |
|
CallMethodItemSource, |
|
ChainedSource, |
|
ConstantSource, |
|
ConstDictKeySource, |
|
DefaultsSource, |
|
DictGetItemSource, |
|
FlattenScriptObjectSource, |
|
FloatTensorSource, |
|
FSDPNNModuleSource, |
|
GenericAttrSource, |
|
GetItemSource, |
|
GlobalSource, |
|
GlobalStateSource, |
|
GlobalWeakRefSource, |
|
GradSource, |
|
ListGetItemSource, |
|
LocalSource, |
|
NNModuleSource, |
|
NumpyTensorSource, |
|
OptimizerSource, |
|
ScriptObjectQualifiedNameSource, |
|
ShapeEnvSource, |
|
SubclassAttrListSource, |
|
TorchFunctionModeStackSource, |
|
TupleIteratorGetItemSource, |
|
TypeSource, |
|
UnspecializedBuiltinNNModuleSource, |
|
UnspecializedNNModuleSource, |
|
UnspecializedParamBufferSource, |
|
WeakRefCallSource, |
|
) |
|
from .types import ( |
|
CacheEntry, |
|
DynamoFrameType, |
|
ExtraState, |
|
GuardedCode, |
|
GuardFail, |
|
GuardFn, |
|
) |
|
from .utils import ( |
|
builtin_dict_keys, |
|
common_constant_types, |
|
dict_keys, |
|
get_custom_getattr, |
|
get_torch_function_mode_stack, |
|
get_torch_function_mode_stack_at, |
|
guard_failures, |
|
istype, |
|
key_is_id, |
|
key_to_id, |
|
normalize_range_iter, |
|
orig_code_map, |
|
tensor_always_has_static_shape, |
|
tuple_iterator_getitem, |
|
tuple_iterator_len, |
|
unpatched_nn_module_getattr, |
|
verify_guard_fn_signature, |
|
) |
|
|
|
|
|
guard_manager_testing_hook_fn: Optional[Callable[[Any, Any], Any]] = None |
|
|
|
try: |
|
import numpy as np |
|
except ModuleNotFoundError: |
|
np = None |
|
|
|
|
|
if TYPE_CHECKING: |
|
from sympy import Symbol |
|
|
|
|
|
log = logging.getLogger(__name__) |
|
guards_log = torch._logging.getArtifactLogger(__name__, "guards") |
|
recompiles_log = torch._logging.getArtifactLogger(__name__, "recompiles") |
|
recompiles_verbose_log = torch._logging.getArtifactLogger( |
|
__name__, "recompiles_verbose" |
|
) |
|
verbose_guards_log = torch._logging.getArtifactLogger(__name__, "verbose_guards") |
|
|
|
|
|
class GuardManagerWrapper: |
|
""" |
|
A helper class that contains the root guard manager. An instance of this |
|
class is stored in the Dynamo cache entry, so that the cache entry can |
|
access the RootGuardManager stored in the "root" attribute and directly call |
|
the check_nopybind from C++. |
|
""" |
|
|
|
def __init__(self, root=None): |
|
if root is None: |
|
self.root = RootGuardManager() |
|
else: |
|
self.root = root |
|
|
|
self.diff_guard_root = None |
|
self.closure_vars = None |
|
self.args = None |
|
self.code_parts = [] |
|
self.verbose_code_parts = None |
|
self.global_scope = None |
|
self.guard_fail_fn = None |
|
self.cache_entry = None |
|
self.extra_state = None |
|
self.id_matched_objs = {} |
|
self.no_tensor_aliasing_sources = [] |
|
|
|
self.print_no_tensor_aliasing_guard = True |
|
|
|
self.diff_guard_sources: OrderedSet[str] = OrderedSet() |
|
|
|
@contextmanager |
|
def _preserve_print_no_tensor_aliasing_flag(self): |
|
self.print_no_tensor_aliasing_guard = True |
|
try: |
|
yield |
|
finally: |
|
self.print_no_tensor_aliasing_guard = True |
|
|
|
def collect_diff_guard_sources(self): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def visit_dict_manager(node): |
|
is_diff_guard_node = ( |
|
node.get_source() in self.diff_guard_sources or node.fail_count() > 0 |
|
) |
|
for idx, (key_mgr, val_mgr) in sorted( |
|
node.get_key_value_managers().items() |
|
): |
|
is_diff_guard_node |= visit(key_mgr) | visit(val_mgr) |
|
|
|
if is_diff_guard_node: |
|
self.diff_guard_sources.add(node.get_source()) |
|
|
|
return is_diff_guard_node |
|
|
|
def visit_manager(node): |
|
assert not isinstance(node, DictGuardManager) |
|
|
|
is_diff_guard_node = ( |
|
node.get_source() in self.diff_guard_sources or node.fail_count() > 0 |
|
) |
|
for child_mgr in node.get_child_managers(): |
|
is_diff_guard_node |= visit(child_mgr) |
|
|
|
if is_diff_guard_node: |
|
self.diff_guard_sources.add(node.get_source()) |
|
|
|
return is_diff_guard_node |
|
|
|
def visit(node): |
|
if node is None: |
|
return False |
|
if isinstance(node, DictGuardManager): |
|
return visit_dict_manager(node) |
|
return visit_manager(node) |
|
|
|
visit(self.root) |
|
|
|
return self.diff_guard_sources |
|
|
|
def finalize(self): |
|
self.collect_diff_guard_sources() |
|
self.populate_diff_guard_manager() |
|
|
|
def populate_diff_guard_manager(self): |
|
self.diff_guard_root = self.clone_with_chosen_sources(self.diff_guard_sources) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.cache_entry: |
|
self.cache_entry.update_diff_guard_root_manager() |
|
|
|
def clone_with_chosen_sources(self, chosen_sources): |
|
def filter_fn(node_mgr): |
|
return node_mgr.get_source() in chosen_sources |
|
|
|
return self.root.clone_manager(filter_fn) |
|
|
|
def get_guard_lines(self, guard): |
|
guard_name = guard.__class__.__name__ |
|
parts = guard.verbose_code_parts() |
|
parts = [guard_name + ": " + part for part in parts] |
|
return parts |
|
|
|
def get_manager_line(self, guard_manager, accessor_str=None): |
|
source = guard_manager.get_source() |
|
t = guard_manager.__class__.__name__ |
|
s = t + ": source=" + source |
|
if accessor_str: |
|
s += ", " + accessor_str |
|
return s |
|
|
|
def construct_dict_manager_string(self, mgr, body): |
|
for idx, (key_mgr, val_mgr) in sorted(mgr.get_key_value_managers().items()): |
|
body.writeline(f"KeyValueManager pair at index={idx}") |
|
with body.indent(): |
|
if key_mgr: |
|
body.writeline(f"KeyManager: {self.get_manager_line(key_mgr)}") |
|
self.construct_manager_string(key_mgr, body) |
|
|
|
if val_mgr: |
|
body.writeline(f"ValueManager: {self.get_manager_line(val_mgr)}") |
|
self.construct_manager_string(val_mgr, body) |
|
|
|
def construct_manager_string(self, mgr, body): |
|
with body.indent(): |
|
for guard in mgr.get_leaf_guards(): |
|
if isinstance(guard, torch._C._dynamo.guards.NO_TENSOR_ALIASING): |
|
if self.print_no_tensor_aliasing_guard: |
|
self.print_no_tensor_aliasing_guard = False |
|
body.writelines(self.get_guard_lines(guard)) |
|
else: |
|
body.writelines( |
|
[ |
|
guard.__class__.__name__, |
|
] |
|
) |
|
else: |
|
body.writelines(self.get_guard_lines(guard)) |
|
|
|
|
|
if isinstance(mgr, DictGuardManager): |
|
self.construct_dict_manager_string(mgr, body) |
|
|
|
|
|
for accessor, child_mgr in zip( |
|
mgr.get_accessors(), mgr.get_child_managers() |
|
): |
|
body.writeline( |
|
self.get_manager_line(child_mgr, f"accessed_by={accessor.repr()}") |
|
) |
|
self.construct_manager_string(child_mgr, body) |
|
|
|
def __str__(self): |
|
from torch._inductor.utils import IndentedBuffer |
|
|
|
class IndentedBufferWithPrefix(IndentedBuffer): |
|
def prefix(self): |
|
return "| " * (self._indent * self.tabwidth) |
|
|
|
def writeline(self, line, skip_prefix=False): |
|
if skip_prefix: |
|
super().writeline(line) |
|
else: |
|
super().writeline("+- " + line) |
|
|
|
with self._preserve_print_no_tensor_aliasing_flag(): |
|
body = IndentedBufferWithPrefix() |
|
body.tabwidth = 1 |
|
body.writeline("", skip_prefix=True) |
|
body.writeline("TREE_GUARD_MANAGER:", skip_prefix=True) |
|
body.writeline("RootGuardManager") |
|
self.construct_manager_string(self.root, body) |
|
if hasattr(self.root, "get_epilogue_lambda_guards"): |
|
for guard in self.root.get_epilogue_lambda_guards(): |
|
body.writelines(self.get_guard_lines(guard)) |
|
return body.getvalue() |
|
|
|
def check(self, x): |
|
|
|
return self.root.check(x) |
|
|
|
def check_verbose(self, x): |
|
|
|
return self.root.check_verbose(x) |
|
|
|
def populate_code_parts_for_debugging(self): |
|
|
|
tensor_aliasing_guard_seen = False |
|
|
|
def get_code_parts(leaf_guard): |
|
code_parts = [] |
|
for verbose_code_part in leaf_guard.verbose_code_parts(): |
|
code_part = verbose_code_part.split("#")[0].rstrip() |
|
code_parts.append(code_part) |
|
return code_parts |
|
|
|
def visit(mgr): |
|
nonlocal tensor_aliasing_guard_seen |
|
for guard in mgr.get_leaf_guards(): |
|
if isinstance(guard, torch._C._dynamo.guards.NO_TENSOR_ALIASING): |
|
if not tensor_aliasing_guard_seen: |
|
self.code_parts.extend(get_code_parts(guard)) |
|
tensor_aliasing_guard_seen = True |
|
else: |
|
self.code_parts.extend(get_code_parts(guard)) |
|
|
|
for child_mgr in mgr.get_child_managers(): |
|
visit(child_mgr) |
|
|
|
visit(self.root) |
|
|
|
|
|
def from_numpy(a): |
|
|
|
|
|
|
|
with torch.overrides._enable_torch_function(): |
|
return torch.as_tensor(a) if isinstance(a, (np.generic, np.ndarray)) else a |
|
|
|
|
|
|
|
@functools.lru_cache(None) |
|
def uninteresting_files(): |
|
import torch._dynamo.external_utils |
|
import torch._dynamo.polyfills |
|
|
|
mods = [torch._dynamo.external_utils, torch._dynamo.polyfills] |
|
|
|
from torch._dynamo.polyfills.loader import POLYFILLED_MODULES |
|
|
|
mods.extend(POLYFILLED_MODULES) |
|
|
|
return {inspect.getfile(m) for m in mods} |
|
|
|
|
|
_CLOSURE_VARS: Optional[dict[str, object]] = None |
|
|
|
|
|
def _get_closure_vars(): |
|
global _CLOSURE_VARS |
|
if _CLOSURE_VARS is None: |
|
_CLOSURE_VARS = { |
|
"___check_type_id": check_type_id, |
|
"___check_obj_id": check_obj_id, |
|
"___odict_getitem": collections.OrderedDict.__getitem__, |
|
"___key_to_id": key_to_id, |
|
"___dict_version": dict_version, |
|
"___dict_contains": lambda a, b: dict.__contains__(b, a), |
|
"___tuple_iterator_len": tuple_iterator_len, |
|
"___normalize_range_iter": normalize_range_iter, |
|
"___tuple_iterator_getitem": tuple_iterator_getitem, |
|
"___get_torch_function_mode_stack_at": get_torch_function_mode_stack_at, |
|
"__math_isnan": math.isnan, |
|
"__numpy_isnan": None if np is None else np.isnan, |
|
"inf": float("inf"), |
|
"__load_module": importlib.import_module, |
|
"utils_device": torch.utils._device, |
|
"device": torch.device, |
|
"___from_numpy": from_numpy, |
|
"___as_tensor": torch._as_tensor_fullprec, |
|
"torch": torch, |
|
"inspect": inspect, |
|
} |
|
return _CLOSURE_VARS |
|
|
|
|
|
def _ast_unparse(node: ast.AST) -> str: |
|
return ast.unparse(node).replace("\n", "") |
|
|
|
|
|
strip_function_call = torch._C._dynamo.strip_function_call |
|
|
|
|
|
def get_verbose_code_part(code_part: str, guard: Guard) -> str: |
|
extra = "" |
|
if guard is not None: |
|
if guard.user_stack: |
|
for fs in reversed(guard.user_stack): |
|
if fs.filename not in uninteresting_files(): |
|
extra = f" # {format_frame(fs, line=True)}" |
|
break |
|
elif guard.stack: |
|
extra = f" # {format_frame(guard.stack.summary()[-1])}" |
|
return f"{code_part:<60}{extra}" |
|
|
|
|
|
def get_verbose_code_parts( |
|
code_parts: Union[str | list[str]], guard: Guard |
|
) -> list[str]: |
|
if not isinstance(code_parts, list): |
|
code_parts = [code_parts] |
|
return [get_verbose_code_part(code_part, guard) for code_part in code_parts] |
|
|
|
|
|
def convert_to_concrete_values(size_or_stride): |
|
converted: list[Optional[int]] = [] |
|
for dim in size_or_stride: |
|
if not is_symbolic(dim): |
|
converted.append(dim) |
|
else: |
|
assert isinstance(dim, torch.SymInt) |
|
converted.append(dim.node.maybe_as_int()) |
|
return converted |
|
|
|
|
|
def get_tensor_guard_code_part(value, name, sizes, strides): |
|
pytype = type(value) |
|
dispatch_key = ( |
|
torch._C._dispatch_keys(value) | torch._C._dispatch_tls_local_include_set() |
|
) - torch._C._dispatch_tls_local_exclude_set() |
|
dtype = value.dtype |
|
device_index = value.device.index |
|
requires_grad = value.requires_grad |
|
guard_str = ( |
|
f"check_tensor({name}, {pytype.__qualname__}, {dispatch_key}, {dtype}, " |
|
f"device={device_index}, requires_grad={requires_grad}, size={sizes}, stride={strides})" |
|
) |
|
return guard_str |
|
|
|
|
|
def get_key_index(dct, key): |
|
|
|
|
|
|
|
|
|
return list(builtin_dict_keys(dct)).index(key) |
|
|
|
|
|
def get_key_index_source(source, index): |
|
return f"list(dict.keys({source}))[{index}]" |
|
|
|
|
|
@dataclasses.dataclass(frozen=True) |
|
class NNModuleAttrAccessorInfo: |
|
|
|
|
|
|
|
|
|
present_in_generic_dict: bool = False |
|
|
|
|
|
l1_key: Optional[str] = None |
|
|
|
|
|
l2_key: Optional[str] = None |
|
|
|
|
|
def getitem_on_dict_manager( |
|
source, base_guard_manager, base_example_value, example_value, guard_manager_enum |
|
): |
|
base_source_name = source.base.name() |
|
if isinstance(source.index, ConstDictKeySource): |
|
index = source.index.index |
|
else: |
|
assert isinstance(base_example_value, dict) |
|
index = get_key_index(base_example_value, source.index) |
|
|
|
key_source = get_key_index_source(base_source_name, index) |
|
|
|
|
|
|
|
|
|
|
|
key_example_value = list(builtin_dict_keys(base_example_value))[index] |
|
if isinstance(key_example_value, (int, str)): |
|
value_source = f"{base_source_name}[{key_example_value!r}]" |
|
else: |
|
value_source = f"{base_source_name}[{key_source}]" |
|
if not isinstance(source.index, ConstDictKeySource): |
|
|
|
|
|
base_guard_manager.get_key_manager( |
|
index=index, |
|
source=key_source, |
|
example_value=source.index, |
|
guard_manager_enum=GuardManagerType.GUARD_MANAGER, |
|
).add_equals_match_guard( |
|
source.index, [f"{key_source} == {key_example_value!r}"] |
|
) |
|
|
|
return base_guard_manager.get_value_manager( |
|
index=index, |
|
source=value_source, |
|
example_value=example_value, |
|
guard_manager_enum=guard_manager_enum, |
|
) |
|
|
|
|
|
def match_on_id_for_tensor(guard): |
|
source = guard.originating_source |
|
|
|
|
|
if isinstance(source, NumpyTensorSource): |
|
return False |
|
|
|
if guard.is_specialized_nn_module(): |
|
return True |
|
|
|
return source.is_dict_key() and not isinstance(source, GradSource) |
|
|
|
|
|
|
|
|
|
@dataclasses.dataclass |
|
class GuardCodeList: |
|
code_list: list[str] |
|
guard: Guard |
|
|
|
|
|
class GuardManagerType(enum.Enum): |
|
GUARD_MANAGER = 1 |
|
DICT_GUARD_MANAGER = 2 |
|
|
|
|
|
@functools.lru_cache(None) |
|
def code_framelocals_names_reversed_cached(code: types.CodeType): |
|
return list(reversed(code_framelocals_names(code))) |
|
|
|
|
|
class GuardBuilder(GuardBuilderBase): |
|
def __init__( |
|
self, |
|
f_code: types.CodeType, |
|
id_ref: Callable[[Any, str], str], |
|
source_ref: Callable[[Source], str], |
|
lookup_weakrefs: Callable[[object], ReferenceType[object]], |
|
local_scope: dict[str, object], |
|
global_scope: dict[str, object], |
|
guard_manager: GuardManagerWrapper, |
|
check_fn_manager: CheckFunctionManager, |
|
): |
|
self.f_code = f_code |
|
self.id_ref = id_ref |
|
self.source_ref = source_ref |
|
self.lookup_weakrefs = lookup_weakrefs |
|
self.scope: dict[str, dict[str, object]] = {"L": local_scope, "G": global_scope} |
|
self.scope["__builtins__"] = builtins.__dict__.copy() |
|
for ( |
|
name, |
|
package_module, |
|
) in torch.package.package_importer._package_imported_modules.items(): |
|
name = name.replace(">", "_").replace("<", "_").replace(".", "_dot_") |
|
|
|
self.scope["__builtins__"][name] = package_module |
|
|
|
self.scope[name] = package_module |
|
self.guard_manager = guard_manager |
|
|
|
self.argnames: list[str] = [] |
|
|
|
self.code: list[GuardCodeList] = [] |
|
|
|
|
|
|
|
|
|
self.shape_env_code: list[GuardCodeList] = [] |
|
|
|
|
|
|
|
self.no_tensor_aliasing_names: list[str] = [] |
|
self.no_tensor_aliasing_guard_managers: list[GuardManagerWrapper] = [] |
|
|
|
self.check_fn_manager: CheckFunctionManager = check_fn_manager |
|
|
|
|
|
|
|
|
|
|
|
self.key_order_guarded_dict_ids = set() |
|
for source_name in self.check_fn_manager.output_graph.guard_on_key_order: |
|
self.key_order_guarded_dict_ids.add(id(self.get(source_name))) |
|
|
|
|
|
|
|
|
|
self.id_matched_objs: dict[str, ReferenceType[object]] = {} |
|
|
|
|
|
self._cached_guard_managers: dict[ |
|
str, torch._C._dynamo.guards.GuardManager |
|
] = {} |
|
self._cached_duplicate_input_guards: set[tuple[str, str]] = set() |
|
|
|
def guard_on_dict_keys_and_ignore_order(self, example_value, guard): |
|
dict_mgr = self.get_guard_manager(guard) |
|
if isinstance(dict_mgr, DictGuardManager): |
|
raise NotImplementedError( |
|
"Not expecting a DictGuardManager. Seems like Dynamo incorrectly " |
|
f"added the dict to tx.output.guard_on_key_order for {guard.name}" |
|
) |
|
|
|
|
|
dict_source = guard.originating_source.name() |
|
|
|
|
|
|
|
|
|
|
|
for key in builtin_dict_keys(example_value): |
|
value = example_value[key] |
|
value_source = DictGetItemSource(guard.originating_source, index=key) |
|
guard_manager_enum = self.get_guard_manager_type( |
|
value_source, example_value |
|
) |
|
dict_mgr.dict_getitem_manager( |
|
key=key, |
|
source=f"{dict_source}[{key!r}]", |
|
example_value=value, |
|
guard_manager_enum=guard_manager_enum, |
|
) |
|
|
|
def guard_on_dict_keys_and_order(self, value, guard): |
|
|
|
|
|
dict_mgr = self.get_guard_manager(guard) |
|
if not isinstance(dict_mgr, DictGuardManager): |
|
raise NotImplementedError( |
|
"Expecting a DictGuardManager. Seems like Dynamo forgot " |
|
f"to set the right guard manager enum for {guard.name}" |
|
) |
|
assert isinstance(dict_mgr, DictGuardManager) |
|
|
|
|
|
|
|
|
|
|
|
for idx, key in enumerate(builtin_dict_keys(value)): |
|
key_source = get_key_index_source(guard.name, idx) |
|
key_manager = dict_mgr.get_key_manager( |
|
index=idx, |
|
source=key_source, |
|
example_value=key, |
|
guard_manager_enum=GuardManagerType.GUARD_MANAGER, |
|
) |
|
if key_is_id(key): |
|
|
|
id_val = self.id_ref(key, key_source) |
|
key_manager.add_id_match_guard( |
|
id_val, |
|
get_verbose_code_parts( |
|
f"__check_obj_id({key_source}, {id_val})", guard |
|
), |
|
) |
|
else: |
|
|
|
key_manager.add_equals_match_guard( |
|
key, get_verbose_code_parts(f"{key_source} == {key!r}", guard) |
|
) |
|
|
|
@staticmethod |
|
def _get_generic_dict_manager_example_value(example_value): |
|
|
|
|
|
|
|
|
|
if ( |
|
config.issue_3_13_0_warning |
|
and sys.version_info >= (3, 13) |
|
and sys.version_info < (3, 13, 1) |
|
): |
|
warnings.warn( |
|
"Guards may run slower on Python 3.13.0. Consider upgrading to Python 3.13.1+.", |
|
RuntimeWarning, |
|
) |
|
return None |
|
return example_value |
|
|
|
def getattr_on_nn_module( |
|
self, |
|
source, |
|
base_guard_manager, |
|
base_example_value, |
|
example_value, |
|
base_source_name, |
|
source_name, |
|
guard_manager_enum, |
|
): |
|
""" |
|
This tries to avoid calling the expensive nn module custom getattr method by |
|
checking if the attribute is accessible via __dict__. For attributes that |
|
are not accessible via __dict__ (like descriptors), we fallback to |
|
PyObject_GetAttr. |
|
|
|
There are two cases that we optimize for |
|
1) attributes present directly in __dict__, e.g training. |
|
2) parameters/buffers/modules - they can be accessed via _parameters, |
|
_buffers, _modules keys in __dict__. For example, mod.linear can be |
|
accessed as mod.__dict__["_parameters"]["linear"] |
|
|
|
The most common and expensive case for nn module guards is of type |
|
mod.submod1.submod2.submod3.training. We avoid the python getattr of nn |
|
modules by going through the __dict__. |
|
""" |
|
|
|
def getitem_on_dict_mgr( |
|
mgr, key, source_name, base_example_value, example_value, guard_manager_enum |
|
): |
|
if isinstance(mgr, DictGuardManager): |
|
|
|
|
|
index = get_key_index(base_example_value, key) |
|
|
|
|
|
key_source = f"list(dict.keys({source_name}))[{index!r}]" |
|
mgr.get_key_manager( |
|
index=index, |
|
source=key_source, |
|
example_value=key, |
|
guard_manager_enum=GuardManagerType.GUARD_MANAGER, |
|
).add_equals_match_guard(key, [f"{key_source} == {key!r}"]) |
|
|
|
|
|
return mgr.get_value_manager( |
|
index=index, |
|
source=source_name, |
|
example_value=example_value, |
|
guard_manager_enum=guard_manager_enum, |
|
) |
|
else: |
|
return mgr.dict_getitem_manager( |
|
key=key, |
|
source=source_name, |
|
example_value=example_value, |
|
guard_manager_enum=guard_manager_enum, |
|
) |
|
|
|
attr_name = source.member |
|
mod_dict = base_example_value.__dict__ |
|
|
|
all_class_attribute_names: set[str] = set() |
|
for x in inspect.getmro(base_example_value.__class__): |
|
all_class_attribute_names.update(x.__dict__.keys()) |
|
|
|
accessor_info = NNModuleAttrAccessorInfo(False, None, None) |
|
|
|
if attr_name in mod_dict: |
|
accessor_info = NNModuleAttrAccessorInfo(True, attr_name, None) |
|
elif "_parameters" in mod_dict and attr_name in mod_dict["_parameters"]: |
|
accessor_info = NNModuleAttrAccessorInfo(True, "_parameters", attr_name) |
|
elif "_buffers" in mod_dict and attr_name in mod_dict["_buffers"]: |
|
accessor_info = NNModuleAttrAccessorInfo(True, "_buffers", attr_name) |
|
elif ( |
|
attr_name not in all_class_attribute_names |
|
and "_modules" in mod_dict |
|
and attr_name in mod_dict["_modules"] |
|
): |
|
|
|
accessor_info = NNModuleAttrAccessorInfo(True, "_modules", attr_name) |
|
|
|
if not accessor_info.present_in_generic_dict: |
|
|
|
|
|
return base_guard_manager.getattr_manager( |
|
attr=source.member, |
|
source=source_name, |
|
example_value=example_value, |
|
guard_manager_enum=guard_manager_enum, |
|
) |
|
else: |
|
assert accessor_info.l1_key |
|
l1_key = accessor_info.l1_key |
|
l2_key = accessor_info.l2_key |
|
|
|
|
|
mod_dict_source = f"{base_source_name}.__dict__" |
|
l1_source_name = l2_source_name = None |
|
l1_value = l2_value = None |
|
l1_guard_manager_enum = l2_guard_manager_enum = None |
|
if l2_key: |
|
l1_source = AttrSource(source.base, l1_key) |
|
l1_source_name = l1_source.name() |
|
l1_value = mod_dict[l1_key] |
|
|
|
|
|
l1_guard_manager_enum = self.get_guard_manager_type(l1_source, l1_value) |
|
|
|
l2_source_name = source_name |
|
l2_value = example_value |
|
l2_guard_manager_enum = self.get_guard_manager_type( |
|
source, example_value |
|
) |
|
else: |
|
l1_source_name = source_name |
|
l1_value = example_value |
|
l1_guard_manager_enum = self.get_guard_manager_type( |
|
source, example_value |
|
) |
|
|
|
|
|
|
|
mod_generic_dict_manager = base_guard_manager.get_generic_dict_manager( |
|
source=mod_dict_source, |
|
example_value=self._get_generic_dict_manager_example_value(mod_dict), |
|
guard_manager_enum=GuardManagerType.GUARD_MANAGER, |
|
) |
|
|
|
l1_mgr = getitem_on_dict_mgr( |
|
mgr=mod_generic_dict_manager, |
|
key=l1_key, |
|
source_name=l1_source_name, |
|
base_example_value=mod_dict, |
|
example_value=l1_value, |
|
guard_manager_enum=l1_guard_manager_enum, |
|
) |
|
|
|
if l2_key: |
|
return getitem_on_dict_mgr( |
|
mgr=l1_mgr, |
|
key=l2_key, |
|
source_name=l2_source_name, |
|
base_example_value=l1_value, |
|
example_value=l2_value, |
|
guard_manager_enum=l2_guard_manager_enum, |
|
) |
|
return l1_mgr |
|
|
|
def requires_key_order_guarding(self, source): |
|
source_name = source.name() |
|
if source_name == "": |
|
return False |
|
obj_id = id(self.get(source_name)) |
|
return obj_id in self.key_order_guarded_dict_ids |
|
|
|
def get_guard_manager_type(self, source, example_value): |
|
guard_manager_enum = GuardManagerType.GUARD_MANAGER |
|
if self.requires_key_order_guarding(source): |
|
|
|
if isinstance(example_value, dict_keys): |
|
guard_manager_enum = GuardManagerType.DICT_GUARD_MANAGER |
|
else: |
|
assert isinstance(example_value, dict) |
|
guard_manager_enum = GuardManagerType.DICT_GUARD_MANAGER |
|
return guard_manager_enum |
|
|
|
def manager_guards_on_keys(self, mgr_enum): |
|
return mgr_enum == GuardManagerType.DICT_GUARD_MANAGER |
|
|
|
def get_global_guard_manager(self): |
|
return self.guard_manager.root.globals_dict_manager( |
|
f_globals=self.scope["G"], |
|
source="G", |
|
example_value=self.scope["G"], |
|
guard_manager_enum=GuardManagerType.GUARD_MANAGER, |
|
) |
|
|
|
def get_guard_manager_from_source(self, source): |
|
root_guard_manager = self.guard_manager.root |
|
|
|
example_value = None |
|
source_name = source.name() |
|
|
|
if source_name != "" and source_name in self._cached_guard_managers: |
|
return self._cached_guard_managers[source_name] |
|
|
|
if source_name != "": |
|
example_value = self.get(source_name) |
|
|
|
guard_manager_enum = self.get_guard_manager_type(source, example_value) |
|
|
|
|
|
base_source_name = None |
|
base_example_value = None |
|
base_guard_manager = None |
|
base_guard_manager_enum = GuardManagerType.GUARD_MANAGER |
|
if isinstance(source, ChainedSource): |
|
base_source_name = source.base.name() |
|
base_example_value = self.get(base_source_name) |
|
base_guard_manager = self.get_guard_manager_from_source(source.base) |
|
base_guard_manager_enum = self.get_guard_manager_type( |
|
source.base, base_example_value |
|
) |
|
|
|
|
|
if istype(source, LocalSource): |
|
|
|
|
|
|
|
|
|
|
|
|
|
if config.enable_cpp_framelocals_guard_eval: |
|
framelocals_names_reversed = code_framelocals_names_reversed_cached( |
|
self.f_code |
|
) |
|
framelocals_idx = ( |
|
len(framelocals_names_reversed) |
|
- framelocals_names_reversed.index(source.local_name) |
|
- 1 |
|
) |
|
out = root_guard_manager.framelocals_manager( |
|
key=(source.local_name, framelocals_idx), |
|
source=source_name, |
|
example_value=example_value, |
|
guard_manager_enum=guard_manager_enum, |
|
) |
|
else: |
|
out = root_guard_manager.dict_getitem_manager( |
|
key=source.local_name, |
|
source=source_name, |
|
example_value=example_value, |
|
guard_manager_enum=guard_manager_enum, |
|
) |
|
elif istype(source, GlobalSource): |
|
|
|
|
|
|
|
out = self.get_global_guard_manager().dict_getitem_manager( |
|
key=source.global_name, |
|
source=source_name, |
|
example_value=example_value, |
|
guard_manager_enum=guard_manager_enum, |
|
) |
|
elif istype(source, GlobalWeakRefSource): |
|
out = self.get_global_guard_manager().global_weakref_manager( |
|
global_name=source.global_name, |
|
source=source_name, |
|
example_value=example_value, |
|
guard_manager_enum=guard_manager_enum, |
|
) |
|
elif istype(source, GlobalStateSource): |
|
|
|
|
|
return root_guard_manager |
|
elif istype(source, ShapeEnvSource): |
|
return root_guard_manager |
|
elif istype(source, TypeSource): |
|
assert base_guard_manager |
|
out = base_guard_manager.type_manager( |
|
source=source_name, |
|
example_value=example_value, |
|
guard_manager_enum=guard_manager_enum, |
|
) |
|
elif istype( |
|
source, |
|
( |
|
OptimizerSource, |
|
NNModuleSource, |
|
UnspecializedNNModuleSource, |
|
UnspecializedBuiltinNNModuleSource, |
|
FSDPNNModuleSource, |
|
), |
|
): |
|
assert base_guard_manager |
|
out = base_guard_manager |
|
elif istype(source, TorchFunctionModeStackSource): |
|
out = root_guard_manager.lambda_manager( |
|
python_lambda=lambda _: get_torch_function_mode_stack_at( |
|
source._get_index() |
|
), |
|
source=source_name, |
|
example_value=example_value, |
|
guard_manager_enum=guard_manager_enum, |
|
) |
|
elif istype(source, GradSource): |
|
assert base_guard_manager |
|
out = base_guard_manager.grad_manager( |
|
source=source_name, |
|
example_value=example_value, |
|
guard_manager_enum=guard_manager_enum, |
|
) |
|
elif istype(source, GenericAttrSource): |
|
assert base_guard_manager |
|
out = base_guard_manager.generic_getattr_manager( |
|
attr=source.member, |
|
source=source_name, |
|
example_value=example_value, |
|
guard_manager_enum=guard_manager_enum, |
|
) |
|
elif istype(source, (AttrSource, UnspecializedParamBufferSource)): |
|
assert base_guard_manager |
|
|
|
if ( |
|
isinstance(base_example_value, torch.nn.Module) |
|
and get_custom_getattr(base_example_value) |
|
is unpatched_nn_module_getattr |
|
): |
|
out = self.getattr_on_nn_module( |
|
source, |
|
base_guard_manager, |
|
base_example_value, |
|
example_value, |
|
base_source_name, |
|
source_name, |
|
guard_manager_enum, |
|
) |
|
else: |
|
out = base_guard_manager.getattr_manager( |
|
attr=source.member, |
|
source=source_name, |
|
example_value=example_value, |
|
guard_manager_enum=guard_manager_enum, |
|
) |
|
elif istype(source, DictGetItemSource): |
|
assert base_guard_manager |
|
assert isinstance(base_example_value, (dict, collections.OrderedDict)) |
|
if isinstance(base_guard_manager, DictGuardManager): |
|
assert self.manager_guards_on_keys(base_guard_manager_enum) |
|
out = getitem_on_dict_manager( |
|
source, |
|
base_guard_manager, |
|
base_example_value, |
|
example_value, |
|
guard_manager_enum, |
|
) |
|
else: |
|
if isinstance(source.index, ConstDictKeySource): |
|
raise RuntimeError( |
|
"Expecting clean index here. Likely Dynamo forgot to mark" |
|
" a dict as guard_on_key_order" |
|
) |
|
out = base_guard_manager.dict_getitem_manager( |
|
key=source.index, |
|
source=source_name, |
|
example_value=example_value, |
|
guard_manager_enum=guard_manager_enum, |
|
) |
|
elif istype(source, TensorPropertySource): |
|
out = getattr( |
|
base_guard_manager, |
|
f"tensor_property_{source.prop.name.lower()}_manager", |
|
)( |
|
idx=source.idx, |
|
source=source_name, |
|
example_value=example_value, |
|
guard_manager_enum=guard_manager_enum, |
|
) |
|
elif istype(source, IndexedSource): |
|
assert base_guard_manager |
|
|
|
out = base_guard_manager.indexed_manager( |
|
idx=source.idx, |
|
source=source_name, |
|
example_value=example_value, |
|
guard_manager_enum=guard_manager_enum, |
|
) |
|
elif istype(source, ListGetItemSource): |
|
assert base_guard_manager |
|
out = base_guard_manager.list_getitem_manager( |
|
key=source.index, |
|
source=source_name, |
|
example_value=example_value, |
|
guard_manager_enum=guard_manager_enum, |
|
) |
|
elif istype(source, GetItemSource): |
|
assert base_guard_manager |
|
assert not isinstance( |
|
base_example_value, (dict, collections.OrderedDict) |
|
), "Use DictGetItemSource" |
|
if isinstance(base_example_value, list) and not source.index_is_slice: |
|
out = base_guard_manager.list_getitem_manager( |
|
key=source.index, |
|
source=source_name, |
|
example_value=example_value, |
|
guard_manager_enum=guard_manager_enum, |
|
) |
|
elif isinstance(base_example_value, tuple) and not source.index_is_slice: |
|
out = base_guard_manager.tuple_getitem_manager( |
|
key=source.index, |
|
source=source_name, |
|
example_value=example_value, |
|
guard_manager_enum=guard_manager_enum, |
|
) |
|
else: |
|
index = source.index |
|
if source.index_is_slice: |
|
index = source.unpack_slice() |
|
out = base_guard_manager.getitem_manager( |
|
key=index, |
|
source=source_name, |
|
example_value=example_value, |
|
guard_manager_enum=guard_manager_enum, |
|
) |
|
elif istype(source, DefaultsSource): |
|
assert base_guard_manager |
|
assert callable(base_example_value) |
|
if not source.is_kw: |
|
out = base_guard_manager.func_defaults_manager( |
|
source=base_source_name, |
|
example_value=base_example_value.__defaults__, |
|
guard_manager_enum=GuardManagerType.GUARD_MANAGER, |
|
).getitem_manager( |
|
key=source.idx_key, |
|
source=source_name, |
|
example_value=example_value, |
|
guard_manager_enum=guard_manager_enum, |
|
) |
|
else: |
|
|
|
kwdefaults = base_example_value.__kwdefaults__ |
|
assert base_source_name is not None |
|
kw_source = base_source_name + ".__kwdefaults__" |
|
|
|
|
|
dict_mgr = base_guard_manager.func_kwdefaults_manager( |
|
source=kw_source, |
|
example_value=kwdefaults, |
|
guard_manager_enum=GuardManagerType.GUARD_MANAGER, |
|
) |
|
assert not isinstance(dict_mgr, DictGuardManager) |
|
|
|
out = dict_mgr.dict_getitem_manager( |
|
key=source.idx_key, |
|
source=source_name, |
|
example_value=example_value, |
|
guard_manager_enum=guard_manager_enum, |
|
) |
|
elif istype(source, NumpyTensorSource): |
|
assert base_guard_manager |
|
out = base_guard_manager.lambda_manager( |
|
python_lambda=from_numpy, |
|
source=source_name, |
|
example_value=example_value, |
|
guard_manager_enum=guard_manager_enum, |
|
) |
|
elif istype(source, SubclassAttrListSource): |
|
assert base_guard_manager |
|
out = base_guard_manager.lambda_manager( |
|
python_lambda=lambda x: x.__tensor_flatten__()[0], |
|
source=source_name, |
|
example_value=example_value, |
|
guard_manager_enum=guard_manager_enum, |
|
) |
|
elif istype(source, FlattenScriptObjectSource): |
|
assert base_guard_manager |
|
out = base_guard_manager.lambda_manager( |
|
python_lambda=lambda x: x.__obj_flatten__(), |
|
source=source_name, |
|
example_value=example_value, |
|
guard_manager_enum=guard_manager_enum, |
|
) |
|
elif istype(source, ScriptObjectQualifiedNameSource): |
|
assert base_guard_manager |
|
out = base_guard_manager.lambda_manager( |
|
python_lambda=lambda x: x._type().qualified_name(), |
|
source=source_name, |
|
example_value=example_value, |
|
guard_manager_enum=guard_manager_enum, |
|
) |
|
elif istype(source, AttrProxySource): |
|
assert base_guard_manager |
|
out = base_guard_manager.lambda_manager( |
|
python_lambda=lambda x: x.get_base(), |
|
source=source_name, |
|
example_value=example_value, |
|
guard_manager_enum=guard_manager_enum, |
|
) |
|
elif istype(source, CallMethodItemSource): |
|
assert base_guard_manager |
|
out = base_guard_manager.lambda_manager( |
|
python_lambda=lambda x: x.item(), |
|
source=source_name, |
|
example_value=example_value, |
|
guard_manager_enum=guard_manager_enum, |
|
) |
|
elif istype(source, FloatTensorSource): |
|
assert base_guard_manager |
|
out = base_guard_manager.lambda_manager( |
|
python_lambda=lambda x: torch._as_tensor_fullprec(x), |
|
source=source_name, |
|
example_value=example_value, |
|
guard_manager_enum=guard_manager_enum, |
|
) |
|
elif istype(source, TupleIteratorGetItemSource): |
|
assert base_guard_manager |
|
out = base_guard_manager.tuple_iterator_getitem_manager( |
|
index=source.index, |
|
source=source_name, |
|
example_value=example_value, |
|
guard_manager_enum=guard_manager_enum, |
|
) |
|
elif isinstance(source, ConstDictKeySource): |
|
if not isinstance(base_guard_manager, DictGuardManager): |
|
raise AssertionError( |
|
"ConstDictKeySource can only work on DictGuardManager" |
|
) |
|
out = base_guard_manager.get_key_manager( |
|
index=source.index, |
|
source=source_name, |
|
example_value=example_value, |
|
guard_manager_enum=guard_manager_enum, |
|
) |
|
elif istype(source, WeakRefCallSource): |
|
assert base_guard_manager |
|
out = base_guard_manager.weakref_call_manager( |
|
source=source_name, |
|
example_value=example_value, |
|
guard_manager_enum=guard_manager_enum, |
|
) |
|
elif istype(source, CallFunctionNoArgsSource): |
|
assert base_guard_manager |
|
out = base_guard_manager.call_function_no_args_manager( |
|
source=source_name, |
|
example_value=example_value, |
|
guard_manager_enum=guard_manager_enum, |
|
) |
|
else: |
|
raise AssertionError( |
|
f"missing guard manager builder {source} - {source.name()}" |
|
) |
|
|
|
self._cached_guard_managers[source.name()] = out |
|
return out |
|
|
|
def get_guard_manager(self, guard: Guard): |
|
return self.get_guard_manager_from_source(guard.originating_source) |
|
|
|
def add_python_lambda_leaf_guard_to_root( |
|
self, |
|
code_parts, |
|
verbose_code_parts, |
|
closure_vars=None, |
|
is_epilogue=True, |
|
): |
|
if closure_vars is None: |
|
closure_vars = _get_closure_vars() |
|
|
|
|
|
|
|
make_guard_fn_args = ", ".join(closure_vars.keys()) |
|
_guard_body, pycode = build_guard_function(code_parts, make_guard_fn_args) |
|
out: dict[str, Any] = {} |
|
globals_for_guard_fn = {"G": self.scope["G"]} |
|
guards_log.debug("Python shape guard function:\n%s", pycode) |
|
exec(pycode, globals_for_guard_fn, out) |
|
guard_fn = out["___make_guard_fn"](*closure_vars.values()) |
|
if is_epilogue: |
|
|
|
|
|
|
|
self.guard_manager.root.add_epilogue_lambda_guard( |
|
guard_fn, verbose_code_parts |
|
) |
|
else: |
|
self.guard_manager.root.add_lambda_guard(guard_fn, verbose_code_parts) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get(self, name: str, closure_vars: Optional[dict[str, Any]] = None) -> Any: |
|
if closure_vars is None: |
|
closure_vars = _get_closure_vars() |
|
return eval(name, self.scope, closure_vars) |
|
|
|
|
|
|
|
|
|
|
|
|
|
def arg_ref(self, guard: Union[str, Guard]) -> str: |
|
name: str |
|
if isinstance(guard, str): |
|
name = guard |
|
else: |
|
name = guard.name |
|
base = strip_function_call(name) |
|
if base not in self.argnames: |
|
is_valid = torch._C._dynamo.is_valid_var_name(base) |
|
if is_valid: |
|
if is_valid == 2: |
|
log.warning("invalid var name: %s", guard) |
|
self.argnames.append(base) |
|
|
|
return name |
|
|
|
def _guard_on_attribute(self, guard: Guard, attr_name: str, guard_fn): |
|
attr_source = AttrSource(guard.originating_source, attr_name) |
|
|
|
new_guard = Guard( |
|
attr_source, guard_fn, stack=guard.stack, user_stack=guard.user_stack |
|
) |
|
new_guard.create(self) |
|
|
|
|
|
def HASATTR(self, guard: Guard): |
|
source = guard.originating_source |
|
if isinstance(source, NNModuleSource): |
|
source = source.base |
|
assert isinstance(source, AttrSource), f"invalid source {guard.name}" |
|
base_source = source.base |
|
base = base_source.name() |
|
attr = source.member |
|
|
|
ref = self.arg_ref(base) |
|
val = hasattr(self.get(base), attr) |
|
code = None |
|
if val: |
|
code = f"hasattr({ref}, {attr!r})" |
|
else: |
|
code = f"not hasattr({ref}, {attr!r})" |
|
self._set_guard_export_info( |
|
guard, [code], provided_guarded_object=self.get(base) |
|
) |
|
|
|
base_manager = self.get_guard_manager_from_source(base_source) |
|
if val: |
|
|
|
|
|
example_value = self.get(source.name()) |
|
base_example_value = self.get(base) |
|
guard_manager_enum = self.get_guard_manager_type(source, example_value) |
|
|
|
|
|
|
|
if ( |
|
isinstance(base_example_value, torch.nn.Module) |
|
and get_custom_getattr(base_example_value) |
|
is unpatched_nn_module_getattr |
|
): |
|
return self.getattr_on_nn_module( |
|
source, |
|
base_manager, |
|
base_example_value, |
|
example_value, |
|
base, |
|
source.name(), |
|
guard_manager_enum, |
|
) |
|
else: |
|
base_manager.getattr_manager( |
|
attr=attr, |
|
source=guard.name, |
|
example_value=example_value, |
|
guard_manager_enum=guard_manager_enum, |
|
) |
|
else: |
|
base_manager.add_no_hasattr_guard(attr, get_verbose_code_parts(code, guard)) |
|
|
|
def NOT_PRESENT_IN_GENERIC_DICT(self, guard: Guard, attr=None) -> None: |
|
assert attr is not None |
|
ref = self.arg_ref(guard) |
|
val = self.get(guard.name) |
|
assert isinstance(val, torch.nn.Module) |
|
|
|
base_manager = self.get_guard_manager(guard) |
|
|
|
mod_dict_source = f"{guard.name}.__dict__" |
|
mod_generic_dict_manager = base_manager.get_generic_dict_manager( |
|
source=mod_dict_source, |
|
example_value=self._get_generic_dict_manager_example_value(val.__dict__), |
|
guard_manager_enum=GuardManagerType.GUARD_MANAGER, |
|
) |
|
|
|
code = f"not ___dict_contains({attr!r}, {ref}.__dict__)" |
|
mod_generic_dict_manager.add_dict_contains_guard( |
|
False, attr, get_verbose_code_parts(code, guard) |
|
) |
|
|
|
def TYPE_MATCH(self, guard: Guard) -> None: |
|
|
|
t = type(self.get(guard.name)) |
|
obj_id = self.id_ref(t, f"type({guard.name})") |
|
code = f"___check_type_id({self.arg_ref(guard)}, {obj_id})" |
|
self._set_guard_export_info(guard, [code]) |
|
|
|
self.get_guard_manager(guard).add_type_match_guard( |
|
obj_id, get_verbose_code_parts(code, guard) |
|
) |
|
|
|
def DICT_VERSION(self, guard: Guard): |
|
|
|
ref = self.arg_ref(guard) |
|
val = self.get(guard.name) |
|
version = dict_version(self.get(guard.name)) |
|
code = f"___dict_version({ref}) == {version}" |
|
self._set_guard_export_info(guard, [code]) |
|
|
|
|
|
|
|
self.get_guard_manager(guard).add_dict_version_guard( |
|
val, get_verbose_code_parts(code, guard) |
|
) |
|
|
|
def DICT_CONTAINS(self, guard: Guard, key: str, invert: bool): |
|
dict_ref = self.arg_ref(guard) |
|
|
|
maybe_not = "not " if invert else "" |
|
code = f"{maybe_not}___dict_contains({key!r}, {dict_ref})" |
|
self._set_guard_export_info(guard, [code]) |
|
|
|
self.get_guard_manager(guard).add_dict_contains_guard( |
|
not invert, key, get_verbose_code_parts(code, guard) |
|
) |
|
|
|
def ID_MATCH(self, guard: Guard): |
|
|
|
if isinstance(guard.originating_source, TypeSource): |
|
|
|
return self.TYPE_MATCH( |
|
Guard(guard.originating_source.base, GuardBuilder.TYPE_MATCH) |
|
) |
|
|
|
ref = self.arg_ref(guard) |
|
val = self.get(guard.name) |
|
id_val = self.id_ref(val, guard.name) |
|
code = f"___check_obj_id({ref}, {id_val})" |
|
self._set_guard_export_info(guard, [code]) |
|
|
|
self.get_guard_manager(guard).add_id_match_guard( |
|
id_val, get_verbose_code_parts(code, guard) |
|
) |
|
|
|
|
|
|
|
if isinstance(guard.originating_source, LocalSource): |
|
|
|
|
|
|
|
if isinstance(val, torch.nn.Module): |
|
local_name = guard.originating_source.local_name |
|
weak_id = self.lookup_weakrefs(val) |
|
if weak_id is not None: |
|
self.id_matched_objs[local_name] = weak_id |
|
|
|
def NOT_NONE_MATCH(self, guard: Guard, value=None): |
|
ref = self.arg_ref(guard) |
|
val = self.get(guard.name) |
|
assert isinstance(val, torch.Tensor) |
|
code = f"{ref} is not None" |
|
self._set_guard_export_info(guard, [code]) |
|
|
|
self.get_guard_manager(guard).add_not_none_guard( |
|
get_verbose_code_parts(code, guard) |
|
) |
|
|
|
def DISPATCH_KEY_SET_MATCH(self, guard: Guard): |
|
ref = self.arg_ref(guard) |
|
val = self.get(guard.name) |
|
assert isinstance(val, torch._C.DispatchKeySet) |
|
code_parts = f"{ref}.raw_repr() == {val!r}.raw_repr()" |
|
|
|
self.get_guard_manager(guard).add_dispatch_key_set_guard( |
|
val, get_verbose_code_parts(code_parts, guard) |
|
) |
|
|
|
def NAME_MATCH(self, guard: Guard): |
|
self._guard_on_attribute(guard, "__name__", GuardBuilder.EQUALS_MATCH) |
|
|
|
def DATA_PTR_MATCH(self, guard: Guard): |
|
|
|
obj = self.get(guard.name) |
|
code = f"{self.arg_ref(guard)}.data_ptr() == {obj.data_ptr()}" |
|
self._set_guard_export_info(guard, [code]) |
|
|
|
self.get_guard_manager(guard).add_data_ptr_guard( |
|
obj, get_verbose_code_parts(code, guard) |
|
) |
|
|
|
def DUAL_LEVEL(self, guard: Guard): |
|
|
|
|
|
dual_level = torch.autograd.forward_ad._current_level |
|
code = [f"torch.autograd.forward_ad._current_level == {dual_level}"] |
|
self._set_guard_export_info(guard, [code]) |
|
|
|
forward_ad = torch.autograd.forward_ad |
|
|
|
def fn(x): |
|
return forward_ad._current_level == dual_level |
|
|
|
self.guard_manager.root.add_lambda_guard( |
|
fn, get_verbose_code_parts(code, guard) |
|
) |
|
|
|
def FUNCTORCH_STACK_MATCH(self, guard: Guard): |
|
|
|
|
|
cis = torch._functorch.pyfunctorch.retrieve_all_functorch_interpreters() |
|
states = [ci.get_state() for ci in cis] |
|
code = [f"torch._functorch.pyfunctorch.compare_functorch_state({states})"] |
|
self._set_guard_export_info(guard, code) |
|
|
|
|
|
compare_fn = torch._functorch.pyfunctorch.compare_functorch_state |
|
|
|
def fn(x): |
|
return compare_fn(states) |
|
|
|
self.guard_manager.root.add_lambda_guard( |
|
fn, get_verbose_code_parts(code, guard) |
|
) |
|
|
|
def TENSOR_SUBCLASS_METADATA_MATCH(self, guard: Guard): |
|
value = self.get(guard.name) |
|
original_metadata = deepcopy(self.get(guard.name).__tensor_flatten__()[1]) |
|
if hasattr(value, "__metadata_guard__"): |
|
verify_guard_fn_signature(value) |
|
|
|
def metadata_checker(x): |
|
return value.__metadata_guard__( |
|
original_metadata, x.__tensor_flatten__()[1] |
|
) |
|
|
|
else: |
|
|
|
def metadata_checker(x): |
|
return x.__tensor_flatten__()[1] == original_metadata |
|
|
|
global_name = f"___check_metadata_{id(metadata_checker)}_c{CompileContext.current_compile_id()}" |
|
self.get_guard_manager(guard).add_lambda_guard( |
|
metadata_checker, get_verbose_code_parts(global_name, guard) |
|
) |
|
|
|
def EQUALS_MATCH(self, guard: Guard): |
|
ref = self.arg_ref(guard) |
|
val = self.get(guard.name) |
|
if np: |
|
np_types: tuple[type[Any], ...] = ( |
|
np.int8, |
|
np.int16, |
|
np.int32, |
|
np.int64, |
|
np.uint8, |
|
np.uint16, |
|
np.uint32, |
|
np.uint64, |
|
np.float16, |
|
np.float32, |
|
np.float64, |
|
) |
|
else: |
|
np_types = () |
|
|
|
ok_mutable_types = (list, set) |
|
|
|
ok_types = tuple( |
|
common_constant_types |
|
| { |
|
type, |
|
tuple, |
|
frozenset, |
|
slice, |
|
range, |
|
dict_keys, |
|
torch.Size, |
|
*np_types, |
|
*ok_mutable_types, |
|
} |
|
) |
|
|
|
if torch.distributed.is_available(): |
|
from torch.distributed.device_mesh import DeviceMesh |
|
from torch.distributed.tensor.placement_types import ( |
|
Partial, |
|
Replicate, |
|
Shard, |
|
) |
|
|
|
ok_types = ok_types + ( |
|
Shard, |
|
Replicate, |
|
Partial, |
|
DeviceMesh, |
|
) |
|
|
|
import torch.utils._pytree as pytree |
|
|
|
assert istype(val, ok_types) or pytree.is_constant_class(type(val)), ( |
|
f"Unexpected type {type(val)}" |
|
) |
|
|
|
|
|
if istype(val, float) and math.isnan(val): |
|
self.TYPE_MATCH(guard) |
|
code = [] |
|
code.append(f"__math_isnan({ref})") |
|
self._set_guard_export_info(guard, code) |
|
|
|
self.get_guard_manager(guard).add_lambda_guard( |
|
_get_closure_vars()["__math_isnan"], |
|
get_verbose_code_parts(code, guard), |
|
) |
|
return |
|
|
|
|
|
if istype(val, complex) and np.isnan(val): |
|
self.TYPE_MATCH(guard) |
|
code = [] |
|
code.append(f"__numpy_isnan({ref})") |
|
self._set_guard_export_info(guard, code) |
|
|
|
self.get_guard_manager(guard).add_lambda_guard( |
|
_get_closure_vars()["__numpy_isnan"], |
|
get_verbose_code_parts(code, guard), |
|
) |
|
return |
|
|
|
|
|
code = [f"{ref} == {val!r}"] |
|
if istype(val, ok_mutable_types): |
|
|
|
|
|
|
|
val = deepcopy(val) |
|
self.get_guard_manager(guard).add_equals_match_guard( |
|
val, get_verbose_code_parts(code, guard) |
|
) |
|
self._set_guard_export_info(guard, code) |
|
return |
|
|
|
def CONSTANT_MATCH(self, guard: Guard): |
|
val = self.get(guard.name) |
|
if istype(val, (bool, type(None), types.CodeType)): |
|
self.ID_MATCH(guard) |
|
else: |
|
self.EQUALS_MATCH(guard) |
|
|
|
def NN_MODULE(self, guard: Guard): |
|
self.ID_MATCH(guard) |
|
val = self.get(guard.name) |
|
if hasattr(val, "training"): |
|
assert istype(val.training, bool) |
|
self._guard_on_attribute(guard, "training", GuardBuilder.CONSTANT_MATCH) |
|
else: |
|
exc.unimplemented_v2( |
|
gb_type="Attempted to guard on uninitialized nn.Module", |
|
context="", |
|
explanation="Attempted to setup an NN_MODULE guard on uninitialized " |
|
f"nn.Module subclass `{type(val)}`.", |
|
hints=[ |
|
"Ensure the `nn.Module` subclass instance has called `super().__init__()`.", |
|
], |
|
) |
|
|
|
def FUNCTION_MATCH(self, guard: Guard): |
|
"""things like torch.add and user defined functions""" |
|
return self.ID_MATCH(guard) |
|
|
|
def CLOSURE_MATCH(self, guard: Guard): |
|
"""matches a closure by __code__ id.""" |
|
val = self.get(guard.name) |
|
|
|
if type(val) == types.FunctionType and hasattr(val, "__code__"): |
|
self._guard_on_attribute(guard, "__code__", GuardBuilder.HASATTR) |
|
self._guard_on_attribute(guard, "__code__", GuardBuilder.FUNCTION_MATCH) |
|
else: |
|
self.FUNCTION_MATCH(guard) |
|
|
|
def BUILTIN_MATCH(self, guard: Guard): |
|
return self.FUNCTION_MATCH(guard) |
|
|
|
def PYMODULE_MATCH(self, guard: Guard): |
|
return self.FUNCTION_MATCH(guard) |
|
|
|
def SEQUENCE_LENGTH(self, guard): |
|
|
|
|
|
ref = self.arg_ref(guard) |
|
value = self.get(guard.name) |
|
|
|
if not isinstance(value, dict): |
|
|
|
self.TYPE_MATCH(guard) |
|
|
|
code = [] |
|
if len(value) == 0: |
|
code.append(f"not {ref}") |
|
else: |
|
code.append(f"len({ref}) == {len(value)}") |
|
|
|
self._set_guard_export_info(guard, code) |
|
if isinstance(value, dict): |
|
self.get_guard_manager(guard).add_dict_length_check_guard( |
|
len(value), get_verbose_code_parts(code, guard) |
|
) |
|
else: |
|
self.get_guard_manager(guard).add_length_check_guard( |
|
len(value), get_verbose_code_parts(code, guard) |
|
) |
|
|
|
def TUPLE_ITERATOR_LEN(self, guard): |
|
ref = self.arg_ref(guard) |
|
value = self.get(guard.name) |
|
t = type(value) |
|
|
|
code = [] |
|
code.append(f"___tuple_iterator_len({ref}) == {tuple_iterator_len(value)}") |
|
self._set_guard_export_info(guard, code) |
|
|
|
t = type(value) |
|
obj_id = self.id_ref(t, f"type({guard.name})") |
|
|
|
self.get_guard_manager(guard).add_tuple_iterator_length_guard( |
|
tuple_iterator_len(value), obj_id, get_verbose_code_parts(code, guard) |
|
) |
|
|
|
def RANGE_ITERATOR_MATCH(self, guard): |
|
ref = self.arg_ref(guard) |
|
value = self.get(guard.name) |
|
t = type(value) |
|
|
|
code = [] |
|
normalized_range_iter = normalize_range_iter(value) |
|
code.append(f"___normalize_range_iter({ref}) == {normalized_range_iter}") |
|
self._set_guard_export_info(guard, code) |
|
|
|
t = type(value) |
|
obj_id = self.id_ref(t, f"type({guard.name})") |
|
|
|
start, stop, step = normalized_range_iter |
|
self.get_guard_manager(guard).add_range_iterator_match_guard( |
|
start, stop, step, obj_id, get_verbose_code_parts(code, guard) |
|
) |
|
|
|
|
|
def DUPLICATE_INPUT(self, guard, source_b): |
|
ref_a = self.arg_ref(guard) |
|
ref_b = self.arg_ref(source_b.name()) |
|
|
|
if is_from_optimizer_source( |
|
guard.originating_source |
|
) or is_from_optimizer_source(source_b): |
|
return |
|
|
|
|
|
key = (ref_a, ref_b) |
|
if key in self._cached_duplicate_input_guards: |
|
return |
|
|
|
self._cached_duplicate_input_guards.add((ref_a, ref_b)) |
|
self._cached_duplicate_input_guards.add((ref_b, ref_a)) |
|
|
|
code = [f"{ref_b} is {ref_a}"] |
|
self._set_guard_export_info(guard, code) |
|
|
|
install_object_aliasing_guard( |
|
self.get_guard_manager(guard), |
|
self.get_guard_manager_from_source(source_b), |
|
get_verbose_code_parts(code, guard), |
|
) |
|
|
|
def WEAKREF_ALIVE(self, guard): |
|
code = [f"{self.arg_ref(guard)} is not None"] |
|
|
|
self._set_guard_export_info(guard, code) |
|
self.get_guard_manager(guard).add_not_none_guard( |
|
get_verbose_code_parts(code, guard) |
|
) |
|
|
|
def MAPPING_KEYS_CHECK(self, guard): |
|
"""Guard on the key order of types.MappingProxyType object""" |
|
ref = self.arg_ref(guard) |
|
value = self.get(guard.name) |
|
|
|
code = [] |
|
code.append(f"list({ref}.keys()) == {list(value.keys())}") |
|
self._set_guard_export_info(guard, code) |
|
self.get_guard_manager(guard).add_mapping_keys_guard(value, code) |
|
|
|
def DICT_KEYS_MATCH(self, guard): |
|
"""Insert guard to check that the keys of a dict are same""" |
|
ref = self.arg_ref(guard) |
|
value = self.get(guard.name) |
|
|
|
if value is torch.utils._pytree.SUPPORTED_NODES: |
|
|
|
self.DICT_VERSION(guard) |
|
return |
|
|
|
self.SEQUENCE_LENGTH(guard) |
|
|
|
code = [] |
|
|
|
|
|
|
|
|
|
code.append(f"list(dict.keys({ref})) == {list(builtin_dict_keys(value))!r}") |
|
self._set_guard_export_info(guard, code) |
|
|
|
if self.requires_key_order_guarding(guard.originating_source): |
|
self.guard_on_dict_keys_and_order(value, guard) |
|
else: |
|
self.guard_on_dict_keys_and_ignore_order(value, guard) |
|
|
|
def EMPTY_NN_MODULE_HOOKS_DICT(self, guard): |
|
"""Special guard to skip guards on empty hooks. This is controlled by skip_nnmodule_hook_guards""" |
|
if config.skip_nnmodule_hook_guards: |
|
|
|
return |
|
self.SEQUENCE_LENGTH(guard) |
|
|
|
def OBJECT_MUTATION(self, guard: Guard): |
|
mutation_guard.watch(self.get(guard.name), self.check_fn_manager) |
|
|
|
def GRAD_MODE(self, guard: Guard): |
|
pass |
|
|
|
def DETERMINISTIC_ALGORITHMS(self, guard: Guard): |
|
pass |
|
|
|
def TORCH_FUNCTION_STATE(self, guard: Guard): |
|
pass |
|
|
|
def FSDP_TRAINING_STATE(self, guard: Guard): |
|
pass |
|
|
|
def DEFAULT_DEVICE(self, guard: Guard): |
|
"""Guard on CURRENT_DEVICE per torch.utils._device""" |
|
assert guard.source is GuardSource.GLOBAL |
|
import torch.utils._device as m |
|
|
|
code = [f"utils_device.CURRENT_DEVICE == {m.CURRENT_DEVICE!r}"] |
|
self._set_guard_export_info(guard, code) |
|
|
|
self.get_guard_manager(guard).add_default_device_guard( |
|
get_verbose_code_parts(code, guard) |
|
) |
|
|
|
def SHAPE_ENV(self, guard: Guard): |
|
|
|
|
|
|
|
assert guard.name == "" |
|
output_graph = self.check_fn_manager.output_graph |
|
|
|
fs = output_graph.tracked_fakes |
|
input_contexts = [a.symbolic_context for a in fs] |
|
|
|
def get_sources(t_id, dim): |
|
|
|
|
|
return [ |
|
TensorPropertySource(source, TensorProperty.SIZE, dim) |
|
for source in output_graph.tracked_fakes_id_to_source[t_id] |
|
] |
|
|
|
if output_graph.export_constraints: |
|
names: dict[str, tuple[int, int]] = {} |
|
source_pairs: list[tuple[Source, Source]] = [] |
|
derived_equalities: list[ |
|
tuple[Source, Union[Source, Symbol], Callable] |
|
] = [] |
|
phantom_symbols: dict[str, Symbol] = {} |
|
relaxed_sources: set[Source] = set() |
|
for constraint in output_graph.export_constraints: |
|
if constraint.t_id in output_graph.tracked_fakes_id_to_source: |
|
torch.export.dynamic_shapes._process_equalities( |
|
constraint, |
|
get_sources, |
|
output_graph.shape_env, |
|
names, |
|
source_pairs, |
|
derived_equalities, |
|
phantom_symbols, |
|
relaxed_sources, |
|
) |
|
else: |
|
log.warning("Untracked tensor used in export constraints") |
|
equalities_inputs = EqualityConstraint( |
|
source_pairs=source_pairs, |
|
derived_equalities=derived_equalities, |
|
phantom_symbols=list(phantom_symbols.values()), |
|
relaxed_sources=relaxed_sources, |
|
warn_only=False, |
|
) |
|
else: |
|
equalities_inputs = None |
|
|
|
def _get_code_parts(langs): |
|
return output_graph.shape_env.produce_guards_verbose( |
|
[a.fake for a in fs], |
|
[a.source for a in fs], |
|
input_contexts=input_contexts, |
|
equalities_inputs=equalities_inputs, |
|
source_ref=self.source_ref, |
|
|
|
ignore_static=(not self.check_fn_manager.output_graph.export), |
|
langs=langs, |
|
) |
|
|
|
if config.enable_cpp_symbolic_shape_guards: |
|
|
|
python_code_parts, verbose_code_parts, cpp_code_parts = _get_code_parts( |
|
("python", "verbose_python", "cpp") |
|
) |
|
else: |
|
python_code_parts, verbose_code_parts = _get_code_parts( |
|
("python", "verbose_python") |
|
) |
|
|
|
|
|
|
|
if not self.check_fn_manager.output_graph.export: |
|
output_graph.shape_env.freeze() |
|
for code in python_code_parts.exprs: |
|
self._set_guard_export_info(guard, [code]) |
|
|
|
|
|
if compile_context := CompileContext.try_get(): |
|
compile_context.shape_env_guards.extend(verbose_code_parts.exprs) |
|
|
|
if config.enable_cpp_symbolic_shape_guards: |
|
import ctypes |
|
|
|
from torch._inductor.codecache import CppCodeCache |
|
|
|
assert cpp_code_parts |
|
code_parts, source_to_symbol = ( |
|
cpp_code_parts.exprs, |
|
cpp_code_parts.source_to_symbol, |
|
) |
|
|
|
if not code_parts: |
|
return |
|
|
|
int_source_to_symbol = [] |
|
float_source_to_symbol = [] |
|
|
|
python_fallback = False |
|
for source, symbol in source_to_symbol.items(): |
|
if isinstance(source, ConstantSource): |
|
python_fallback = True |
|
else: |
|
example_value = self.get( |
|
source.name(), |
|
closure_vars={**SYMPY_INTERP, **_get_closure_vars()}, |
|
) |
|
if isinstance(example_value, int): |
|
int_source_to_symbol.append((source, symbol)) |
|
elif isinstance(example_value, float): |
|
float_source_to_symbol.append((source, symbol)) |
|
else: |
|
|
|
|
|
python_fallback = True |
|
|
|
if not python_fallback: |
|
source_to_symbol = dict(int_source_to_symbol + float_source_to_symbol) |
|
try: |
|
guard_managers = [ |
|
self.get_guard_manager_from_source(IndexedSource(source, i)) |
|
for i, source in enumerate(source_to_symbol) |
|
] |
|
|
|
int_symbols_str = ", ".join( |
|
f"{symbol} = int_values[{i}]" |
|
for i, (_, symbol) in enumerate(int_source_to_symbol) |
|
) |
|
float_symbols_str = ", ".join( |
|
f"{symbol} = float_values[{i}]" |
|
for i, (_, symbol) in enumerate(float_source_to_symbol) |
|
) |
|
|
|
if int_symbols_str: |
|
int_symbols_str = f"int64_t {int_symbols_str};" |
|
if float_symbols_str: |
|
float_symbols_str = f"double {float_symbols_str};" |
|
|
|
func_str = textwrap.dedent( |
|
f""" |
|
#include <cstdint> |
|
#include <cmath> |
|
#include <c10/util/generic_math.h> |
|
|
|
extern "C" int8_t guard(int64_t *int_values, double *float_values) {{ |
|
{int_symbols_str} |
|
{float_symbols_str} |
|
return ({") && (".join(code_parts)}); |
|
}} |
|
""" |
|
) |
|
guards_log.debug( |
|
"C++ shape guard function: %s %s", |
|
func_str, |
|
verbose_code_parts.exprs, |
|
) |
|
clib = CppCodeCache.load(func_str) |
|
cguard = ctypes.cast(clib.guard, ctypes.c_void_p).value |
|
assert cguard |
|
except torch._inductor.exc.InvalidCxxCompiler: |
|
|
|
pass |
|
else: |
|
install_symbolic_shape_guard( |
|
guard_managers, |
|
len(int_source_to_symbol), |
|
len(float_source_to_symbol), |
|
cguard, |
|
clib, |
|
verbose_code_parts.exprs, |
|
) |
|
return |
|
|
|
|
|
|
|
|
|
if python_code_parts.exprs: |
|
self.add_python_lambda_leaf_guard_to_root( |
|
python_code_parts.exprs, |
|
verbose_code_parts.exprs, |
|
closure_vars={**SYMPY_INTERP, **_get_closure_vars()}, |
|
) |
|
|
|
def TENSOR_MATCH(self, guard: Guard, value=None): |
|
if config._unsafe_skip_fsdp_module_guards and guard.is_fsdp_module(): |
|
return |
|
|
|
|
|
|
|
if match_on_id_for_tensor(guard): |
|
self.ID_MATCH(guard) |
|
else: |
|
if isinstance(value, TensorWeakRef): |
|
value = value() |
|
|
|
value = value if value is not None else self.get(guard.name) |
|
assert isinstance(value, torch.Tensor) |
|
|
|
tensor_name = self.arg_ref(guard) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
code: list[str] = [] |
|
if self.check_fn_manager.output_graph.export: |
|
self.TYPE_MATCH(guard) |
|
terms = [ |
|
"dtype", |
|
"device", |
|
"requires_grad", |
|
"ndimension()", |
|
] |
|
|
|
for term in terms: |
|
real_value = self.get(tensor_name + "." + term) |
|
if istype(real_value, (torch.device, torch.dtype)): |
|
|
|
code.append(f"str({tensor_name}.{term}) == {str(real_value)!r}") |
|
else: |
|
code.append(f"{tensor_name}.{term} == {real_value}") |
|
else: |
|
guard_manager = self.get_guard_manager(guard) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if not ( |
|
config.skip_no_tensor_aliasing_guards_on_parameters |
|
and istype(value, torch.nn.Parameter) |
|
) and not isinstance(guard.originating_source, NumpyTensorSource): |
|
|
|
|
|
self.no_tensor_aliasing_names.append(tensor_name) |
|
self.no_tensor_aliasing_guard_managers.append(guard_manager) |
|
|
|
output_graph = self.check_fn_manager.output_graph |
|
metadata = output_graph.input_source_to_sizes_strides[ |
|
guard.originating_source |
|
] |
|
size = convert_to_concrete_values(metadata["size"]) |
|
stride = convert_to_concrete_values(metadata["stride"]) |
|
|
|
verbose_code_parts = get_verbose_code_parts( |
|
get_tensor_guard_code_part(value, tensor_name, size, stride), |
|
guard, |
|
) |
|
guard_manager.add_tensor_match_guard( |
|
value, |
|
size, |
|
stride, |
|
tensor_name, |
|
verbose_code_parts, |
|
) |
|
|
|
|
|
|
|
if not isinstance(value, torch.nn.Parameter): |
|
self.check_fn_manager.guard_manager.diff_guard_sources.add( |
|
guard.name |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
assert guard.source is not None |
|
static, _reason = tensor_always_has_static_shape( |
|
value, is_tensor=True, tensor_source=guard.originating_source |
|
) |
|
|
|
if not static: |
|
if hasattr(value, "_dynamo_dynamic_indices"): |
|
dynamic_indices = value._dynamo_dynamic_indices |
|
code_part = f"(({tensor_name}._dynamo_dynamic_indices.issubset({dynamic_indices})) if hasattr({tensor_name}, '_dynamo_dynamic_indices') else True)" |
|
code.append(code_part) |
|
self.get_guard_manager(guard).add_dynamic_indices_guard( |
|
dynamic_indices, get_verbose_code_parts(code_part, guard) |
|
) |
|
|
|
|
|
else: |
|
code_part = ( |
|
f"hasattr({tensor_name}, '_dynamo_dynamic_indices') == False" |
|
) |
|
code.append(code_part) |
|
self.get_guard_manager(guard).add_no_hasattr_guard( |
|
"_dynamo_dynamic_indices", |
|
get_verbose_code_parts(code_part, guard), |
|
) |
|
if len(code) > 0: |
|
self._set_guard_export_info(guard, code) |
|
|
|
|
|
def _set_guard_export_info(self, guard, code_list, provided_guarded_object=None): |
|
|
|
|
|
|
|
cur_frame = currentframe() |
|
assert cur_frame is not None |
|
caller = cur_frame.f_back |
|
del cur_frame |
|
assert caller is not None |
|
func_name = caller.f_code.co_name |
|
del caller |
|
|
|
assert func_name in self.__class__.__dict__, ( |
|
f"_produce_guard_code must be called from inside GuardedCode. Called from {func_name}" |
|
) |
|
|
|
|
|
if provided_guarded_object is None: |
|
name = guard.name |
|
guarded_object = None if not name else self.get(name) |
|
else: |
|
guarded_object = provided_guarded_object |
|
|
|
guarded_object_type = ( |
|
weakref.ref(type(guarded_object)) if guarded_object is not None else None |
|
) |
|
obj_ref = None |
|
|
|
|
|
supports_weakref = ( |
|
getattr(guarded_object.__class__, "__weakrefoffset__", 0) != 0 |
|
) |
|
|
|
if supports_weakref and not isinstance(guarded_object, (enum.Enum, tuple)): |
|
obj_ref = weakref.ref(guarded_object) |
|
|
|
guard.set_export_info( |
|
func_name, |
|
guarded_object_type, |
|
code_list, |
|
obj_ref, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class PyExprCSEPass: |
|
|
|
|
|
USE_THRESHOLD = 1 |
|
|
|
|
|
ALLOWED_NODE_TYPES = (ast.Attribute, ast.Call, ast.Subscript) |
|
|
|
@dataclasses.dataclass |
|
class Config: |
|
expr_count: dict[str, int] |
|
expr_to_name: dict[str, str] |
|
|
|
class ExprCounter(ast.NodeVisitor): |
|
def __init__(self, config: PyExprCSEPass.Config) -> None: |
|
self._config = config |
|
|
|
def visit(self, node: ast.AST) -> Any: |
|
if isinstance(node, PyExprCSEPass.ALLOWED_NODE_TYPES): |
|
self._config.expr_count[_ast_unparse(node)] += 1 |
|
super().visit(node) |
|
|
|
class Replacer(ast.NodeTransformer): |
|
def __init__( |
|
self, |
|
config: PyExprCSEPass.Config, |
|
gen_name: Callable[[], str], |
|
) -> None: |
|
super().__init__() |
|
self._config = config |
|
self._gen_name = gen_name |
|
self.preface: list[str] = [] |
|
|
|
def visit(self, node: ast.AST) -> Any: |
|
if isinstance(node, PyExprCSEPass.ALLOWED_NODE_TYPES): |
|
expr = _ast_unparse(node) |
|
|
|
|
|
|
|
if self._config.expr_count[expr] > PyExprCSEPass.USE_THRESHOLD: |
|
if expr not in self._config.expr_to_name: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
node_ = super().visit(node) |
|
expr_ = _ast_unparse(node_) |
|
var_name = self._gen_name() |
|
self.preface.append(f"{var_name} = {expr_}") |
|
self._config.expr_to_name[expr] = var_name |
|
else: |
|
var_name = self._config.expr_to_name[expr] |
|
return ast.Name(var_name, ast.Load()) |
|
|
|
return super().visit(node) |
|
|
|
def __init__(self) -> None: |
|
self._counter = 0 |
|
self._config = self.Config( |
|
expr_count=collections.defaultdict(lambda: 0), expr_to_name={} |
|
) |
|
|
|
def _new_var(self, prefix: str = "_var") -> str: |
|
name = f"{prefix}{self._counter}" |
|
self._counter += 1 |
|
return name |
|
|
|
def count(self, exprs: list[str]) -> None: |
|
counter = self.ExprCounter(self._config) |
|
for e in exprs: |
|
try: |
|
counter.visit(ast.parse(e)) |
|
except SyntaxError as ex: |
|
log.exception("Failed to visit expr at line %s.\n%s", ex.lineno, e) |
|
raise |
|
|
|
def replace(self, expr: str) -> tuple[list[str], str]: |
|
replacer = self.Replacer(self._config, self._new_var) |
|
new_node = replacer.visit(ast.parse(expr)) |
|
return replacer.preface, _ast_unparse(new_node) |
|
|
|
|
|
def must_add_nn_module_guards(guard): |
|
|
|
|
|
return ( |
|
|
|
isinstance(guard.originating_source, DefaultsSource) |
|
|
|
or ( |
|
config.guard_nn_modules_using_dict_tags |
|
and guard.create_fn is GuardBuilder.NN_MODULE |
|
) |
|
) |
|
|
|
|
|
class DeletedGuardManagerWrapper(GuardManagerWrapper): |
|
def __init__(self, reason): |
|
super().__init__() |
|
self.invalidation_reason = reason |
|
|
|
def populate_diff_guard_manager(self): |
|
self.diff_guard_root = None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class CheckFunctionManager: |
|
def __init__( |
|
self, |
|
f_code, |
|
output_graph=None, |
|
cache_entry=None, |
|
guard_fail_fn: Optional[Callable[[GuardFail], None]] = None, |
|
): |
|
guards = output_graph.guards if output_graph else None |
|
self._weakrefs: dict[int, ReferenceType[object]] = {} |
|
|
|
existing_diff_guard_sources = ( |
|
update_diff_guard_managers_for_existing_cache_entries(cache_entry) |
|
) |
|
self.guard_manager = GuardManagerWrapper() |
|
self.guard_manager.diff_guard_sources = existing_diff_guard_sources |
|
self.output_graph = output_graph |
|
w_builder = None |
|
|
|
|
|
|
|
self.torch_function_mode_stack = ( |
|
output_graph.torch_function_mode_stack if output_graph else None |
|
) |
|
|
|
def source_ref(source): |
|
guard_source = source.guard_source() |
|
if guard_source is GuardSource.CONSTANT: |
|
|
|
return source.name() |
|
assert w_builder |
|
r_builder = w_builder() |
|
assert r_builder is not None |
|
return r_builder.arg_ref(source.name()) |
|
|
|
builder = GuardBuilder( |
|
f_code, |
|
self.id_ref, |
|
source_ref, |
|
self.lookup_weakrefs, |
|
output_graph.local_scope, |
|
output_graph.global_scope, |
|
self.guard_manager, |
|
self, |
|
) |
|
|
|
|
|
def cleanup_builder(weak_b): |
|
b = weak_b() |
|
if b: |
|
b.scope = None |
|
|
|
|
|
w_builder = weakref.ref(builder, cleanup_builder) |
|
|
|
guard_on_nn_modules = config.guard_nn_modules and justknobs_check( |
|
"pytorch/compiler:guard_nn_modules" |
|
) |
|
|
|
if not justknobs_check("pytorch/compiler:guard_nn_modules"): |
|
log.warning("guard_nn_modules is turned off using justknobs killswitch") |
|
|
|
for guard in sorted(guards or (), key=Guard.sort_key): |
|
if ( |
|
not guard_on_nn_modules |
|
and guard.is_specialized_nn_module() |
|
|
|
|
|
and "__defaults__" not in guard.name |
|
and "__kwdefaults__" not in guard.name |
|
and (config.skip_nnmodule_hook_guards or "hooks" not in guard.name) |
|
): |
|
continue |
|
|
|
guard.create(builder) |
|
|
|
self.compile_check_fn(builder, guards, guard_fail_fn) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.guard_manager.id_matched_objs = builder.id_matched_objs |
|
|
|
guards_log.debug("%s", self.guard_manager) |
|
self.guard_manager.id_matched_objs = builder.id_matched_objs |
|
|
|
|
|
|
|
|
|
|
|
latency = 0.0 |
|
if not output_graph.export: |
|
if not self.guard_manager.check(output_graph.local_scope): |
|
reasons = get_guard_fail_reason_helper( |
|
self.guard_manager, |
|
output_graph.local_scope, |
|
CompileContext.current_compile_id(), |
|
) |
|
raise AssertionError(f"Guard check failed: {reasons}") |
|
|
|
if guard_manager_testing_hook_fn is not None: |
|
guard_manager_testing_hook_fn( |
|
self.guard_manager, output_graph.local_scope |
|
) |
|
|
|
|
|
|
|
|
|
|
|
latency = profile_guard_manager( |
|
self.guard_manager.root, output_graph.local_scope, 50 |
|
) |
|
guards_log.debug("Guard eval latency = %s us", f"{latency:.2f}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
CompileEventLogger.increment_toplevel("guard_latency_us", int(latency)) |
|
|
|
|
|
torch._logging.trace_structured( |
|
"dynamo_cpp_guards_str", |
|
payload_fn=lambda: f"{self.guard_manager}\nGuard latency = {latency:.2f} us", |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self._weakrefs.clear() |
|
self.output_graph = None |
|
|
|
def compile_check_fn(self, builder, guards_out, guard_fail_fn): |
|
|
|
largs = builder.argnames |
|
largs += ["**___kwargs_ignored"] |
|
|
|
guards_log.debug("GUARDS:") |
|
|
|
code_parts = [] |
|
verbose_code_parts = [] |
|
structured_guard_fns: list[Callable[[], dict[str, Any]]] = [] |
|
|
|
torch_function_mode_stack_check_fn = make_torch_function_mode_stack_guard( |
|
self.torch_function_mode_stack |
|
) |
|
|
|
|
|
self.guard_manager.root.add_global_state_guard(["___check_global_state()"]) |
|
|
|
self.guard_manager.root.add_torch_function_mode_stack_guard( |
|
self.torch_function_mode_stack, |
|
["___check_torch_function_mode_stack()"], |
|
) |
|
|
|
self.torch_function_mode_stack = None |
|
|
|
def add_code_part(code_part, guard, log_only=False): |
|
verbose_code_part = get_verbose_code_part(code_part, guard) |
|
guards_log.debug("%s", verbose_code_part) |
|
|
|
structured_guard_fns.append( |
|
lambda: { |
|
"code": code_part, |
|
"stack": ( |
|
structured.from_traceback(guard.stack.summary()) |
|
if guard and guard.stack |
|
else None |
|
), |
|
"user_stack": ( |
|
structured.from_traceback(guard.user_stack) |
|
if guard and guard.user_stack |
|
else None |
|
), |
|
} |
|
) |
|
|
|
if verbose_guards_log.isEnabledFor(logging.DEBUG): |
|
maybe_stack = "" |
|
maybe_user_stack = "" |
|
if guard is not None: |
|
if guard.stack: |
|
maybe_stack = f"\nStack:\n{''.join(guard.stack.format())}" |
|
if guard.user_stack: |
|
maybe_user_stack = ( |
|
f"\nUser stack:\n{''.join(guard.user_stack.format())}" |
|
) |
|
verbose_guards_log.debug( |
|
"Guard: %s%s%s", |
|
code_part, |
|
maybe_stack, |
|
maybe_user_stack, |
|
) |
|
|
|
if not log_only: |
|
code_parts.append(code_part) |
|
verbose_code_parts.append(verbose_code_part) |
|
|
|
seen = set() |
|
for gcl in builder.code: |
|
for code in gcl.code_list: |
|
if code not in seen: |
|
|
|
|
|
add_code_part(code, gcl.guard, True) |
|
seen.add(code) |
|
|
|
no_tensor_aliasing_names = builder.no_tensor_aliasing_names |
|
check_tensors_fn = None |
|
check_tensors_verbose_fn = None |
|
|
|
if len(no_tensor_aliasing_names) > 1: |
|
|
|
|
|
install_no_tensor_aliasing_guard( |
|
builder.no_tensor_aliasing_guard_managers, |
|
no_tensor_aliasing_names, |
|
["check_no_aliasing(" + ", ".join(no_tensor_aliasing_names) + ")"], |
|
) |
|
|
|
aotautograd_guards: list[GuardEnvExpr] = ( |
|
self.output_graph.tracing_context.guards_context.aotautograd_guards |
|
if self.output_graph |
|
else [] |
|
) |
|
|
|
|
|
|
|
|
|
for guard in aotautograd_guards: |
|
if isinstance(guard, DuplicateInputs): |
|
source_a = guard.input_source_a |
|
source_b = guard.input_source_b |
|
code_part = f"{source_a.name()} is {source_b.name()}" |
|
install_object_aliasing_guard( |
|
builder.get_guard_manager_from_source(source_a), |
|
builder.get_guard_manager_from_source(source_b), |
|
[code_part], |
|
) |
|
add_code_part(code_part, None, True) |
|
elif isinstance(guard, StorageOverlap): |
|
overlapping_guard_managers = [ |
|
builder.get_guard_manager_from_source(s) |
|
for s in guard.overlapping_sources |
|
] |
|
non_overlapping_guard_managers = [ |
|
builder.get_guard_manager_from_source(s) |
|
for s in guard.non_overlapping_sources |
|
] |
|
code_part = ( |
|
"""check_overlapping(""" |
|
f"""overlapping=[{", ".join(s.name() for s in guard.overlapping_sources)}], """ |
|
f"""non_overlapping=[{", ".join(s.name() for s in guard.non_overlapping_sources)}])""" |
|
) |
|
install_storage_overlapping_guard( |
|
overlapping_guard_managers, |
|
non_overlapping_guard_managers, |
|
[code_part], |
|
) |
|
add_code_part(code_part, None, True) |
|
else: |
|
raise RuntimeError(f"Unknown GuardEnvExpr: {guard}") |
|
|
|
|
|
|
|
for gcl in builder.shape_env_code: |
|
for code in gcl.code_list: |
|
|
|
|
|
add_code_part(code, gcl.guard, True) |
|
|
|
|
|
if structured_guard_fns: |
|
torch._logging.trace_structured( |
|
"dynamo_guards", payload_fn=lambda: [f() for f in structured_guard_fns] |
|
) |
|
|
|
global_state = convert_frame.initial_global_state |
|
if global_state is None: |
|
|
|
global_state = convert_frame.GlobalStateGuard() |
|
closure_vars = { |
|
"___check_tensors": check_tensors_fn, |
|
"___check_tensors_verbose": check_tensors_verbose_fn, |
|
"___check_global_state": global_state.check, |
|
"___check_torch_function_mode_stack": torch_function_mode_stack_check_fn, |
|
**SYMPY_INTERP, |
|
**_get_closure_vars(), |
|
} |
|
|
|
self.guard_manager.finalize() |
|
|
|
globals_for_guard_fn = {"G": builder.scope["G"]} |
|
|
|
|
|
assert len(code_parts) == 0 |
|
|
|
self.guard_manager.closure_vars = closure_vars |
|
self.guard_manager.args = largs |
|
self.guard_manager.populate_code_parts_for_debugging() |
|
self.guard_manager.verbose_code_parts = verbose_code_parts |
|
|
|
self.guard_manager.global_scope = globals_for_guard_fn |
|
self.guard_manager.guard_fail_fn = guard_fail_fn |
|
|
|
|
|
self.guard_manager.cache_entry = None |
|
self.guard_manager.extra_state = None |
|
self.guard_manager.no_tensor_aliasing_sources = no_tensor_aliasing_names |
|
|
|
def invalidate(self, obj_str): |
|
|
|
|
|
|
|
if ( |
|
hasattr(self, "guard_manager") |
|
and not isinstance(self.guard_manager, DeletedGuardManagerWrapper) |
|
and (cache_entry := self.guard_manager.cache_entry) is not None |
|
and (extra_state := self.guard_manager.extra_state) is not None |
|
): |
|
assert isinstance(cache_entry, CacheEntry) |
|
assert isinstance(extra_state, ExtraState) |
|
reason = f"Cache line invalidated because {obj_str} got deallocated" |
|
deleted_guard_manager = DeletedGuardManagerWrapper(reason) |
|
extra_state.invalidate(cache_entry, deleted_guard_manager) |
|
self.guard_manager = deleted_guard_manager |
|
|
|
def id_ref(self, obj, obj_str): |
|
"""add a weakref, return the id""" |
|
try: |
|
if id(obj) not in self._weakrefs: |
|
|
|
|
|
|
|
self._weakrefs[id(obj)] = weakref.ref(obj) |
|
weakref.finalize( |
|
obj, functools.partial(self.invalidate, obj_str=obj_str) |
|
) |
|
except TypeError: |
|
pass |
|
return id(obj) |
|
|
|
def lookup_weakrefs(self, obj): |
|
"""Lookup the _weakrefs created in id_ref function for ID_MATCH'd objects""" |
|
if id(obj) in self._weakrefs: |
|
return self._weakrefs[id(obj)] |
|
return None |
|
|
|
|
|
def build_guard_function(code_parts, closure_args) -> tuple[str, str]: |
|
from torch._inductor.utils import IndentedBuffer |
|
|
|
csepass = PyExprCSEPass() |
|
csepass.count(code_parts) |
|
|
|
def replace(expr: str) -> tuple[list[str], str]: |
|
return csepass.replace(expr) |
|
|
|
|
|
|
|
guard_body = IndentedBuffer() |
|
for expr in code_parts: |
|
preface, expr = replace(expr) |
|
guard_body.writelines(preface) |
|
guard_body.writeline(f"if not ({expr}):") |
|
with guard_body.indent(): |
|
guard_body.writeline("return False") |
|
|
|
|
|
guard = IndentedBuffer() |
|
guard.writeline("def guard(L):") |
|
with guard.indent(): |
|
guard.splice(guard_body) |
|
guard.writeline("return True") |
|
|
|
|
|
|
|
make_guard_fn = IndentedBuffer() |
|
make_guard_fn.writeline(f"def ___make_guard_fn({closure_args}):") |
|
with make_guard_fn.indent(): |
|
make_guard_fn.splice(guard) |
|
make_guard_fn.writeline("return guard") |
|
|
|
return guard_body.getvalue(), make_guard_fn.getvalue() |
|
|
|
|
|
def is_recompiles_enabled(): |
|
return torch._logging._internal.log_state.is_artifact_enabled("recompiles") |
|
|
|
|
|
def is_recompiles_verbose_enabled(): |
|
return torch._logging._internal.log_state.is_artifact_enabled("recompiles_verbose") |
|
|
|
|
|
|
|
def make_torch_function_mode_stack_guard(intial_stack): |
|
types = [type(x) for x in intial_stack] |
|
|
|
def check_torch_function_mode_stack(): |
|
cur_stack = get_torch_function_mode_stack() |
|
|
|
if len(cur_stack) != len(types): |
|
return False |
|
|
|
for ty, mode in zip(types, cur_stack): |
|
if ty != type(mode): |
|
return False |
|
|
|
return True |
|
|
|
return check_torch_function_mode_stack |
|
|
|
|
|
def recompilation_reason_for_no_tensor_aliasing_guard(guard_manager, scope): |
|
global_scope = dict(guard_manager.global_scope) |
|
ids_to_source = collections.defaultdict(list) |
|
for tensor_source in guard_manager.no_tensor_aliasing_sources: |
|
global_scope["__compile_source__"] = tensor_source |
|
tensor_id = id(eval(tensor_source, global_scope, scope)) |
|
ids_to_source[tensor_id].append(tensor_source) |
|
|
|
duplicate_tensors = [ |
|
f"{ids_to_source[key]}" for key in ids_to_source if len(ids_to_source[key]) > 1 |
|
] |
|
|
|
reason = ", ".join(duplicate_tensors) |
|
return [f"Duplicate tensors found: {reason}"] |
|
|
|
|
|
def strip_local_scope(s: str) -> str: |
|
""" |
|
Replace occurrences of L[...] with just the inner content. |
|
Handles both single and double quotes. |
|
|
|
This is to generate user friendly recompilation messages. |
|
""" |
|
import re |
|
|
|
pattern = r"L\[\s*['\"](.*?)['\"]\s*\]" |
|
return re.sub(pattern, r"\1", s) |
|
|
|
|
|
def get_guard_fail_reason_helper( |
|
guard_manager: GuardFn, |
|
f_locals: dict[str, object], |
|
compile_id: CompileId, |
|
) -> str: |
|
""" |
|
Return the reason why `guard_manager` failed. |
|
Updates `guard_failures` with the generated reason. |
|
Only the first failed check of guard_manager is reported. |
|
""" |
|
scope = {"L": f_locals, "G": guard_manager.global_scope["G"]} |
|
scope.update(guard_manager.closure_vars) |
|
reasons: list[str] = [] |
|
|
|
no_tensor_aliasing_check_failed = False |
|
|
|
verbose_code_parts: list[str] = [] |
|
guard_debug_info = guard_manager.check_verbose(f_locals) |
|
|
|
|
|
|
|
if not guard_debug_info.result: |
|
verbose_code_parts = guard_debug_info.verbose_code_parts |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if len(verbose_code_parts) == 1: |
|
if "Duplicate tensor found" in verbose_code_parts[0]: |
|
no_tensor_aliasing_check_failed = True |
|
else: |
|
reasons = verbose_code_parts |
|
verbose_code_parts = [] |
|
|
|
if no_tensor_aliasing_check_failed: |
|
reasons = recompilation_reason_for_no_tensor_aliasing_guard( |
|
guard_manager, scope |
|
) |
|
else: |
|
for part in verbose_code_parts: |
|
global_scope = dict(guard_manager.global_scope) |
|
global_scope["__compile_source__"] = part |
|
with report_compile_source_on_error(): |
|
try: |
|
fail_reason = eval(part, global_scope, scope) |
|
except Exception: |
|
if is_recompiles_verbose_enabled(): |
|
continue |
|
else: |
|
raise |
|
|
|
|
|
|
|
if isinstance(fail_reason, bool) and not fail_reason: |
|
fail_reason = part |
|
if isinstance(fail_reason, str): |
|
reasons.append(fail_reason) |
|
if not is_recompiles_verbose_enabled(): |
|
break |
|
|
|
reason_str = f"{compile_id}: " + "; ".join(reasons) |
|
return strip_local_scope(reason_str) |
|
|
|
|
|
def get_guard_fail_reason( |
|
guard_manager: GuardFn, |
|
code: types.CodeType, |
|
f_locals: dict[str, object], |
|
compile_id: CompileId, |
|
) -> str: |
|
if isinstance(guard_manager, DeletedGuardManagerWrapper): |
|
return f"{compile_id}: {guard_manager.invalidation_reason}" |
|
reason_str = get_guard_fail_reason_helper(guard_manager, f_locals, compile_id) |
|
guard_failures[orig_code_map[code]].append(reason_str) |
|
|
|
try: |
|
if guard_manager.guard_fail_fn is not None: |
|
guard_manager.guard_fail_fn( |
|
GuardFail(reason_str or "unknown reason", orig_code_map[code]) |
|
) |
|
except Exception: |
|
log.exception( |
|
"Failure in guard_fail_fn callback - raising here will cause a NULL Error on guard eval", |
|
) |
|
|
|
return reason_str |
|
|
|
|
|
def get_and_maybe_log_recompilation_reasons( |
|
cache_entry, frame: DynamoFrameType |
|
) -> list[str]: |
|
""" |
|
Return the list of guard failure reasons using cache_entry. |
|
Logs the recompilation reason if `recompiles` logging is enabled. |
|
Raises a RecompileError if `config.error_on_recompile` is enabled. |
|
""" |
|
reasons = [] |
|
while cache_entry is not None: |
|
reason = get_guard_fail_reason( |
|
cache_entry.guard_manager, |
|
cache_entry.code, |
|
frame.f_locals, |
|
cache_entry.compile_id, |
|
) |
|
if reason: |
|
reasons.append(reason) |
|
cache_entry = cache_entry.next |
|
|
|
code = frame.f_code |
|
|
|
|
|
do_recompiles_log = is_recompiles_enabled() or is_recompiles_verbose_enabled() |
|
|
|
if do_recompiles_log or config.error_on_recompile: |
|
if is_recompiles_verbose_enabled(): |
|
failures = "\n\n".join( |
|
f"guard {i} failures:\n" + textwrap.indent(reason, "- ") |
|
for i, reason in enumerate(reasons) |
|
) |
|
else: |
|
failures = textwrap.indent("\n".join(reasons), "- ") |
|
guard_failure_details = ( |
|
f"triggered by the following guard failure(s):\n{failures}" |
|
) |
|
message = ( |
|
f"Recompiling function {code.co_name} in {code.co_filename}:{code.co_firstlineno}\n" |
|
f"{textwrap.indent(guard_failure_details, ' ')}" |
|
) |
|
if do_recompiles_log: |
|
if is_recompiles_verbose_enabled(): |
|
recompiles_verbose_log.debug(message) |
|
else: |
|
recompiles_log.debug(message) |
|
if config.error_on_recompile: |
|
raise exc.RecompileError(message) |
|
|
|
torch._logging.trace_structured( |
|
"artifact", |
|
metadata_fn=lambda: { |
|
"name": "recompile_reasons", |
|
"encoding": "json", |
|
}, |
|
payload_fn=lambda: reasons, |
|
) |
|
|
|
return reasons |
|
|
|
|
|
def update_diff_guard_managers_for_existing_cache_entries(cache_entry): |
|
first_cache_entry = cache_entry |
|
|
|
|
|
|
|
|
|
acc_diff_guard_sources = set() |
|
while cache_entry is not None: |
|
acc_diff_guard_sources.update( |
|
cache_entry.guard_manager.collect_diff_guard_sources() |
|
) |
|
cache_entry = cache_entry.next |
|
|
|
|
|
|
|
cache_entry = first_cache_entry |
|
while cache_entry is not None: |
|
cache_entry.guard_manager.diff_guard_sources = acc_diff_guard_sources |
|
cache_entry.guard_manager.populate_diff_guard_manager() |
|
cache_entry = cache_entry.next |
|
|
|
|
|
return acc_diff_guard_sources |
|
|
|
|
|
def guard_error_hook( |
|
guard_manager: GuardFn, |
|
code: types.CodeType, |
|
f_locals: dict[str, object], |
|
index: int, |
|
last: bool, |
|
): |
|
print( |
|
f"ERROR RUNNING GUARDS {code.co_name} {code.co_filename}:{code.co_firstlineno}" |
|
) |
|
print("lambda " + ", ".join(guard_manager.args) + ":") |
|
print(" ", " and\n ".join(guard_manager.code_parts)) |
|
|
|
print(guard_manager) |
|
|
|
local_scope = {"L": f_locals, **guard_manager.closure_vars} |
|
for guard in guard_manager.code_parts: |
|
try: |
|
eval(guard, guard_manager.global_scope, local_scope) |
|
except: |
|
print(f"Malformed guard:\n{guard}") |
|
|
|
|
|
set_guard_error_hook(guard_error_hook) |
|
|
|
|
|
def unique(seq): |
|
seen = set() |
|
for x in seq: |
|
if x not in seen: |
|
yield x |
|
seen.add(x) |
|
|
|
|
|
def make_dupe_guard(obj_source, dupe_source): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if dupe_source and dupe_source != obj_source: |
|
ser_source_is_local = is_from_local_source(dupe_source) |
|
source_is_local = is_from_local_source(obj_source) |
|
if is_from_flatten_script_object_source( |
|
dupe_source |
|
) or is_from_flatten_script_object_source(obj_source): |
|
raise exc.UnsafeScriptObjectError( |
|
f"{obj_source.name()} is alising {dupe_source.name()}. This is not supported." |
|
f" Please do a clone for corresponding input." |
|
) |
|
|
|
|
|
|
|
|
|
|
|
if ser_source_is_local == source_is_local: |
|
|
|
|
|
return functools.partial(GuardBuilder.DUPLICATE_INPUT, source_b=dupe_source) |
|
return None |
|
|
|
|
|
def install_guard(*guards, skip=0): |
|
""" |
|
Add dynamo guards to the current tracing context. |
|
|
|
Args: |
|
guards: guard(s) to add |
|
skip: number of stack frames to ignore for debug stack trace |
|
""" |
|
from torch._guards import TracingContext |
|
|
|
collect_debug_stack = guards_log.isEnabledFor( |
|
logging.DEBUG |
|
) or verbose_guards_log.isEnabledFor(logging.DEBUG) |
|
add = TracingContext.get().guards_context.dynamo_guards.add |
|
for guard in guards: |
|
assert isinstance(guard, Guard) |
|
add(guard, collect_debug_stack=collect_debug_stack, skip=skip + 1) |
|
|