File size: 7,325 Bytes
9c6594c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 |
# mypy: allow-untyped-defs
import copy
import traceback
from contextlib import contextmanager
from enum import Enum
from typing import Any, Optional, Union
from ._compatibility import compatibility
from .graph import Graph
from .node import Node
__all__ = [
"preserve_node_meta",
"has_preserved_node_meta",
"set_stack_trace",
"set_grad_fn_seq_nr",
"reset_grad_fn_seq_nr",
"format_stack",
"set_current_meta",
"get_current_meta",
"NodeSource",
"NodeSourceAction",
"get_graph_provenance_json",
]
current_meta: dict[str, Any] = {}
should_preserve_node_meta = False
@compatibility(is_backward_compatible=False)
class NodeSourceAction(Enum):
CREATE = "create"
REPLACE = "replace"
@compatibility(is_backward_compatible=False)
class NodeSource:
"""
NodeSource is a data structure that contains the provenance information of a node.
If node `a` is created from node `b`, then `a.meta["from_node"]` may contain NodeSource(b).
"""
class NodeInfo:
def __init__(self, name: str, target: str, graph_id: int):
self.name = name
self.target = target
self.graph_id = graph_id
pass_name: str
action: list["NodeSourceAction"]
from_node: list["NodeSource"]
node_info: Optional["NodeInfo"]
def __init__(
self,
node: Optional[Node],
pass_name: str = "",
action: Optional[Union["NodeSourceAction", list["NodeSourceAction"]]] = None,
):
self.pass_name = pass_name
if action is None:
action = []
elif not isinstance(action, list):
action = [action]
for a in action:
assert isinstance(a, NodeSourceAction)
self.action = action
if node:
self.node_info = self.NodeInfo(
name=node.name, target=str(node.target), graph_id=id(node.graph)
)
self.from_node = (
copy.deepcopy(node.meta["from_node"])
if "from_node" in node.meta
else []
)
else:
self.node_info = None
self.from_node = []
@property
def name(self) -> str:
return self.node_info.name if self.node_info else ""
@property
def target(self) -> str:
return self.node_info.target if self.node_info else ""
@property
def graph_id(self) -> int:
return self.node_info.graph_id if self.node_info else -1
def __repr__(self):
return self.print_readable()
def _get_action_string(self):
return "+".join([a.name.lower() for a in self.action])
def print_readable(self, indent=0):
if indent > 9:
return ""
result = ""
action_string = self._get_action_string()
result += (
" " * indent * 4
+ f"(name={self.name}, pass_name={self.pass_name}, action={action_string}, graph_id={self.graph_id})\n"
)
for item in self.from_node:
result += item.print_readable(indent + 1)
return result
def to_dict(self) -> dict:
# Convert the object to a dictionary
action_string = self._get_action_string()
return {
"name": self.name,
"target": self.target,
"graph_id": self.graph_id,
"pass_name": self.pass_name,
"action": action_string,
"from_node": [node.to_dict() for node in self.from_node],
}
@compatibility(is_backward_compatible=False)
@contextmanager
def preserve_node_meta(enable=True):
global should_preserve_node_meta
global current_meta
# If enable is False, this context manager is a no-op
if not enable:
yield
else:
saved_should_preserve_node_meta = should_preserve_node_meta
# Shallow copy is OK since fields of current_meta are not mutated
saved_current_meta = current_meta.copy()
try:
should_preserve_node_meta = True
yield
finally:
should_preserve_node_meta = saved_should_preserve_node_meta
current_meta = saved_current_meta
@compatibility(is_backward_compatible=False)
def set_stack_trace(stack: list[str]):
global current_meta
if should_preserve_node_meta and stack:
current_meta["stack_trace"] = "".join(stack)
@compatibility(is_backward_compatible=False)
def set_grad_fn_seq_nr(seq_nr):
global current_meta
if should_preserve_node_meta:
# The seq_nr is captured by eager mode in the grad_fn during forward
current_meta["grad_fn_seq_nr"] = current_meta.get("grad_fn_seq_nr", []) + [
seq_nr
]
current_meta["in_grad_fn"] = current_meta.get("in_grad_fn", 0) + 1
@compatibility(is_backward_compatible=False)
def reset_grad_fn_seq_nr():
# NB: reset state properly, this would be helpful towards supporting
# reentrant autograd if we actually wanted to do that.
global current_meta
if should_preserve_node_meta:
current_level = current_meta.get("in_grad_fn", 0)
assert current_level > 0
if current_level == 1:
del current_meta["in_grad_fn"]
del current_meta["grad_fn_seq_nr"]
else:
current_meta["in_grad_fn"] = current_level - 1
current_meta["grad_fn_seq_nr"] = current_meta["grad_fn_seq_nr"][:-1]
@compatibility(is_backward_compatible=False)
def format_stack() -> list[str]:
if should_preserve_node_meta:
return [current_meta.get("stack_trace", "")]
else:
# fallback to traceback.format_stack()
return traceback.format_list(traceback.extract_stack()[:-1])
@compatibility(is_backward_compatible=False)
def has_preserved_node_meta() -> bool:
return should_preserve_node_meta
@compatibility(is_backward_compatible=False)
@contextmanager
def set_current_meta(node, pass_name=""):
global current_meta
if should_preserve_node_meta and node.meta:
saved_meta = current_meta
try:
current_meta = node.meta.copy()
# Update the "from_node" field in current_meta for provenance tracking.
# Instead of appending, overwrite the "from_node" field because current_meta
# will be assigned to the new node. The new NodeSource(node, ...) will
# include the information from the previous current_meta["from_node"].
current_meta["from_node"] = [
NodeSource(node, pass_name, NodeSourceAction.CREATE)
]
yield
finally:
current_meta = saved_meta
else:
yield
@compatibility(is_backward_compatible=False)
def get_current_meta() -> dict[str, Any]:
return current_meta
@compatibility(is_backward_compatible=False)
def get_graph_provenance_json(graph: Graph) -> dict[str, Any]:
"""
Given an fx.Graph, return a json that contains the provenance information of each node.
"""
provenance_tracking_json = {}
for node in graph.nodes:
if node.op == "call_function":
provenance_tracking_json[node.name] = (
[source.to_dict() for source in node.meta["from_node"]]
if "from_node" in node.meta
else []
)
return provenance_tracking_json
|