|
from __future__ import annotations |
|
|
|
|
|
"""Exception handling and error reporting for TorchDynamo. |
|
|
|
This module provides a comprehensive set of exception classes and utilities for error |
|
handling in TorchDynamo. It includes: |
|
|
|
Base Exceptions: |
|
- TorchDynamoException: Base class for all TorchDynamo-specific exceptions |
|
- Various specialized subclasses for different error scenarios |
|
|
|
User Error Handling: |
|
- UserError: Exceptions for user-facing errors in TorchDynamo usage |
|
- UserErrorType: Enumeration of different categories of user errors |
|
- Formatted error messages with debugging information |
|
|
|
Observed Exceptions: |
|
- Classes for handling exceptions observed during tracing |
|
- Special handling for StopIteration, LookupError, etc. |
|
- Exception state management during compilation |
|
|
|
Error Formatting: |
|
- Stack trace filtering and formatting |
|
- Error message augmentation |
|
- Debugging utilities for error reporting |
|
""" |
|
|
|
import logging |
|
import os |
|
import re |
|
import textwrap |
|
import typing |
|
from enum import auto, Enum |
|
from traceback import extract_stack, format_exc, format_list, StackSummary |
|
from typing import Any, NoReturn, Optional, TYPE_CHECKING |
|
|
|
import torch._guards |
|
|
|
from . import config |
|
from .utils import counters |
|
|
|
|
|
if TYPE_CHECKING: |
|
import types |
|
|
|
from torch._guards import CompileId |
|
|
|
from .symbolic_convert import InstructionTranslatorBase |
|
from .types import DynamoFrameType |
|
|
|
|
|
def exportdb_error_message(case_name: str) -> str: |
|
return ( |
|
"For more information about this error, see: " |
|
+ "https://pytorch.org/docs/main/generated/exportdb/index.html#" |
|
+ case_name.replace("_", "-") |
|
) |
|
|
|
|
|
log = logging.getLogger(__name__) |
|
graph_breaks_log = torch._logging.getArtifactLogger(__name__, "graph_breaks") |
|
|
|
|
|
class TorchDynamoException(RuntimeError): |
|
pass |
|
|
|
|
|
class InternalTorchDynamoError(TorchDynamoException): |
|
pass |
|
|
|
|
|
class RestartAnalysis(TorchDynamoException): |
|
restart_reason: Optional[str] |
|
|
|
def __init__(self, *args: Any, restart_reason: Optional[str] = None) -> None: |
|
self.restart_reason = restart_reason |
|
super().__init__(*args) |
|
|
|
|
|
class SpeculationRestartAnalysis(RestartAnalysis): |
|
pass |
|
|
|
|
|
class UnspecializeRestartAnalysis(RestartAnalysis): |
|
pass |
|
|
|
|
|
class CompileCollectiveRestartAnalysis(RestartAnalysis): |
|
pass |
|
|
|
|
|
class TensorifyScalarRestartAnalysis(RestartAnalysis): |
|
pass |
|
|
|
|
|
class SkipFrame(TorchDynamoException): |
|
pass |
|
|
|
|
|
class TorchRuntimeError(TorchDynamoException): |
|
pass |
|
|
|
|
|
class InvalidBackend(TorchDynamoException): |
|
def __init__(self, name: str) -> None: |
|
super().__init__( |
|
f"Invalid backend: {name!r}, see `torch._dynamo.list_backends()` for available backends." |
|
) |
|
|
|
|
|
class ResetRequired(TorchDynamoException): |
|
def __init__(self) -> None: |
|
super().__init__( |
|
textwrap.dedent( |
|
""" |
|
Must call `torch._dynamo.reset()` before changing backends. Detected two calls to |
|
`torch.compile()` with a different backend compiler arguments. |
|
""" |
|
) |
|
) |
|
|
|
|
|
class ShortenTraceback(TorchDynamoException): |
|
def __init__( |
|
self, *args: Any, first_useful_frame: Optional[types.FrameType], **kwargs: Any |
|
) -> None: |
|
super().__init__(*args, **kwargs) |
|
self.first_useful_frame = first_useful_frame |
|
|
|
def remove_dynamo_frames(self) -> typing.Self: |
|
tb = self.__traceback__ |
|
if self.first_useful_frame is None or tb is None or config.verbose: |
|
return self |
|
while tb.tb_frame is not self.first_useful_frame: |
|
tb = tb.tb_next |
|
assert tb is not None, "internal error, please report a bug" |
|
return self.with_traceback(tb) |
|
|
|
|
|
class BackendCompilerFailed(ShortenTraceback): |
|
def __init__( |
|
self, |
|
backend_fn: Any, |
|
inner_exception: Exception, |
|
first_useful_frame: Optional[types.FrameType], |
|
) -> None: |
|
self.backend_name = getattr(backend_fn, "__name__", "?") |
|
self.inner_exception = inner_exception |
|
msg = f"backend={self.backend_name!r} raised:\n{type(inner_exception).__name__}: {inner_exception}" |
|
super().__init__(msg, first_useful_frame=first_useful_frame) |
|
|
|
|
|
class Unsupported(TorchDynamoException): |
|
def __init__(self, msg: str, *, case_name: Optional[str] = None) -> None: |
|
super().__init__(msg) |
|
self.real_stack = torch._guards.TracingContext.extract_stack() |
|
self.msg = msg |
|
self.category: Optional[str] = None |
|
self.add_to_stats() |
|
self.case_name: Optional[str] = case_name |
|
|
|
def remove_from_stats(self) -> None: |
|
assert self.category is not None |
|
counters[self.category][self.msg] -= 1 |
|
if counters[self.category][self.msg] <= 0: |
|
del counters[self.category][self.msg] |
|
|
|
def add_to_stats(self, category: str = "unimplemented") -> None: |
|
self.category = category |
|
counters[category][self.msg] += 1 |
|
|
|
|
|
class UnknownPropertiesDuringBackwardTrace(Unsupported): |
|
pass |
|
|
|
|
|
class RecompileError(TorchDynamoException): |
|
pass |
|
|
|
|
|
class ArgsMismatchError(Unsupported): |
|
def __init__(self, msg: str) -> None: |
|
super().__init__(msg) |
|
|
|
|
|
class AttributeMutationError(Unsupported): |
|
def __init__(self, msg: str) -> None: |
|
super().__init__(msg) |
|
|
|
|
|
class InfiniteGeneratorError(Unsupported): |
|
|
|
def __init__(self, msg: str) -> None: |
|
super().__init__(msg) |
|
|
|
|
|
class SideEffectsError(Unsupported): |
|
def __init__(self, msg: str) -> None: |
|
super().__init__(msg) |
|
|
|
|
|
class CondOpArgsMismatchError(ArgsMismatchError): |
|
""" |
|
Internal error from cond() due to arguments mismatch. |
|
""" |
|
|
|
def __init__(self, msg: str) -> None: |
|
super().__init__(msg) |
|
|
|
|
|
class UserErrorType(Enum): |
|
DYNAMIC_CONTROL_FLOW = auto() |
|
ANTI_PATTERN = auto() |
|
STANDARD_LIBRARY = auto() |
|
CONSTRAINT_VIOLATION = auto() |
|
DYNAMIC_DIM = auto() |
|
INVALID_INPUT = auto() |
|
INVALID_OUTPUT = auto() |
|
UNSUPPORTED_ALIASED_MUTATED_DYNAMIC_INPUTS = auto() |
|
|
|
|
|
class UserError(Unsupported): |
|
def __init__( |
|
self, error_type: UserErrorType, msg: str, case_name: Optional[str] = None |
|
) -> None: |
|
""" |
|
Type of errors that would be valid in Eager, but not supported in TorchDynamo. |
|
The error message should tell user about next actions. |
|
|
|
error_type: Type of user error |
|
msg: Actionable error message |
|
case_name: (Optional) Unique name (snake case) for the usage example in exportdb. |
|
""" |
|
if case_name is not None: |
|
assert isinstance(case_name, str) |
|
if msg.endswith("."): |
|
msg += " " |
|
else: |
|
msg += "\n" |
|
msg += exportdb_error_message(case_name) |
|
super().__init__(msg) |
|
self.error_type = error_type |
|
self.message = msg |
|
|
|
|
|
class SkipCodeRecursiveException(TorchDynamoException): |
|
pass |
|
|
|
|
|
class RecompileLimitExceeded(Unsupported): |
|
pass |
|
|
|
|
|
class UnsafeScriptObjectError(TorchDynamoException): |
|
pass |
|
|
|
|
|
class UncapturedHigherOrderOpError(TorchDynamoException): |
|
pass |
|
|
|
|
|
class IncorrectUsage(Exception): |
|
pass |
|
|
|
|
|
|
|
|
|
|
|
class FailOnRecompileLimitHit(Exception): |
|
pass |
|
|
|
|
|
class ObservedException(TorchDynamoException): |
|
|
|
pass |
|
|
|
|
|
class ObservedUserStopIteration(ObservedException): |
|
|
|
value: Optional[Any] |
|
|
|
|
|
|
|
def __init__(self, *args: Any, **kwargs: Any) -> None: |
|
super().__init__("unhandled `raise StopIteration`") |
|
if len(args) > 0: |
|
self.value = args[0] |
|
else: |
|
self.value = None |
|
|
|
|
|
class ObservedLookupError(ObservedException): |
|
|
|
pass |
|
|
|
|
|
class ObservedIndexError(ObservedLookupError): |
|
|
|
pass |
|
|
|
|
|
class ObservedKeyError(ObservedLookupError): |
|
|
|
pass |
|
|
|
|
|
class ObservedGeneratorExit(ObservedException): |
|
pass |
|
|
|
|
|
class ObservedAttributeError(ObservedException): |
|
|
|
pass |
|
|
|
|
|
class ObservedRuntimeError(ObservedException): |
|
|
|
pass |
|
|
|
|
|
class ObservedNotImplementedError(ObservedException): |
|
pass |
|
|
|
|
|
class ObservedTypeError(ObservedException): |
|
|
|
pass |
|
|
|
|
|
observed_exception_map = { |
|
StopIteration: ObservedUserStopIteration, |
|
LookupError: ObservedLookupError, |
|
IndexError: ObservedIndexError, |
|
GeneratorExit: ObservedGeneratorExit, |
|
KeyError: ObservedKeyError, |
|
AttributeError: ObservedAttributeError, |
|
RuntimeError: ObservedRuntimeError, |
|
NotImplementedError: ObservedNotImplementedError, |
|
TypeError: ObservedTypeError, |
|
} |
|
|
|
|
|
def get_dynamo_observed_exception(exc_type: type[Exception]) -> type[ObservedException]: |
|
if exc_type not in observed_exception_map: |
|
name = getattr(exc_type, "__name__", str(exc_type)) |
|
observed_exception_map[exc_type] = type( |
|
f"Observed{name}Error", (ObservedException,), {} |
|
) |
|
return observed_exception_map[exc_type] |
|
|
|
|
|
def raise_observed_exception( |
|
exc_type: type[Exception], |
|
tx: InstructionTranslatorBase, |
|
*, |
|
args: Optional[list[Any]] = None, |
|
kwargs: Optional[dict[str, Any]] = None, |
|
) -> NoReturn: |
|
from .variables import BuiltinVariable |
|
|
|
|
|
|
|
exception_vt = BuiltinVariable(exc_type).call_function(tx, args or [], kwargs or {}) |
|
tx.exn_vt_stack.set_current_exception(exception_vt) |
|
raise observed_exception_map[exc_type] |
|
|
|
|
|
def handle_observed_exception(tx: Any) -> None: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tx.exn_vt_stack.clear_current_exception() |
|
|
|
|
|
|
|
exceptions_allowed_to_be_fallback = ( |
|
torch._subclasses.fake_tensor.DataDependentOutputException, |
|
torch._subclasses.fake_tensor.DynamicOutputShapeException, |
|
torch._subclasses.fake_tensor.UnsupportedOperatorException, |
|
torch._subclasses.fake_tensor.UnsupportedFakeTensorException, |
|
) |
|
|
|
|
|
def unimplemented_with_warning( |
|
e: Exception, code: types.CodeType, msg: str |
|
) -> NoReturn: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
graph_break_msg = format_error_msg_verbose(e, code) |
|
torch._logging.trace_structured( |
|
"artifact", |
|
metadata_fn=lambda: { |
|
"name": "dynamo_graph_break_reason", |
|
"encoding": "string", |
|
}, |
|
payload_fn=lambda: graph_break_msg, |
|
) |
|
graph_breaks_log.debug("%s", graph_break_msg) |
|
log.warning(msg) |
|
unimplemented(msg, from_exc=e) |
|
|
|
|
|
_NOTHING = object() |
|
|
|
|
|
def unimplemented( |
|
msg: str, *, from_exc: Any = _NOTHING, case_name: Optional[str] = None |
|
) -> NoReturn: |
|
assert msg != os.environ.get("BREAK", False) |
|
if from_exc is not _NOTHING: |
|
raise Unsupported(msg, case_name=case_name) from from_exc |
|
raise Unsupported(msg, case_name=case_name) |
|
|
|
|
|
def unimplemented_v2_with_warning( |
|
e: Exception, |
|
code: types.CodeType, |
|
gb_type: str, |
|
context: str, |
|
explanation: str, |
|
hints: list[str], |
|
) -> NoReturn: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
graph_break_msg = format_error_msg_verbose(e, code) |
|
torch._logging.trace_structured( |
|
"artifact", |
|
metadata_fn=lambda: { |
|
"name": "dynamo_graph_break_reason", |
|
"encoding": "string", |
|
}, |
|
payload_fn=lambda: graph_break_msg, |
|
) |
|
graph_breaks_log.debug("%s", graph_break_msg) |
|
unimplemented_v2(gb_type, context, explanation, hints, from_exc=e, log_warning=True) |
|
|
|
|
|
def format_graph_break_message( |
|
gb_type: str, |
|
context: str, |
|
explanation: str, |
|
hints: list[str], |
|
) -> str: |
|
explanation = textwrap.indent(explanation, " ").lstrip() |
|
hints_str = "\n".join( |
|
" Hint: " + textwrap.indent(hint, " ").lstrip() for hint in hints |
|
) |
|
context = textwrap.indent(context, " ").lstrip() |
|
|
|
msg = f"""\ |
|
{gb_type} |
|
Explanation: {explanation} |
|
{hints_str} |
|
|
|
Developer debug context: {context} |
|
""" |
|
return msg |
|
|
|
|
|
|
|
def unimplemented_v2( |
|
gb_type: str, |
|
context: str, |
|
explanation: str, |
|
hints: list[str], |
|
*, |
|
from_exc: Any = _NOTHING, |
|
log_warning: bool = False, |
|
) -> NoReturn: |
|
""" |
|
Called within dynamo to cause a graph break. |
|
Args: |
|
gb_type: Context-free graph break type. It should be a short string without any |
|
information specific to the tracing context (i.e. no dynamically-generated strings) |
|
context: Developer context for the graph break. It can contain tracing context/dynamic strings. |
|
explanation: User-facing context-dependent explanation for the graph break. Can be dynamic. |
|
hints: List of user-facing hints for the graph break. |
|
""" |
|
|
|
msg = format_graph_break_message(gb_type, context, explanation, hints) |
|
if log_warning: |
|
log.warning(msg) |
|
if from_exc is not _NOTHING: |
|
raise Unsupported(msg) from from_exc |
|
raise Unsupported(msg) |
|
|
|
|
|
def warning(msg: str) -> None: |
|
counters["warnings"][msg] += 1 |
|
assert msg != os.environ.get("BREAK", False) |
|
|
|
|
|
|
|
|
|
class KeyErrorMsg: |
|
def __init__(self, value: Any) -> None: |
|
self.value = value |
|
|
|
def __str__(self) -> str: |
|
return str(self.value) |
|
|
|
def __repr__(self) -> str: |
|
return self.__str__() |
|
|
|
|
|
def augment_exc_message(exc: Exception, msg: str = "\n", export: bool = False) -> None: |
|
import traceback |
|
|
|
exc.innermost_user_frame_summary = None |
|
|
|
real_stack = get_real_stack(exc) |
|
if real_stack is not None and len(real_stack) > 0: |
|
exc.innermost_user_frame_summary = real_stack[-1] |
|
msg += f"\nfrom user code:\n {''.join(traceback.format_list(real_stack))}" |
|
|
|
if config.replay_record_enabled and hasattr(exc, "record_filename"): |
|
msg += ( |
|
f"\nLast frame execution written to {exc.record_filename}. To run only this frame while debugging, run\ |
|
torch._dynamo.replay('{exc.record_filename}').\n" |
|
) |
|
|
|
if not config.verbose and hasattr(exc, "real_stack"): |
|
msg += ( |
|
"\nSet TORCHDYNAMO_VERBOSE=1 for the internal stack trace " |
|
"(please do this especially if you're reporting a bug to PyTorch). " |
|
'For even more developer context, set TORCH_LOGS="+dynamo"\n' |
|
) |
|
|
|
if hasattr(exc, "inner_exception") and hasattr( |
|
exc.inner_exception, "minifier_path" |
|
): |
|
if hasattr(exc.inner_exception, "buck_command"): |
|
msg += ( |
|
f"\nMinifier script written to {exc.inner_exception.minifier_path}. Run " |
|
f"this buck command to find the smallest traced graph " |
|
f"which reproduces this error: {exc.inner_exception.buck_command}\n" |
|
) |
|
else: |
|
msg += ( |
|
f"\nMinifier script written to {exc.inner_exception.minifier_path}. Run " |
|
"this script to find the smallest traced graph which reproduces this error.\n" |
|
) |
|
|
|
old_msg = "" if len(exc.args) == 0 else str(exc.args[0]) |
|
|
|
if isinstance(exc, KeyError): |
|
exc.args = (KeyErrorMsg(old_msg + msg),) + exc.args[1:] |
|
else: |
|
new_msg = old_msg + msg |
|
exc.args = (new_msg,) + exc.args[1:] |
|
|
|
|
|
def get_exc_message( |
|
e: Exception, compile_id: CompileId |
|
) -> tuple[Optional[str], Optional[int]]: |
|
filename = None |
|
lineno = None |
|
if e.innermost_user_frame_summary is not None: |
|
filename = e.innermost_user_frame_summary.filename |
|
lineno = e.innermost_user_frame_summary.lineno |
|
e.compile_id = compile_id |
|
return filename, lineno |
|
|
|
|
|
def get_stack_above_dynamo() -> StackSummary: |
|
return filter_stack(extract_stack()) |
|
|
|
|
|
def get_real_stack( |
|
exc: Exception, frame: Optional[DynamoFrameType] = None |
|
) -> Optional[StackSummary]: |
|
real_stack = getattr(exc, "real_stack", None) |
|
if real_stack is None: |
|
return None |
|
|
|
|
|
|
|
|
|
|
|
if frame is not None: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
stack_above_dynamo = get_stack_above_dynamo() |
|
else: |
|
stack_above_dynamo = StackSummary() |
|
|
|
return StackSummary.from_list(stack_above_dynamo + real_stack) |
|
|
|
|
|
|
|
def filter_stack(stack: StackSummary) -> StackSummary: |
|
user_stack = StackSummary() |
|
for frame in stack: |
|
if frame.filename is None: |
|
continue |
|
if "convert_frame" in frame.filename: |
|
break |
|
if "eval_frame" in frame.filename or ( |
|
frame.line and "torch._dynamo.optimize(" in frame.line |
|
): |
|
continue |
|
user_stack.append(frame) |
|
|
|
return user_stack |
|
|
|
|
|
def remove_resume_prefix(name: str) -> Optional[str]: |
|
from .resume_execution import TORCH_DYNAMO_RESUME_IN_PREFIX |
|
|
|
match = re.match(f"{TORCH_DYNAMO_RESUME_IN_PREFIX}_(\\w+)_at_\\d+", name) |
|
if match: |
|
return match.group(1) |
|
return None |
|
|
|
|
|
def collapse_resume_frames(stack: StackSummary) -> StackSummary: |
|
""" |
|
When we graph break, we create a resume function and make a regular Python call |
|
to it, which gets intercepted by Dynamo. This behavior is normally shown in the |
|
traceback, which can be confusing to a user. So we can filter out resume frames |
|
for better traceback clarity. |
|
|
|
Example: |
|
File "..." line 3, in f |
|
<line 3> |
|
File "..." line 5, in torch_dynamo_resume_in_f_at_80 |
|
<line 5> |
|
File "..." line 10, in torch_dynamo_resume_in_f_at_120 |
|
<line 10> |
|
|
|
becomes |
|
File "..." line 10, in f |
|
<line 10> |
|
""" |
|
|
|
new_stack = StackSummary() |
|
for frame in stack: |
|
if frame.filename is None: |
|
continue |
|
name = remove_resume_prefix(frame.name) |
|
if new_stack and name and new_stack[-1].name == name: |
|
new_stack[-1] = frame |
|
frame.name = name |
|
else: |
|
new_stack.append(frame) |
|
|
|
return new_stack |
|
|
|
|
|
def format_error_msg_verbose( |
|
exc: Exception, |
|
code: types.CodeType, |
|
record_filename: Optional[str] = None, |
|
frame: Optional[DynamoFrameType] = None, |
|
) -> str: |
|
msg = ( |
|
f"WON'T CONVERT {code.co_name} {code.co_filename} line {code.co_firstlineno}\n" |
|
) |
|
msg += "=" * 10 + " TorchDynamo Stack Trace " + "=" * 10 + "\n" |
|
msg += format_exc() |
|
real_stack = get_real_stack(exc, frame) |
|
if real_stack is not None: |
|
msg += ( |
|
"\n" |
|
+ "=" * 10 |
|
+ " The above exception occurred while processing the following code " |
|
+ "=" * 10 |
|
+ "\n\n" |
|
) |
|
msg += "".join(format_list(real_stack)) |
|
msg += "\n" |
|
msg += "=" * 10 |
|
|
|
return msg |
|
|
|
|
|
def format_error_msg( |
|
exc: Exception, |
|
code: types.CodeType, |
|
record_filename: Optional[str] = None, |
|
frame: Optional[DynamoFrameType] = None, |
|
) -> str: |
|
if config.verbose: |
|
return format_error_msg_verbose(exc, code, record_filename, frame) |
|
return f"WON'T CONVERT {code.co_name} {code.co_filename}\ |
|
line {code.co_firstlineno} \ndue to: \n{format_exc()}" |
|
|