|
import copy |
|
import inspect |
|
import math |
|
import re |
|
import warnings |
|
from collections import OrderedDict |
|
from copy import deepcopy |
|
from itertools import chain |
|
from typing import Any, Callable, Dict, List, Optional, Tuple, Union |
|
|
|
import torch |
|
import torchvision |
|
from torch import fx, nn |
|
from torch.fx.graph_module import _CodeOnlyModule, _copy_attr, _USER_PRESERVED_ATTRIBUTES_KEY |
|
|
|
|
|
__all__ = ["create_feature_extractor", "get_graph_node_names"] |
|
|
|
|
|
class LeafModuleAwareTracer(fx.Tracer): |
|
""" |
|
An fx.Tracer that allows the user to specify a set of leaf modules, i.e. |
|
modules that are not to be traced through. The resulting graph ends up |
|
having single nodes referencing calls to the leaf modules' forward methods. |
|
""" |
|
|
|
def __init__(self, *args, **kwargs): |
|
self.leaf_modules = {} |
|
if "leaf_modules" in kwargs: |
|
leaf_modules = kwargs.pop("leaf_modules") |
|
self.leaf_modules = leaf_modules |
|
super().__init__(*args, **kwargs) |
|
|
|
def is_leaf_module(self, m: nn.Module, module_qualname: str) -> bool: |
|
if isinstance(m, tuple(self.leaf_modules)): |
|
return True |
|
return super().is_leaf_module(m, module_qualname) |
|
|
|
|
|
class NodePathTracer(LeafModuleAwareTracer): |
|
""" |
|
NodePathTracer is an FX tracer that, for each operation, also records the |
|
name of the Node from which the operation originated. A node name here is |
|
a `.` separated path walking the hierarchy from top level module down to |
|
leaf operation or leaf module. The name of the top level module is not |
|
included as part of the node name. For example, if we trace a module whose |
|
forward method applies a ReLU module, the name for that node will simply |
|
be 'relu'. |
|
|
|
Some notes on the specifics: |
|
- Nodes are recorded to `self.node_to_qualname` which is a dictionary |
|
mapping a given Node object to its node name. |
|
- Nodes are recorded in the order which they are executed during |
|
tracing. |
|
- When a duplicate node name is encountered, a suffix of the form |
|
_{int} is added. The counter starts from 1. |
|
""" |
|
|
|
def __init__(self, *args, **kwargs): |
|
super().__init__(*args, **kwargs) |
|
|
|
self.current_module_qualname = "" |
|
|
|
|
|
|
|
|
|
self.node_to_qualname = OrderedDict() |
|
|
|
def call_module(self, m: torch.nn.Module, forward: Callable, args, kwargs): |
|
""" |
|
Override of `fx.Tracer.call_module` |
|
This override: |
|
1) Stores away the qualified name of the caller for restoration later |
|
2) Adds the qualified name of the caller to |
|
`current_module_qualname` for retrieval by `create_proxy` |
|
3) Once a leaf module is reached, calls `create_proxy` |
|
4) Restores the caller's qualified name into current_module_qualname |
|
""" |
|
old_qualname = self.current_module_qualname |
|
try: |
|
module_qualname = self.path_of_module(m) |
|
self.current_module_qualname = module_qualname |
|
if not self.is_leaf_module(m, module_qualname): |
|
out = forward(*args, **kwargs) |
|
return out |
|
return self.create_proxy("call_module", module_qualname, args, kwargs) |
|
finally: |
|
self.current_module_qualname = old_qualname |
|
|
|
def create_proxy( |
|
self, kind: str, target: fx.node.Target, args, kwargs, name=None, type_expr=None, *_ |
|
) -> fx.proxy.Proxy: |
|
""" |
|
Override of `Tracer.create_proxy`. This override intercepts the recording |
|
of every operation and stores away the current traced module's qualified |
|
name in `node_to_qualname` |
|
""" |
|
proxy = super().create_proxy(kind, target, args, kwargs, name, type_expr) |
|
self.node_to_qualname[proxy.node] = self._get_node_qualname(self.current_module_qualname, proxy.node) |
|
return proxy |
|
|
|
def _get_node_qualname(self, module_qualname: str, node: fx.node.Node) -> str: |
|
node_qualname = module_qualname |
|
|
|
if node.op != "call_module": |
|
|
|
|
|
if len(node_qualname) > 0: |
|
|
|
node_qualname += "." |
|
node_qualname += str(node) |
|
|
|
|
|
|
|
|
|
|
|
|
|
if re.match(r".+_[0-9]+$", node_qualname) is not None: |
|
node_qualname = node_qualname.rsplit("_", 1)[0] |
|
|
|
|
|
for existing_qualname in reversed(self.node_to_qualname.values()): |
|
|
|
|
|
if re.match(rf"{node_qualname}(_[0-9]+)?$", existing_qualname) is not None: |
|
postfix = existing_qualname.replace(node_qualname, "") |
|
if len(postfix): |
|
|
|
next_index = int(postfix[1:]) + 1 |
|
else: |
|
|
|
next_index = 1 |
|
node_qualname += f"_{next_index}" |
|
break |
|
|
|
return node_qualname |
|
|
|
|
|
def _is_subseq(x, y): |
|
"""Check if y is a subsequence of x |
|
https://stackoverflow.com/a/24017747/4391249 |
|
""" |
|
iter_x = iter(x) |
|
return all(any(x_item == y_item for x_item in iter_x) for y_item in y) |
|
|
|
|
|
def _warn_graph_differences(train_tracer: NodePathTracer, eval_tracer: NodePathTracer): |
|
""" |
|
Utility function for warning the user if there are differences between |
|
the train graph nodes and the eval graph nodes. |
|
""" |
|
train_nodes = list(train_tracer.node_to_qualname.values()) |
|
eval_nodes = list(eval_tracer.node_to_qualname.values()) |
|
|
|
if len(train_nodes) == len(eval_nodes) and all(t == e for t, e in zip(train_nodes, eval_nodes)): |
|
return |
|
|
|
suggestion_msg = ( |
|
"When choosing nodes for feature extraction, you may need to specify " |
|
"output nodes for train and eval mode separately." |
|
) |
|
|
|
if _is_subseq(train_nodes, eval_nodes): |
|
msg = ( |
|
"NOTE: The nodes obtained by tracing the model in eval mode " |
|
"are a subsequence of those obtained in train mode. " |
|
) |
|
elif _is_subseq(eval_nodes, train_nodes): |
|
msg = ( |
|
"NOTE: The nodes obtained by tracing the model in train mode " |
|
"are a subsequence of those obtained in eval mode. " |
|
) |
|
else: |
|
msg = "The nodes obtained by tracing the model in train mode are different to those obtained in eval mode. " |
|
warnings.warn(msg + suggestion_msg) |
|
|
|
|
|
def _get_leaf_modules_for_ops() -> List[type]: |
|
members = inspect.getmembers(torchvision.ops) |
|
result = [] |
|
for _, obj in members: |
|
if inspect.isclass(obj) and issubclass(obj, torch.nn.Module): |
|
result.append(obj) |
|
return result |
|
|
|
|
|
def _set_default_tracer_kwargs(original_tr_kwargs: Optional[Dict[str, Any]]) -> Dict[str, Any]: |
|
default_autowrap_modules = (math, torchvision.ops) |
|
default_leaf_modules = _get_leaf_modules_for_ops() |
|
result_tracer_kwargs = {} if original_tr_kwargs is None else original_tr_kwargs |
|
result_tracer_kwargs["autowrap_modules"] = ( |
|
tuple(set(result_tracer_kwargs["autowrap_modules"] + default_autowrap_modules)) |
|
if "autowrap_modules" in result_tracer_kwargs |
|
else default_autowrap_modules |
|
) |
|
result_tracer_kwargs["leaf_modules"] = ( |
|
list(set(result_tracer_kwargs["leaf_modules"] + default_leaf_modules)) |
|
if "leaf_modules" in result_tracer_kwargs |
|
else default_leaf_modules |
|
) |
|
return result_tracer_kwargs |
|
|
|
|
|
def get_graph_node_names( |
|
model: nn.Module, |
|
tracer_kwargs: Optional[Dict[str, Any]] = None, |
|
suppress_diff_warning: bool = False, |
|
concrete_args: Optional[Dict[str, Any]] = None, |
|
) -> Tuple[List[str], List[str]]: |
|
""" |
|
Dev utility to return node names in order of execution. See note on node |
|
names under :func:`create_feature_extractor`. Useful for seeing which node |
|
names are available for feature extraction. There are two reasons that |
|
node names can't easily be read directly from the code for a model: |
|
|
|
1. Not all submodules are traced through. Modules from ``torch.nn`` all |
|
fall within this category. |
|
2. Nodes representing the repeated application of the same operation |
|
or leaf module get a ``_{counter}`` postfix. |
|
|
|
The model is traced twice: once in train mode, and once in eval mode. Both |
|
sets of node names are returned. |
|
|
|
For more details on the node naming conventions used here, please see the |
|
:ref:`relevant subheading <about-node-names>` in the |
|
`documentation <https://pytorch.org/vision/stable/feature_extraction.html>`_. |
|
|
|
Args: |
|
model (nn.Module): model for which we'd like to print node names |
|
tracer_kwargs (dict, optional): a dictionary of keyword arguments for |
|
``NodePathTracer`` (they are eventually passed onto |
|
`torch.fx.Tracer <https://pytorch.org/docs/stable/fx.html#torch.fx.Tracer>`_). |
|
By default, it will be set to wrap and make leaf nodes all torchvision ops: |
|
{"autowrap_modules": (math, torchvision.ops,),"leaf_modules": _get_leaf_modules_for_ops(),} |
|
WARNING: In case the user provides tracer_kwargs, above default arguments will be appended to the user |
|
provided dictionary. |
|
suppress_diff_warning (bool, optional): whether to suppress a warning |
|
when there are discrepancies between the train and eval version of |
|
the graph. Defaults to False. |
|
concrete_args (Optional[Dict[str, any]]): Concrete arguments that should |
|
not be treated as Proxies. According to the `Pytorch docs |
|
<https://pytorch.org/docs/stable/fx.html#torch.fx.Tracer.trace>`_, |
|
this parameter's API may not be guaranteed. |
|
|
|
Returns: |
|
tuple(list, list): a list of node names from tracing the model in |
|
train mode, and another from tracing the model in eval mode. |
|
|
|
Examples:: |
|
|
|
>>> model = torchvision.models.resnet18() |
|
>>> train_nodes, eval_nodes = get_graph_node_names(model) |
|
""" |
|
tracer_kwargs = _set_default_tracer_kwargs(tracer_kwargs) |
|
is_training = model.training |
|
train_tracer = NodePathTracer(**tracer_kwargs) |
|
train_tracer.trace(model.train(), concrete_args=concrete_args) |
|
eval_tracer = NodePathTracer(**tracer_kwargs) |
|
eval_tracer.trace(model.eval(), concrete_args=concrete_args) |
|
train_nodes = list(train_tracer.node_to_qualname.values()) |
|
eval_nodes = list(eval_tracer.node_to_qualname.values()) |
|
if not suppress_diff_warning: |
|
_warn_graph_differences(train_tracer, eval_tracer) |
|
|
|
model.train(is_training) |
|
return train_nodes, eval_nodes |
|
|
|
|
|
class DualGraphModule(fx.GraphModule): |
|
""" |
|
A derivative of `fx.GraphModule`. Differs in the following ways: |
|
- Requires a train and eval version of the underlying graph |
|
- Copies submodules according to the nodes of both train and eval graphs. |
|
- Calling train(mode) switches between train graph and eval graph. |
|
""" |
|
|
|
def __init__( |
|
self, root: torch.nn.Module, train_graph: fx.Graph, eval_graph: fx.Graph, class_name: str = "GraphModule" |
|
): |
|
""" |
|
Args: |
|
root (nn.Module): module from which the copied module hierarchy is |
|
built |
|
train_graph (fx.Graph): the graph that should be used in train mode |
|
eval_graph (fx.Graph): the graph that should be used in eval mode |
|
""" |
|
super(fx.GraphModule, self).__init__() |
|
|
|
self.__class__.__name__ = class_name |
|
|
|
self.train_graph = train_graph |
|
self.eval_graph = eval_graph |
|
|
|
|
|
|
|
for node in chain(iter(train_graph.nodes), iter(eval_graph.nodes)): |
|
if node.op in ["get_attr", "call_module"]: |
|
if not isinstance(node.target, str): |
|
raise TypeError(f"node.target should be of type str instead of {type(node.target)}") |
|
_copy_attr(root, self, node.target) |
|
|
|
|
|
self.train() |
|
self.graph = train_graph |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.eval_graph._tracer_cls != self.train_graph._tracer_cls: |
|
raise TypeError( |
|
f"Train mode and eval mode should use the same tracer class. Instead got {self.eval_graph._tracer_cls} for eval vs {self.train_graph._tracer_cls} for train" |
|
) |
|
self._tracer_cls = None |
|
if self.graph._tracer_cls and "<locals>" not in self.graph._tracer_cls.__qualname__: |
|
self._tracer_cls = self.graph._tracer_cls |
|
|
|
def train(self, mode=True): |
|
""" |
|
Swap out the graph depending on the selected training mode. |
|
NOTE this should be safe when calling model.eval() because that just |
|
calls this with mode == False. |
|
""" |
|
|
|
|
|
if mode and not self.training: |
|
self.graph = self.train_graph |
|
elif not mode and self.training: |
|
self.graph = self.eval_graph |
|
return super().train(mode=mode) |
|
|
|
def _deepcopy_init(self): |
|
|
|
return DualGraphModule.__init__ |
|
|
|
def __deepcopy__(self, memo): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
res = type(self).__new__(type(self)) |
|
memo[id(self)] = res |
|
fake_mod = _CodeOnlyModule(copy.deepcopy(self.__dict__, memo)) |
|
self._deepcopy_init()(res, fake_mod, fake_mod.__dict__["train_graph"], fake_mod.__dict__["eval_graph"]) |
|
|
|
extra_preserved_attrs = [ |
|
"_state_dict_hooks", |
|
"_load_state_dict_pre_hooks", |
|
"_load_state_dict_post_hooks", |
|
"_replace_hook", |
|
"_create_node_hooks", |
|
"_erase_node_hooks", |
|
] |
|
for attr in extra_preserved_attrs: |
|
if attr in self.__dict__: |
|
setattr(res, attr, copy.deepcopy(self.__dict__[attr], memo)) |
|
res.meta = copy.deepcopy(getattr(self, "meta", {}), memo) |
|
if _USER_PRESERVED_ATTRIBUTES_KEY in res.meta: |
|
for attr_name, attr in res.meta[_USER_PRESERVED_ATTRIBUTES_KEY].items(): |
|
setattr(res, attr_name, attr) |
|
return res |
|
|
|
|
|
def create_feature_extractor( |
|
model: nn.Module, |
|
return_nodes: Optional[Union[List[str], Dict[str, str]]] = None, |
|
train_return_nodes: Optional[Union[List[str], Dict[str, str]]] = None, |
|
eval_return_nodes: Optional[Union[List[str], Dict[str, str]]] = None, |
|
tracer_kwargs: Optional[Dict[str, Any]] = None, |
|
suppress_diff_warning: bool = False, |
|
concrete_args: Optional[Dict[str, Any]] = None, |
|
) -> fx.GraphModule: |
|
""" |
|
Creates a new graph module that returns intermediate nodes from a given |
|
model as dictionary with user specified keys as strings, and the requested |
|
outputs as values. This is achieved by re-writing the computation graph of |
|
the model via FX to return the desired nodes as outputs. All unused nodes |
|
are removed, together with their corresponding parameters. |
|
|
|
Desired output nodes must be specified as a ``.`` separated |
|
path walking the module hierarchy from top level module down to leaf |
|
operation or leaf module. For more details on the node naming conventions |
|
used here, please see the :ref:`relevant subheading <about-node-names>` |
|
in the `documentation <https://pytorch.org/vision/stable/feature_extraction.html>`_. |
|
|
|
Not all models will be FX traceable, although with some massaging they can |
|
be made to cooperate. Here's a (not exhaustive) list of tips: |
|
|
|
- If you don't need to trace through a particular, problematic |
|
sub-module, turn it into a "leaf module" by passing a list of |
|
``leaf_modules`` as one of the ``tracer_kwargs`` (see example below). |
|
It will not be traced through, but rather, the resulting graph will |
|
hold a reference to that module's forward method. |
|
- Likewise, you may turn functions into leaf functions by passing a |
|
list of ``autowrap_functions`` as one of the ``tracer_kwargs`` (see |
|
example below). |
|
- Some inbuilt Python functions can be problematic. For instance, |
|
``int`` will raise an error during tracing. You may wrap them in your |
|
own function and then pass that in ``autowrap_functions`` as one of |
|
the ``tracer_kwargs``. |
|
|
|
For further information on FX see the |
|
`torch.fx documentation <https://pytorch.org/docs/stable/fx.html>`_. |
|
|
|
Args: |
|
model (nn.Module): model on which we will extract the features |
|
return_nodes (list or dict, optional): either a ``List`` or a ``Dict`` |
|
containing the names (or partial names - see note above) |
|
of the nodes for which the activations will be returned. If it is |
|
a ``Dict``, the keys are the node names, and the values |
|
are the user-specified keys for the graph module's returned |
|
dictionary. If it is a ``List``, it is treated as a ``Dict`` mapping |
|
node specification strings directly to output names. In the case |
|
that ``train_return_nodes`` and ``eval_return_nodes`` are specified, |
|
this should not be specified. |
|
train_return_nodes (list or dict, optional): similar to |
|
``return_nodes``. This can be used if the return nodes |
|
for train mode are different than those from eval mode. |
|
If this is specified, ``eval_return_nodes`` must also be specified, |
|
and ``return_nodes`` should not be specified. |
|
eval_return_nodes (list or dict, optional): similar to |
|
``return_nodes``. This can be used if the return nodes |
|
for train mode are different than those from eval mode. |
|
If this is specified, ``train_return_nodes`` must also be specified, |
|
and `return_nodes` should not be specified. |
|
tracer_kwargs (dict, optional): a dictionary of keyword arguments for |
|
``NodePathTracer`` (which passes them onto it's parent class |
|
`torch.fx.Tracer <https://pytorch.org/docs/stable/fx.html#torch.fx.Tracer>`_). |
|
By default, it will be set to wrap and make leaf nodes all torchvision ops: |
|
{"autowrap_modules": (math, torchvision.ops,),"leaf_modules": _get_leaf_modules_for_ops(),} |
|
WARNING: In case the user provides tracer_kwargs, above default arguments will be appended to the user |
|
provided dictionary. |
|
suppress_diff_warning (bool, optional): whether to suppress a warning |
|
when there are discrepancies between the train and eval version of |
|
the graph. Defaults to False. |
|
concrete_args (Optional[Dict[str, any]]): Concrete arguments that should |
|
not be treated as Proxies. According to the `Pytorch docs |
|
<https://pytorch.org/docs/stable/fx.html#torch.fx.Tracer.trace>`_, |
|
this parameter's API may not be guaranteed. |
|
|
|
Examples:: |
|
|
|
>>> # Feature extraction with resnet |
|
>>> model = torchvision.models.resnet18() |
|
>>> # extract layer1 and layer3, giving as names `feat1` and feat2` |
|
>>> model = create_feature_extractor( |
|
>>> model, {'layer1': 'feat1', 'layer3': 'feat2'}) |
|
>>> out = model(torch.rand(1, 3, 224, 224)) |
|
>>> print([(k, v.shape) for k, v in out.items()]) |
|
>>> [('feat1', torch.Size([1, 64, 56, 56])), |
|
>>> ('feat2', torch.Size([1, 256, 14, 14]))] |
|
|
|
>>> # Specifying leaf modules and leaf functions |
|
>>> def leaf_function(x): |
|
>>> # This would raise a TypeError if traced through |
|
>>> return int(x) |
|
>>> |
|
>>> class LeafModule(torch.nn.Module): |
|
>>> def forward(self, x): |
|
>>> # This would raise a TypeError if traced through |
|
>>> int(x.shape[0]) |
|
>>> return torch.nn.functional.relu(x + 4) |
|
>>> |
|
>>> class MyModule(torch.nn.Module): |
|
>>> def __init__(self): |
|
>>> super().__init__() |
|
>>> self.conv = torch.nn.Conv2d(3, 1, 3) |
|
>>> self.leaf_module = LeafModule() |
|
>>> |
|
>>> def forward(self, x): |
|
>>> leaf_function(x.shape[0]) |
|
>>> x = self.conv(x) |
|
>>> return self.leaf_module(x) |
|
>>> |
|
>>> model = create_feature_extractor( |
|
>>> MyModule(), return_nodes=['leaf_module'], |
|
>>> tracer_kwargs={'leaf_modules': [LeafModule], |
|
>>> 'autowrap_functions': [leaf_function]}) |
|
|
|
""" |
|
tracer_kwargs = _set_default_tracer_kwargs(tracer_kwargs) |
|
is_training = model.training |
|
|
|
if all(arg is None for arg in [return_nodes, train_return_nodes, eval_return_nodes]): |
|
|
|
raise ValueError( |
|
"Either `return_nodes` or `train_return_nodes` and `eval_return_nodes` together, should be specified" |
|
) |
|
|
|
if (train_return_nodes is None) ^ (eval_return_nodes is None): |
|
raise ValueError( |
|
"If any of `train_return_nodes` and `eval_return_nodes` are specified, then both should be specified" |
|
) |
|
|
|
if not ((return_nodes is None) ^ (train_return_nodes is None)): |
|
raise ValueError("If `train_return_nodes` and `eval_return_nodes` are specified, then both should be specified") |
|
|
|
|
|
def to_strdict(n) -> Dict[str, str]: |
|
if isinstance(n, list): |
|
return {str(i): str(i) for i in n} |
|
return {str(k): str(v) for k, v in n.items()} |
|
|
|
if train_return_nodes is None: |
|
return_nodes = to_strdict(return_nodes) |
|
train_return_nodes = deepcopy(return_nodes) |
|
eval_return_nodes = deepcopy(return_nodes) |
|
else: |
|
train_return_nodes = to_strdict(train_return_nodes) |
|
eval_return_nodes = to_strdict(eval_return_nodes) |
|
|
|
|
|
tracers = {} |
|
graphs = {} |
|
mode_return_nodes: Dict[str, Dict[str, str]] = {"train": train_return_nodes, "eval": eval_return_nodes} |
|
for mode in ["train", "eval"]: |
|
if mode == "train": |
|
model.train() |
|
elif mode == "eval": |
|
model.eval() |
|
|
|
|
|
tracer = NodePathTracer(**tracer_kwargs) |
|
graph = tracer.trace(model, concrete_args=concrete_args) |
|
|
|
name = model.__class__.__name__ if isinstance(model, nn.Module) else model.__name__ |
|
graph_module = fx.GraphModule(tracer.root, graph, name) |
|
|
|
available_nodes = list(tracer.node_to_qualname.values()) |
|
|
|
if len(set(available_nodes)) != len(available_nodes): |
|
raise ValueError( |
|
"There are duplicate nodes! Please raise an issue https://github.com/pytorch/vision/issues" |
|
) |
|
|
|
for query in mode_return_nodes[mode].keys(): |
|
|
|
|
|
if not any([re.match(rf"^{query}(\.|$)", n) is not None for n in available_nodes]): |
|
raise ValueError( |
|
f"node: '{query}' is not present in model. Hint: use " |
|
"`get_graph_node_names` to make sure the " |
|
"`return_nodes` you specified are present. It may even " |
|
"be that you need to specify `train_return_nodes` and " |
|
"`eval_return_nodes` separately." |
|
) |
|
|
|
|
|
orig_output_nodes = [] |
|
for n in reversed(graph_module.graph.nodes): |
|
if n.op == "output": |
|
orig_output_nodes.append(n) |
|
if not orig_output_nodes: |
|
raise ValueError("No output nodes found in graph_module.graph.nodes") |
|
|
|
for n in orig_output_nodes: |
|
graph_module.graph.erase_node(n) |
|
|
|
|
|
nodes = [n for n in graph_module.graph.nodes] |
|
output_nodes = OrderedDict() |
|
for n in reversed(nodes): |
|
module_qualname = tracer.node_to_qualname.get(n) |
|
if module_qualname is None: |
|
|
|
|
|
|
|
|
|
continue |
|
for query in mode_return_nodes[mode]: |
|
depth = query.count(".") |
|
if ".".join(module_qualname.split(".")[: depth + 1]) == query: |
|
output_nodes[mode_return_nodes[mode][query]] = n |
|
mode_return_nodes[mode].pop(query) |
|
break |
|
output_nodes = OrderedDict(reversed(list(output_nodes.items()))) |
|
|
|
|
|
with graph_module.graph.inserting_after(nodes[-1]): |
|
graph_module.graph.output(output_nodes) |
|
|
|
|
|
graph_module.graph.eliminate_dead_code() |
|
graph_module.recompile() |
|
|
|
|
|
tracers[mode] = tracer |
|
graphs[mode] = graph |
|
|
|
|
|
|
|
if not suppress_diff_warning: |
|
_warn_graph_differences(tracers["train"], tracers["eval"]) |
|
|
|
|
|
graph_module = DualGraphModule(model, graphs["train"], graphs["eval"], class_name=name) |
|
|
|
|
|
model.train(is_training) |
|
graph_module.train(is_training) |
|
|
|
return graph_module |
|
|