|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Tracks the running statistics per mini-batch instead of micro-batch.""" |
|
from typing import Optional, TypeVar, cast |
|
|
|
import torch |
|
from torch import Tensor, nn |
|
import torch.nn.functional as F |
|
from torch.nn.modules.batchnorm import _BatchNorm |
|
|
|
from .checkpoint import is_recomputing |
|
|
|
__all__ = ["DeferredBatchNorm"] |
|
|
|
|
|
TModule = TypeVar("TModule", bound=nn.Module) |
|
|
|
|
|
class DeferredBatchNorm(_BatchNorm): |
|
"""A BatchNorm layer tracks multiple micro-batches to update running |
|
statistics per mini-batch. |
|
""" |
|
|
|
sum: Tensor |
|
sum_squares: Tensor |
|
|
|
def __init__( |
|
self, |
|
num_features: int, |
|
eps: float = 1e-5, |
|
momentum: Optional[float] = 0.1, |
|
affine: bool = True, |
|
chunks: int = 1, |
|
) -> None: |
|
super().__init__(num_features, eps, momentum, affine, track_running_stats=True) |
|
|
|
self.register_buffer("sum", torch.zeros_like(self.running_mean)) |
|
self.register_buffer("sum_squares", torch.zeros_like(self.running_var)) |
|
|
|
self.counter = 0 |
|
self.tracked = 0 |
|
self.chunks = chunks |
|
|
|
def _check_input_dim(self, input: Tensor) -> None: |
|
|
|
if input.dim() <= 2: |
|
raise ValueError("expected at least 3D input (got %dD input)" % input.dim()) |
|
|
|
def _track(self, input: Tensor) -> bool: |
|
"""Tracks statistics of a micro-batch.""" |
|
|
|
dim = [0] |
|
dim.extend(range(2, input.dim())) |
|
|
|
with torch.no_grad(): |
|
self.sum += input.sum(dim) |
|
self.sum_squares += (input**2).sum(dim) |
|
|
|
size = input.size().numel() // input.size(1) |
|
self.counter += size |
|
self.tracked += 1 |
|
|
|
return self.tracked == self.chunks |
|
|
|
def _commit(self) -> None: |
|
"""Updates the running statistics of a mini-batch.""" |
|
exponential_average_factor = 0.0 |
|
self.num_batches_tracked += 1 |
|
if self.momentum is None: |
|
exponential_average_factor = 1.0 / float(self.num_batches_tracked) |
|
else: |
|
exponential_average_factor = self.momentum |
|
|
|
mean = self.sum / self.counter |
|
var = self.sum_squares / self.counter - mean**2 |
|
|
|
|
|
m = exponential_average_factor |
|
|
|
self.running_mean *= 1 - m |
|
self.running_mean += mean * m |
|
|
|
self.running_var *= 1 - m |
|
self.running_var += var * m |
|
|
|
self.sum.zero_() |
|
self.sum_squares.zero_() |
|
self.counter = 0 |
|
self.tracked = 0 |
|
|
|
def forward(self, input: Tensor) -> Tensor: |
|
if not self.training: |
|
|
|
return F.batch_norm( |
|
input, |
|
running_mean=self.running_mean, |
|
running_var=self.running_var, |
|
weight=self.weight, |
|
bias=self.bias, |
|
training=False, |
|
momentum=0.0, |
|
eps=self.eps, |
|
) |
|
|
|
if not is_recomputing(): |
|
|
|
|
|
tracked_enough = self._track(input) |
|
|
|
|
|
|
|
if tracked_enough: |
|
self._commit() |
|
|
|
|
|
return F.batch_norm( |
|
input, |
|
running_mean=None, |
|
running_var=None, |
|
weight=self.weight, |
|
bias=self.bias, |
|
training=True, |
|
momentum=0.0, |
|
eps=self.eps, |
|
) |
|
|
|
@classmethod |
|
def convert_deferred_batch_norm(cls, module: TModule, chunks: int = 1) -> TModule: |
|
"""Converts a :class:`nn.BatchNorm` or underlying |
|
:class:`nn.BatchNorm`s into :class:`DeferredBatchNorm`:: |
|
|
|
from torchvision.models.resnet import resnet101 |
|
from torchpipe.batchnorm import DeferredBatchNorm |
|
model = resnet101() |
|
model = DeferredBatchNorm.convert_deferred_batch_norm(model) |
|
|
|
""" |
|
if isinstance(module, DeferredBatchNorm) and module.chunks is chunks: |
|
return cast(TModule, module) |
|
|
|
module_output: nn.Module = module |
|
|
|
if isinstance(module, _BatchNorm) and module.track_running_stats: |
|
module_output = DeferredBatchNorm(module.num_features, module.eps, module.momentum, module.affine, chunks) |
|
if module.affine: |
|
module_output.register_parameter("weight", module.weight) |
|
module_output.register_parameter("bias", module.bias) |
|
module_output.register_buffer("running_mean", module.running_mean) |
|
module_output.register_buffer("running_var", module.running_var) |
|
module_output.register_buffer("num_batches_tracked", module.num_batches_tracked) |
|
|
|
for name, child in module.named_children(): |
|
module_output.add_module(name, cls.convert_deferred_batch_norm(child, chunks)) |
|
|
|
return cast(TModule, module_output) |
|
|