|
|
|
import logging |
|
import weakref |
|
from typing import TYPE_CHECKING |
|
|
|
import torch |
|
from torch.autograd.graph import register_multi_grad_hook |
|
from torch.nn.modules.module import ( |
|
register_module_forward_hook, |
|
register_module_forward_pre_hook, |
|
) |
|
from torch.utils._pytree import tree_flatten |
|
|
|
|
|
if TYPE_CHECKING: |
|
from torch.utils.hooks import RemovableHandle |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
__all__ = ["ModuleTracker"] |
|
|
|
|
|
class ModuleTracker: |
|
""" |
|
``ModuleTracker`` is a context manager that tracks the nn.Module hierarchy during execution |
|
so that other system can query which Module is currently being executed (or its backward is being |
|
executed). |
|
|
|
You can access the ``parents`` attribute on this context manager to get the set of all the |
|
Modules currently being executed via their fqn (fully qualified name, also used as the key within |
|
the state_dict). |
|
You can access the ``is_bw`` attribute to know if you are currently running in backward or not. |
|
|
|
Note that ``parents`` is never empty and always contains the "Global" key. The ``is_bw`` flag |
|
will remain ``True`` after the forward until another Module is executed. If you need it to be |
|
more accurate, please submit an issue requesting this. Adding a map from fqn to the module instance |
|
is possible but not done yet, please submit an issue requesting this if you need it. |
|
|
|
Example usage |
|
|
|
.. code-block:: python |
|
|
|
mod = torch.nn.Linear(2, 2) |
|
|
|
with ModuleTracker() as tracker: |
|
# Access anything during the forward pass |
|
def my_linear(m1, m2, bias): |
|
print(f"Current modules: {tracker.parents}") |
|
return torch.mm(m1, m2.t()) + bias |
|
torch.nn.functional.linear = my_linear |
|
|
|
mod(torch.rand(2, 2)) |
|
|
|
""" |
|
|
|
parents: set[str] |
|
""" |
|
A Set containing the fqn for each module currently running their forward |
|
""" |
|
|
|
def __init__(self) -> None: |
|
self.parents = {"Global"} |
|
self._known_modules: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary() |
|
self._seen_modules: weakref.WeakSet = weakref.WeakSet() |
|
self._has_callback = False |
|
self._hooks: list[RemovableHandle] = [] |
|
|
|
def _maybe_set_engine_callback(self): |
|
|
|
if self._has_callback: |
|
return |
|
|
|
def callback(): |
|
self.parents = {"Global"} |
|
self._has_callback = False |
|
|
|
torch.autograd.Variable._execution_engine.queue_callback(callback) |
|
self._has_callback = True |
|
|
|
@property |
|
def is_bw(self): |
|
""" |
|
A boolean marking if this is currently running during the backward pass or not |
|
""" |
|
return torch._C._current_graph_task_id() != -1 |
|
|
|
def _get_mod_name(self, mod): |
|
if mod not in self._known_modules: |
|
self._known_modules[mod] = type(mod).__name__ |
|
mod_name = self._known_modules[mod] |
|
if mod not in self._seen_modules: |
|
for name, submod in mod.named_children(): |
|
self._known_modules[submod] = f"{mod_name}.{name}" |
|
self._get_mod_name(submod) |
|
self._seen_modules.add(mod) |
|
return mod_name |
|
|
|
def _get_append_fn(self, name, is_bw): |
|
def fn(*args): |
|
if is_bw: |
|
self._maybe_set_engine_callback() |
|
if name in self.parents: |
|
logger.info( |
|
"The module hierarchy tracking seems to be broken as this Module was already entered. %s during %s", |
|
name, |
|
"backward" if is_bw else "forward", |
|
) |
|
self.parents.add(name) |
|
|
|
return fn |
|
|
|
def _get_pop_fn(self, name, is_bw): |
|
def fn(*args): |
|
if name in self.parents: |
|
self.parents.remove(name) |
|
else: |
|
logger.info( |
|
"The Module hierarchy tracking is confused as we're exiting a Module that was never entered. %s during %s", |
|
name, |
|
"backward" if is_bw else "forward", |
|
) |
|
|
|
return fn |
|
|
|
def _fw_pre_hook(self, mod, input): |
|
name = self._get_mod_name(mod) |
|
self._get_append_fn(name, False)() |
|
|
|
args, _ = tree_flatten(input) |
|
tensors = [a for a in args if isinstance(a, torch.Tensor) and a.requires_grad] |
|
if tensors: |
|
self._hooks.append( |
|
register_multi_grad_hook(tensors, self._get_pop_fn(name, True)) |
|
) |
|
|
|
def _fw_post_hook(self, mod, input, output): |
|
name = self._get_mod_name(mod) |
|
self._get_pop_fn(name, False)() |
|
|
|
args, _ = tree_flatten(output) |
|
tensors = [a for a in args if isinstance(a, torch.Tensor) and a.requires_grad] |
|
if tensors: |
|
self._hooks.append( |
|
register_multi_grad_hook(tensors, self._get_append_fn(name, True)) |
|
) |
|
|
|
def __enter__(self): |
|
self._fw_pre_handle = register_module_forward_pre_hook(self._fw_pre_hook) |
|
self._fw_post_handle = register_module_forward_hook(self._fw_post_hook) |
|
return self |
|
|
|
def __exit__(self, *args): |
|
self._fw_pre_handle.remove() |
|
self._fw_post_handle.remove() |
|
for hook in self._hooks: |
|
hook.remove() |
|
self._hooks.clear() |
|
|