File size: 1,894 Bytes
9c6594c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 |
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from typing import List
import torch
from torch import Tensor, nn
from torch.nn.modules.batchnorm import _BatchNorm
def patch_batchnorm(module: nn.Module) -> List:
"""Patch all batchnorm instances (1d, 2d, 3d, sync_bn, etc.) of a module
so that they don't track running stats when torch.no_grad() is enabled.
This is important in activation checkpointing to ensure stats are tracked
correctly as if there were no activation checkpointing. The reason is
that activation checkpointing runs the forward function twice, first
with torch.no_grad(), then with torch.grad().
Args:
module (nn.Module):
The module to be patched in-place.
Returns:
(list):
A list of hook handles, late can be freed.
"""
def pre_forward(module: _BatchNorm, input: Tensor) -> None:
if torch.is_grad_enabled():
return
module._track_running_stats_backup = module.track_running_stats
module.track_running_stats = False
def post_forward(module: _BatchNorm, input: Tensor, result: Tensor) -> None:
if torch.is_grad_enabled():
return
module.track_running_stats = module._track_running_stats_backup
hooks = []
for name, child in module.named_modules():
# _BatchNorm is base for bn1d, bn2d, bn3d and sync_bn, apex_sync_bn, etc.
if isinstance(child, _BatchNorm) and not hasattr(child, "disable_patch_batchnorm"):
# Register the pre/post hooks.
pre_handle = child.register_forward_pre_hook(pre_forward)
post_handle = child.register_forward_hook(post_forward)
hooks += [pre_handle, post_handle]
return hooks
|