File size: 1,736 Bytes
9c6594c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 |
# mypy: allow-untyped-defs
import sys
from typing import Optional
import torch
from torch._logging import LazyString
def lazy_format_graph_code(name, gm, maybe_id=None, **kwargs):
"""
Returns a LazyString that formats the graph code.
"""
def format_name():
if maybe_id is not None:
return f"{name} {maybe_id}"
else:
return name
if "print_output" not in kwargs:
kwargs["print_output"] = False
if "colored" in kwargs:
try:
if not sys.stdout.isatty():
kwargs["colored"] = False
except AttributeError:
kwargs["colored"] = False
return LazyString(
lambda: _format_graph_code(
f"===== {format_name()} =====\n",
gm.forward.__code__.co_filename,
gm.print_readable(**kwargs),
)
)
def _format_graph_code(name, filename, graph_str):
"""
Returns a string that formats the graph code.
"""
return f"TRACED GRAPH\n {name} {filename} {graph_str}\n"
def first_call_function_nn_module_stack(graph: torch.fx.Graph) -> Optional[dict]:
"""
Returns the nn_module_stack of the first call_function node.
"""
for node in graph.nodes:
if node.op == "call_function" and "nn_module_stack" in node.meta:
return node.meta["nn_module_stack"]
return None
def get_node_context(node, num_nodes=2) -> str:
"""
Returns a string of the last num_nodes nodes in the graph.
"""
node_contexts = []
cur = node
for _ in range(num_nodes):
node_contexts.append(cur.format_node())
if cur.op == "root":
break
cur = cur.prev
return "\n".join(node_contexts[::-1])
|