|
|
|
from types import TracebackType |
|
from typing import Optional |
|
import tempfile |
|
import traceback |
|
import contextlib |
|
import inspect |
|
import os.path |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@contextlib.contextmanager |
|
def report_compile_source_on_error(): |
|
try: |
|
yield |
|
except Exception as exc: |
|
tb = exc.__traceback__ |
|
|
|
|
|
|
|
stack = [] |
|
while tb is not None: |
|
filename = tb.tb_frame.f_code.co_filename |
|
source = tb.tb_frame.f_globals.get("__compile_source__") |
|
|
|
if filename == "<string>" and source is not None: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix=".py") as f: |
|
f.write(source) |
|
|
|
|
|
frame = tb.tb_frame |
|
code = compile('__inspect_currentframe()', f.name, 'eval') |
|
code = code.replace(co_name=frame.f_code.co_name) |
|
|
|
if hasattr(frame.f_code, 'co_linetable'): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
code = code.replace( |
|
co_linetable=frame.f_code.co_linetable, |
|
co_firstlineno=frame.f_code.co_firstlineno, |
|
) |
|
fake_frame = eval( |
|
code, |
|
frame.f_globals, |
|
{ |
|
**frame.f_locals, |
|
'__inspect_currentframe': inspect.currentframe |
|
} |
|
) |
|
fake_tb = TracebackType( |
|
None, fake_frame, tb.tb_lasti, tb.tb_lineno |
|
) |
|
stack.append(fake_tb) |
|
else: |
|
stack.append(tb) |
|
|
|
tb = tb.tb_next |
|
|
|
|
|
tb_next = None |
|
for tb in reversed(stack): |
|
tb.tb_next = tb_next |
|
tb_next = tb |
|
|
|
raise exc.with_traceback(tb_next) |
|
|
|
def shorten_filename(fn, *, base=None): |
|
"""Shorten a source filepath, with the assumption that torch/ subdirectories don't need to be shown to user.""" |
|
if base is None: |
|
base = os.path.dirname(os.path.dirname(__file__)) |
|
|
|
try: |
|
prefix = os.path.commonpath([fn, base]) |
|
except ValueError: |
|
return fn |
|
else: |
|
return fn[len(prefix) + 1:] |
|
|
|
def format_frame(frame, *, base=None, line=False): |
|
""" |
|
Format a FrameSummary in a short way, without printing full absolute path or code. |
|
|
|
The idea is the result fits on a single line. |
|
""" |
|
extra_line = "" |
|
if line: |
|
extra_line = f"{frame.line} # " |
|
return f"{extra_line}{shorten_filename(frame.filename, base=base)}:{frame.lineno} in {frame.name}" |
|
|
|
def format_traceback_short(tb): |
|
"""Format a TracebackType in a short way, printing only the inner-most frame.""" |
|
return format_frame(traceback.extract_tb(tb)[-1]) |
|
|
|
class CapturedTraceback: |
|
__slots__ = ['tb', 'skip'] |
|
|
|
def __init__(self, tb, skip=0): |
|
self.tb = tb |
|
self.skip = skip |
|
|
|
def cleanup(self): |
|
self.tb = None |
|
|
|
def summary(self): |
|
import torch._C._profiler |
|
|
|
if self.tb is None: |
|
|
|
return traceback.StackSummary() |
|
|
|
return _extract_symbolized_tb( |
|
torch._C._profiler.symbolize_tracebacks([self.tb])[0], |
|
self.skip |
|
) |
|
|
|
def __getstate__(self): |
|
return (None, { |
|
'tb': None, |
|
'skip': self.skip, |
|
}) |
|
|
|
@staticmethod |
|
def extract(*, script=False, cpp=False, skip=0): |
|
""" |
|
Like traceback.extract_stack(), but faster (approximately 20x faster); it |
|
is fast enough that you can unconditionally log stacks this way as part of |
|
normal execution. It returns a torch._C._profiler.CapturedTraceback |
|
object that must be formatted specially with format_captured_tb. |
|
|
|
By default, this only reports Python backtraces (like extract_stack). You |
|
can set the script/cpp kwargs to also turn on TorchScript/C++ trace |
|
reporting. |
|
""" |
|
import torch._C._profiler |
|
|
|
if script or cpp: |
|
assert skip == 0, "skip with script/cpp NYI" |
|
|
|
return CapturedTraceback( |
|
torch._C._profiler.gather_traceback(python=True, script=script, cpp=cpp), |
|
|
|
|
|
0 if script or cpp else skip + 1 |
|
) |
|
|
|
def format(self): |
|
""" |
|
Formats a single torch._C._profiler.CapturedTraceback into a list of |
|
strings equivalent to the output of traceback.format_list. Note that if |
|
pass it CapturedTraceback with C++ traces, it is better not to use this |
|
function and use the batch formatting API format_captured_tbs to amortize |
|
the cost of symbolization |
|
""" |
|
return traceback.format_list(self.summary()) |
|
|
|
@staticmethod |
|
def format_all(tbs): |
|
""" |
|
Bulk version of CapturedTraceback.format. Returns a list of list of strings. |
|
""" |
|
import torch._C._profiler |
|
|
|
|
|
rs: list[Optional[list[str]]] = [] |
|
delayed_idxs = [] |
|
for i, tb in enumerate(tbs): |
|
if tb.tb is None: |
|
rs.append([]) |
|
else: |
|
rs.append(None) |
|
delayed_idxs.append(i) |
|
|
|
torch._C._profiler.symbolize_tracebacks([tbs[i].tb for i in delayed_idxs]) |
|
for i in delayed_idxs: |
|
rs[i] = traceback.format_list(tbs[i].summary()) |
|
|
|
return rs |
|
|
|
|
|
def _extract_symbolized_tb(tb, skip): |
|
""" |
|
Given a symbolized traceback from symbolize_tracebacks, return a StackSummary object of |
|
pre-processed stack trace entries. |
|
""" |
|
stack = traceback.StackSummary() |
|
for f in reversed(tb[skip:]): |
|
stack.append(traceback.FrameSummary(f['filename'], f['line'], f['name'])) |
|
return stack |
|
|