File size: 10,274 Bytes
9c6594c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 |
# mypy: allow-untyped-defs
from types import TracebackType
from typing import Optional
import tempfile
import traceback
import contextlib
import inspect
import os.path
# This file contains utilities for ensuring dynamically compile()'d
# code fragments display their line numbers in backtraces.
#
# The constraints:
#
# - We don't have control over the user exception printer (in particular,
# we cannot assume the linecache trick will work, c.f.
# https://stackoverflow.com/q/50515651/23845 )
#
# - We don't want to create temporary files every time we compile()
# some code; file creation should happen lazily only at exception
# time. Arguably, you *should* be willing to write out your
# generated Python code to file system, but in some situations
# (esp. library code) it would violate user expectation to write
# to the file system, so we try to avoid it. In particular, we'd
# like to keep the files around, so users can open up the files
# mentioned in the trace; if the file is invisible, we want to
# avoid clogging up the filesystem.
#
# If this is not a constraint for you, there is a substantially simpler
# way to implement the functionality in this PR: instead of using
# eval/exec directly, just always write a Python file to filesystem
# and compile that.
#
# - You have control over a context where the compiled code will get
# executed, so that we can interpose while the stack is unwinding
# (otherwise, we have no way to interpose on the exception printing
# process.)
#
# There are two things you have to do to make use of the utilities here:
#
# - When you compile your source code, you must save its string source
# in its f_globals under the magic name "__compile_source__"
#
# - Before running the compiled code, enter the
# report_compile_source_on_error() context manager.
@contextlib.contextmanager
def report_compile_source_on_error():
try:
yield
except Exception as exc:
tb = exc.__traceback__
# Walk the traceback, looking for frames that have
# source attached
stack = []
while tb is not None:
filename = tb.tb_frame.f_code.co_filename
source = tb.tb_frame.f_globals.get("__compile_source__")
if filename == "<string>" and source is not None:
# What black magic are we doing here? Intuitively, what
# we would like to do is overwrite the co_filename on any
# frames that were generated from exec/eval so that they
# point to a temporary file that has the actual line
# information, so Python's default error printer can print
# useful line information on it.
#
# Writing out the temporary file is easy. But overwriting
# co_filename is not! You can't modify the code object
# associated with a frame. You can, however, reconstruct
# a traceback with entirely new frames from scratch, so that's
# what we do. But there's another problem, which is how to
# make the frame?
#
# The black magic is we make a frankenstein frame and code
# object which resembles the original frame/code enough so
# that it will print properly under traceback and the default
# error printer, but IT IS NOT THE ORIGINAL FRAME (you
# couldn't, e.g., execute its code with different variables
# and expect it to work.)
# Don't delete the temporary file so the user can inspect it
# TODO: This creates a temporary file for every frame, but we
# technically only need one per distinct __compile_source__
with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix=".py") as f:
f.write(source)
# Create a frame. Python doesn't let you construct
# FrameType directly, so just make one with compile
frame = tb.tb_frame
code = compile('__inspect_currentframe()', f.name, 'eval')
code = code.replace(co_name=frame.f_code.co_name)
# Python 3.11 only
if hasattr(frame.f_code, 'co_linetable'):
# We can't copy ALL of the metadata over, because you
# can cause Python to segfault this way. What exactly
# do we need? We need enough information for
# traceback to be able to print the exception
# correctly. Code reading Lib/traceback.py reveals
# that traceback calls code.co_positions() in order to
# get the augmented line/col numbers. Objects/codeobject.c,
# specifically _PyCode_InitAddressRange, reveals that
# this iterator is initialized from co_linetable and
# co_firstfileno. So copy these we must!
code = code.replace( # type: ignore[call-arg]
co_linetable=frame.f_code.co_linetable, # type: ignore[attr-defined]
co_firstlineno=frame.f_code.co_firstlineno, # type: ignore[attr-defined]
)
fake_frame = eval(
code,
frame.f_globals,
{
**frame.f_locals,
'__inspect_currentframe': inspect.currentframe
}
)
fake_tb = TracebackType(
None, fake_frame, tb.tb_lasti, tb.tb_lineno
)
stack.append(fake_tb)
else:
stack.append(tb)
tb = tb.tb_next
# Reconstruct the linked list
tb_next = None
for tb in reversed(stack):
tb.tb_next = tb_next
tb_next = tb
raise exc.with_traceback(tb_next) # noqa: B904
def shorten_filename(fn, *, base=None):
"""Shorten a source filepath, with the assumption that torch/ subdirectories don't need to be shown to user."""
if base is None:
base = os.path.dirname(os.path.dirname(__file__))
# Truncate torch/foo.py to foo.py
try:
prefix = os.path.commonpath([fn, base])
except ValueError:
return fn
else:
return fn[len(prefix) + 1:]
def format_frame(frame, *, base=None, line=False):
"""
Format a FrameSummary in a short way, without printing full absolute path or code.
The idea is the result fits on a single line.
"""
extra_line = ""
if line:
extra_line = f"{frame.line} # "
return f"{extra_line}{shorten_filename(frame.filename, base=base)}:{frame.lineno} in {frame.name}"
def format_traceback_short(tb):
"""Format a TracebackType in a short way, printing only the inner-most frame."""
return format_frame(traceback.extract_tb(tb)[-1])
class CapturedTraceback:
__slots__ = ['tb', 'skip']
def __init__(self, tb, skip=0):
self.tb = tb
self.skip = skip
def cleanup(self):
self.tb = None
def summary(self):
import torch._C._profiler
if self.tb is None:
# TODO: Maybe indicate that the traceback was elided?
return traceback.StackSummary()
return _extract_symbolized_tb(
torch._C._profiler.symbolize_tracebacks([self.tb])[0],
self.skip
)
def __getstate__(self):
return (None, {
'tb': None, # TB is not pickleable
'skip': self.skip,
})
@staticmethod
def extract(*, script=False, cpp=False, skip=0):
"""
Like traceback.extract_stack(), but faster (approximately 20x faster); it
is fast enough that you can unconditionally log stacks this way as part of
normal execution. It returns a torch._C._profiler.CapturedTraceback
object that must be formatted specially with format_captured_tb.
By default, this only reports Python backtraces (like extract_stack). You
can set the script/cpp kwargs to also turn on TorchScript/C++ trace
reporting.
"""
import torch._C._profiler
if script or cpp:
assert skip == 0, "skip with script/cpp NYI"
return CapturedTraceback(
torch._C._profiler.gather_traceback(python=True, script=script, cpp=cpp),
# Elide extract() frame if we don't have script/cpp frames. If
# we do have those frames, it doesn't work so force zero.
0 if script or cpp else skip + 1
)
def format(self):
"""
Formats a single torch._C._profiler.CapturedTraceback into a list of
strings equivalent to the output of traceback.format_list. Note that if
pass it CapturedTraceback with C++ traces, it is better not to use this
function and use the batch formatting API format_captured_tbs to amortize
the cost of symbolization
"""
return traceback.format_list(self.summary())
@staticmethod
def format_all(tbs):
"""
Bulk version of CapturedTraceback.format. Returns a list of list of strings.
"""
import torch._C._profiler
# Directly populate tracebacks that already have cached summaries
rs: list[Optional[list[str]]] = []
delayed_idxs = []
for i, tb in enumerate(tbs):
if tb.tb is None:
rs.append([])
else:
rs.append(None)
delayed_idxs.append(i)
torch._C._profiler.symbolize_tracebacks([tbs[i].tb for i in delayed_idxs])
for i in delayed_idxs:
rs[i] = traceback.format_list(tbs[i].summary())
return rs
def _extract_symbolized_tb(tb, skip):
"""
Given a symbolized traceback from symbolize_tracebacks, return a StackSummary object of
pre-processed stack trace entries.
"""
stack = traceback.StackSummary()
for f in reversed(tb[skip:]):
stack.append(traceback.FrameSummary(f['filename'], f['line'], f['name']))
return stack
|