|
|
|
r"""Implementation for Stochastic Weight Averaging implementation.""" |
|
import itertools |
|
import math |
|
import warnings |
|
from collections.abc import Iterable |
|
from copy import deepcopy |
|
from typing import Any, Callable, Literal, Optional, Union |
|
|
|
import torch |
|
from torch import Tensor |
|
from torch.nn import Module |
|
from torch.optim.lr_scheduler import _format_param, LRScheduler |
|
from torch.utils._foreach_utils import _get_foreach_kernels_supported_devices |
|
|
|
from .optimizer import Optimizer |
|
|
|
|
|
__all__ = [ |
|
"AveragedModel", |
|
"update_bn", |
|
"SWALR", |
|
"get_ema_multi_avg_fn", |
|
"get_swa_multi_avg_fn", |
|
"get_ema_avg_fn", |
|
"get_swa_avg_fn", |
|
] |
|
|
|
from torch.utils._foreach_utils import _group_tensors_by_device_and_dtype |
|
|
|
|
|
PARAM_LIST = Union[tuple[Tensor, ...], list[Tensor]] |
|
|
|
|
|
def get_ema_multi_avg_fn(decay=0.999): |
|
"""Get the function applying exponential moving average (EMA) across multiple params.""" |
|
|
|
if decay < 0.0 or decay > 1.0: |
|
raise ValueError( |
|
f"Invalid decay value {decay} provided. Please provide a value in [0,1] range." |
|
) |
|
|
|
@torch.no_grad() |
|
def ema_update(ema_param_list: PARAM_LIST, current_param_list: PARAM_LIST, _): |
|
|
|
if torch.is_floating_point(ema_param_list[0]) or torch.is_complex( |
|
ema_param_list[0] |
|
): |
|
torch._foreach_lerp_(ema_param_list, current_param_list, 1 - decay) |
|
else: |
|
for p_ema, p_model in zip(ema_param_list, current_param_list): |
|
p_ema.copy_(p_ema * decay + p_model * (1 - decay)) |
|
|
|
return ema_update |
|
|
|
|
|
def get_swa_multi_avg_fn(): |
|
"""Get the function applying stochastic weight average (SWA) across multiple params.""" |
|
|
|
@torch.no_grad() |
|
def swa_update( |
|
averaged_param_list: PARAM_LIST, |
|
current_param_list: PARAM_LIST, |
|
num_averaged: Union[Tensor, int], |
|
): |
|
|
|
if torch.is_floating_point(averaged_param_list[0]) or torch.is_complex( |
|
averaged_param_list[0] |
|
): |
|
torch._foreach_lerp_( |
|
averaged_param_list, current_param_list, 1 / (num_averaged + 1) |
|
) |
|
else: |
|
diffs = torch._foreach_sub(current_param_list, averaged_param_list) |
|
if isinstance(num_averaged, Tensor): |
|
torch._foreach_addcdiv_( |
|
averaged_param_list, |
|
diffs, |
|
[num_averaged + 1] * len(averaged_param_list), |
|
) |
|
else: |
|
torch._foreach_add_( |
|
averaged_param_list, diffs, alpha=1.0 / (num_averaged + 1) |
|
) |
|
|
|
return swa_update |
|
|
|
|
|
def get_ema_avg_fn(decay=0.999): |
|
"""Get the function applying exponential moving average (EMA) across a single param.""" |
|
|
|
if decay < 0.0 or decay > 1.0: |
|
raise ValueError( |
|
f"Invalid decay value {decay} provided. Please provide a value in [0,1] range." |
|
) |
|
|
|
@torch.no_grad() |
|
def ema_update(ema_param: Tensor, current_param: Tensor, num_averaged): |
|
return decay * ema_param + (1 - decay) * current_param |
|
|
|
return ema_update |
|
|
|
|
|
def get_swa_avg_fn(): |
|
"""Get the function applying stochastic weight average (SWA) across a single param.""" |
|
|
|
@torch.no_grad() |
|
def swa_update( |
|
averaged_param: Tensor, current_param: Tensor, num_averaged: Union[Tensor, int] |
|
): |
|
return averaged_param + (current_param - averaged_param) / (num_averaged + 1) |
|
|
|
return swa_update |
|
|
|
|
|
class AveragedModel(Module): |
|
r"""Implements averaged model for Stochastic Weight Averaging (SWA) and Exponential Moving Average (EMA). |
|
|
|
Stochastic Weight Averaging was proposed in `Averaging Weights Leads to |
|
Wider Optima and Better Generalization`_ by Pavel Izmailov, Dmitrii |
|
Podoprikhin, Timur Garipov, Dmitry Vetrov and Andrew Gordon Wilson |
|
(UAI 2018). |
|
|
|
Exponential Moving Average is a variation of `Polyak averaging`_, |
|
but using exponential weights instead of equal weights across iterations. |
|
|
|
AveragedModel class creates a copy of the provided module :attr:`model` |
|
on the device :attr:`device` and allows to compute running averages of the |
|
parameters of the :attr:`model`. |
|
|
|
Args: |
|
model (torch.nn.Module): model to use with SWA/EMA |
|
device (torch.device, optional): if provided, the averaged model will be |
|
stored on the :attr:`device` |
|
avg_fn (function, optional): the averaging function used to update |
|
parameters; the function must take in the current value of the |
|
:class:`AveragedModel` parameter, the current value of :attr:`model` |
|
parameter, and the number of models already averaged; if None, |
|
an equally weighted average is used (default: None) |
|
multi_avg_fn (function, optional): the averaging function used to update |
|
parameters inplace; the function must take in the current values of the |
|
:class:`AveragedModel` parameters as a list, the current values of :attr:`model` |
|
parameters as a list, and the number of models already averaged; if None, |
|
an equally weighted average is used (default: None) |
|
use_buffers (bool): if ``True``, it will compute running averages for |
|
both the parameters and the buffers of the model. (default: ``False``) |
|
|
|
Example: |
|
>>> # xdoctest: +SKIP("undefined variables") |
|
>>> loader, optimizer, model, loss_fn = ... |
|
>>> swa_model = torch.optim.swa_utils.AveragedModel(model) |
|
>>> scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, |
|
>>> T_max=300) |
|
>>> swa_start = 160 |
|
>>> swa_scheduler = SWALR(optimizer, swa_lr=0.05) |
|
>>> for i in range(300): |
|
>>> for input, target in loader: |
|
>>> optimizer.zero_grad() |
|
>>> loss_fn(model(input), target).backward() |
|
>>> optimizer.step() |
|
>>> if i > swa_start: |
|
>>> swa_model.update_parameters(model) |
|
>>> swa_scheduler.step() |
|
>>> else: |
|
>>> scheduler.step() |
|
>>> |
|
>>> # Update bn statistics for the swa_model at the end |
|
>>> torch.optim.swa_utils.update_bn(loader, swa_model) |
|
|
|
You can also use custom averaging functions with the `avg_fn` or `multi_avg_fn` parameters. |
|
If no averaging function is provided, the default is to compute |
|
equally-weighted average of the weights (SWA). |
|
|
|
Example: |
|
>>> # xdoctest: +SKIP("undefined variables") |
|
>>> # Compute exponential moving averages of the weights and buffers |
|
>>> ema_model = torch.optim.swa_utils.AveragedModel(model, |
|
>>> torch.optim.swa_utils.get_ema_multi_avg_fn(0.9), use_buffers=True) |
|
|
|
.. note:: |
|
When using SWA/EMA with models containing Batch Normalization you may |
|
need to update the activation statistics for Batch Normalization. |
|
This can be done either by using the :meth:`torch.optim.swa_utils.update_bn` |
|
or by setting :attr:`use_buffers` to `True`. The first approach updates the |
|
statistics in a post-training step by passing data through the model. The |
|
second does it during the parameter update phase by averaging all buffers. |
|
Empirical evidence has shown that updating the statistics in normalization |
|
layers increases accuracy, but you may wish to empirically test which |
|
approach yields the best results in your problem. |
|
|
|
.. note:: |
|
:attr:`avg_fn` and `multi_avg_fn` are not saved in the :meth:`state_dict` of the model. |
|
|
|
.. note:: |
|
When :meth:`update_parameters` is called for the first time (i.e. |
|
:attr:`n_averaged` is `0`) the parameters of `model` are copied |
|
to the parameters of :class:`AveragedModel`. For every subsequent |
|
call of :meth:`update_parameters` the function `avg_fn` is used |
|
to update the parameters. |
|
|
|
.. _Averaging Weights Leads to Wider Optima and Better Generalization: |
|
https://arxiv.org/abs/1803.05407 |
|
.. _There Are Many Consistent Explanations of Unlabeled Data: Why You Should |
|
Average: |
|
https://arxiv.org/abs/1806.05594 |
|
.. _SWALP: Stochastic Weight Averaging in Low-Precision Training: |
|
https://arxiv.org/abs/1904.11943 |
|
.. _Stochastic Weight Averaging in Parallel: Large-Batch Training That |
|
Generalizes Well: |
|
https://arxiv.org/abs/2001.02312 |
|
.. _Polyak averaging: |
|
https://paperswithcode.com/method/polyak-averaging |
|
""" |
|
|
|
n_averaged: Tensor |
|
|
|
def __init__( |
|
self, |
|
model: Module, |
|
device: Optional[Union[int, torch.device]] = None, |
|
avg_fn: Optional[Callable[[Tensor, Tensor, Union[Tensor, int]], Tensor]] = None, |
|
multi_avg_fn: Optional[ |
|
Callable[[PARAM_LIST, PARAM_LIST, Union[Tensor, int]], None] |
|
] = None, |
|
use_buffers=False, |
|
): |
|
super().__init__() |
|
assert ( |
|
avg_fn is None or multi_avg_fn is None |
|
), "Only one of avg_fn and multi_avg_fn should be provided" |
|
self.module = deepcopy(model) |
|
if device is not None: |
|
self.module = self.module.to(device) |
|
self.register_buffer( |
|
"n_averaged", torch.tensor(0, dtype=torch.long, device=device) |
|
) |
|
self.avg_fn = avg_fn |
|
self.multi_avg_fn = multi_avg_fn |
|
self.use_buffers = use_buffers |
|
|
|
def forward(self, *args, **kwargs): |
|
"""Forward pass.""" |
|
return self.module(*args, **kwargs) |
|
|
|
def update_parameters(self, model: Module): |
|
"""Update model parameters.""" |
|
self_param = ( |
|
itertools.chain(self.module.parameters(), self.module.buffers()) |
|
if self.use_buffers |
|
else self.parameters() |
|
) |
|
model_param = ( |
|
itertools.chain(model.parameters(), model.buffers()) |
|
if self.use_buffers |
|
else model.parameters() |
|
) |
|
self_param_detached: list[Optional[Tensor]] = [] |
|
model_param_detached: list[Optional[Tensor]] = [] |
|
for p_averaged, p_model in zip(self_param, model_param): |
|
p_model_ = p_model.detach().to(p_averaged.device) |
|
self_param_detached.append(p_averaged.detach()) |
|
model_param_detached.append(p_model_) |
|
if self.n_averaged == 0: |
|
p_averaged.detach().copy_(p_model_) |
|
|
|
if self.n_averaged > 0: |
|
if self.multi_avg_fn is not None or self.avg_fn is None: |
|
grouped_tensors = _group_tensors_by_device_and_dtype( |
|
[self_param_detached, model_param_detached] |
|
) |
|
for (device, _), ( |
|
[self_params, model_params], |
|
_, |
|
) in grouped_tensors.items(): |
|
if self.multi_avg_fn: |
|
self.multi_avg_fn( |
|
self_params, model_params, self.n_averaged.to(device) |
|
) |
|
elif ( |
|
device is not None |
|
and device.type in _get_foreach_kernels_supported_devices() |
|
): |
|
multi_avg_fn = get_swa_multi_avg_fn() |
|
multi_avg_fn( |
|
self_params, model_params, self.n_averaged.to(device) |
|
) |
|
else: |
|
avg_fn = get_swa_avg_fn() |
|
n_averaged = self.n_averaged.to(device) |
|
for p_averaged, p_model in zip(self_params, model_params): |
|
p_averaged.copy_(avg_fn(p_averaged, p_model, n_averaged)) |
|
else: |
|
for p_averaged, p_model in zip( |
|
self_param_detached, model_param_detached |
|
): |
|
n_averaged = self.n_averaged.to(p_averaged.device) |
|
p_averaged.detach().copy_( |
|
self.avg_fn(p_averaged.detach(), p_model, n_averaged) |
|
) |
|
|
|
if not self.use_buffers: |
|
|
|
|
|
for b_swa, b_model in zip(self.module.buffers(), model.buffers()): |
|
b_swa.detach().copy_(b_model.detach().to(b_swa.device)) |
|
self.n_averaged += 1 |
|
|
|
|
|
@torch.no_grad() |
|
def update_bn( |
|
loader: Iterable[Any], |
|
model: Module, |
|
device: Optional[Union[int, torch.device]] = None, |
|
): |
|
r"""Update BatchNorm running_mean, running_var buffers in the model. |
|
|
|
It performs one pass over data in `loader` to estimate the activation |
|
statistics for BatchNorm layers in the model. |
|
|
|
Args: |
|
loader (torch.utils.data.DataLoader): dataset loader to compute the |
|
activation statistics on. Each data batch should be either a |
|
tensor, or a list/tuple whose first element is a tensor |
|
containing data. |
|
model (torch.nn.Module): model for which we seek to update BatchNorm |
|
statistics. |
|
device (torch.device, optional): If set, data will be transferred to |
|
:attr:`device` before being passed into :attr:`model`. |
|
|
|
Example: |
|
>>> # xdoctest: +SKIP("Undefined variables") |
|
>>> loader, model = ... |
|
>>> torch.optim.swa_utils.update_bn(loader, model) |
|
|
|
.. note:: |
|
The `update_bn` utility assumes that each data batch in :attr:`loader` |
|
is either a tensor or a list or tuple of tensors; in the latter case it |
|
is assumed that :meth:`model.forward()` should be called on the first |
|
element of the list or tuple corresponding to the data batch. |
|
""" |
|
momenta = {} |
|
for module in model.modules(): |
|
if isinstance(module, torch.nn.modules.batchnorm._BatchNorm): |
|
module.reset_running_stats() |
|
momenta[module] = module.momentum |
|
|
|
if not momenta: |
|
return |
|
|
|
was_training = model.training |
|
model.train() |
|
for module in momenta.keys(): |
|
module.momentum = None |
|
|
|
for input in loader: |
|
if isinstance(input, (list, tuple)): |
|
input = input[0] |
|
if device is not None: |
|
input = input.to(device) |
|
|
|
model(input) |
|
|
|
for bn_module in momenta.keys(): |
|
bn_module.momentum = momenta[bn_module] |
|
model.train(was_training) |
|
|
|
|
|
class SWALR(LRScheduler): |
|
r"""Anneals the learning rate in each parameter group to a fixed value. |
|
|
|
This learning rate scheduler is meant to be used with Stochastic Weight |
|
Averaging (SWA) method (see `torch.optim.swa_utils.AveragedModel`). |
|
|
|
Args: |
|
optimizer (torch.optim.Optimizer): wrapped optimizer |
|
swa_lrs (float or list): the learning rate value for all param groups |
|
together or separately for each group. |
|
annealing_epochs (int): number of epochs in the annealing phase |
|
(default: 10) |
|
annealing_strategy (str): "cos" or "linear"; specifies the annealing |
|
strategy: "cos" for cosine annealing, "linear" for linear annealing |
|
(default: "cos") |
|
last_epoch (int): the index of the last epoch (default: -1) |
|
|
|
The :class:`SWALR` scheduler can be used together with other |
|
schedulers to switch to a constant learning rate late in the training |
|
as in the example below. |
|
|
|
Example: |
|
>>> # xdoctest: +SKIP("Undefined variables") |
|
>>> loader, optimizer, model = ... |
|
>>> lr_lambda = lambda epoch: 0.9 |
|
>>> scheduler = torch.optim.lr_scheduler.MultiplicativeLR(optimizer, |
|
>>> lr_lambda=lr_lambda) |
|
>>> swa_scheduler = torch.optim.swa_utils.SWALR(optimizer, |
|
>>> anneal_strategy="linear", anneal_epochs=20, swa_lr=0.05) |
|
>>> swa_start = 160 |
|
>>> for i in range(300): |
|
>>> for input, target in loader: |
|
>>> optimizer.zero_grad() |
|
>>> loss_fn(model(input), target).backward() |
|
>>> optimizer.step() |
|
>>> if i > swa_start: |
|
>>> swa_scheduler.step() |
|
>>> else: |
|
>>> scheduler.step() |
|
|
|
.. _Averaging Weights Leads to Wider Optima and Better Generalization: |
|
https://arxiv.org/abs/1803.05407 |
|
""" |
|
|
|
def __init__( |
|
self, |
|
optimizer: Optimizer, |
|
swa_lr: float, |
|
anneal_epochs=10, |
|
anneal_strategy: Literal["cos", "linear"] = "cos", |
|
last_epoch=-1, |
|
): |
|
swa_lrs = _format_param("swa_lr", optimizer, swa_lr) |
|
for swa_lr, group in zip(swa_lrs, optimizer.param_groups): |
|
group["swa_lr"] = swa_lr |
|
if anneal_strategy not in ["cos", "linear"]: |
|
raise ValueError( |
|
"anneal_strategy must by one of 'cos' or 'linear', " |
|
f"instead got {anneal_strategy}" |
|
) |
|
elif anneal_strategy == "cos": |
|
self.anneal_func = self._cosine_anneal |
|
elif anneal_strategy == "linear": |
|
self.anneal_func = self._linear_anneal |
|
if not isinstance(anneal_epochs, int) or anneal_epochs < 0: |
|
raise ValueError( |
|
f"anneal_epochs must be equal or greater than 0, got {anneal_epochs}" |
|
) |
|
self.anneal_epochs = anneal_epochs |
|
super().__init__(optimizer, last_epoch) |
|
|
|
@staticmethod |
|
def _linear_anneal(t): |
|
return t |
|
|
|
@staticmethod |
|
def _cosine_anneal(t): |
|
return (1 - math.cos(math.pi * t)) / 2 |
|
|
|
@staticmethod |
|
def _get_initial_lr(lr, swa_lr, alpha): |
|
if alpha == 1: |
|
return swa_lr |
|
return (lr - alpha * swa_lr) / (1 - alpha) |
|
|
|
def get_lr(self): |
|
"""Get learning rate.""" |
|
|
|
|
|
if not self._get_lr_called_within_step: |
|
warnings.warn( |
|
"To get the last learning rate computed by the scheduler, " |
|
"please use `get_last_lr()`.", |
|
UserWarning, |
|
) |
|
|
|
step = self._step_count - 1 |
|
if self.anneal_epochs == 0: |
|
step = max(1, step) |
|
prev_t = max(0, min(1, (step - 1) / max(1, self.anneal_epochs))) |
|
prev_alpha = self.anneal_func(prev_t) |
|
prev_lrs = [ |
|
self._get_initial_lr(group["lr"], group["swa_lr"], prev_alpha) |
|
for group in self.optimizer.param_groups |
|
] |
|
t = max(0, min(1, step / max(1, self.anneal_epochs))) |
|
alpha = self.anneal_func(t) |
|
return [ |
|
group["swa_lr"] * alpha + lr * (1 - alpha) |
|
for group, lr in zip(self.optimizer.param_groups, prev_lrs) |
|
] |
|
|