|
|
|
|
|
|
|
|
|
|
|
from typing import List |
|
|
|
import torch |
|
from torch import Tensor, nn |
|
from torch.nn.modules.batchnorm import _BatchNorm |
|
|
|
|
|
def patch_batchnorm(module: nn.Module) -> List: |
|
"""Patch all batchnorm instances (1d, 2d, 3d, sync_bn, etc.) of a module |
|
so that they don't track running stats when torch.no_grad() is enabled. |
|
|
|
This is important in activation checkpointing to ensure stats are tracked |
|
correctly as if there were no activation checkpointing. The reason is |
|
that activation checkpointing runs the forward function twice, first |
|
with torch.no_grad(), then with torch.grad(). |
|
|
|
Args: |
|
module (nn.Module): |
|
The module to be patched in-place. |
|
|
|
Returns: |
|
(list): |
|
A list of hook handles, late can be freed. |
|
""" |
|
|
|
def pre_forward(module: _BatchNorm, input: Tensor) -> None: |
|
if torch.is_grad_enabled(): |
|
return |
|
module._track_running_stats_backup = module.track_running_stats |
|
module.track_running_stats = False |
|
|
|
def post_forward(module: _BatchNorm, input: Tensor, result: Tensor) -> None: |
|
if torch.is_grad_enabled(): |
|
return |
|
module.track_running_stats = module._track_running_stats_backup |
|
|
|
hooks = [] |
|
for name, child in module.named_modules(): |
|
|
|
if isinstance(child, _BatchNorm) and not hasattr(child, "disable_patch_batchnorm"): |
|
|
|
pre_handle = child.register_forward_pre_hook(pre_forward) |
|
post_handle = child.register_forward_hook(post_forward) |
|
hooks += [pre_handle, post_handle] |
|
return hooks |
|
|