|
|
|
from typing import Optional, Union |
|
|
|
from torch import Tensor |
|
|
|
from .adam import Adam, adam |
|
from .optimizer import ( |
|
_capturable_doc, |
|
_differentiable_doc, |
|
_foreach_doc, |
|
_fused_doc, |
|
_maximize_doc, |
|
_params_doc, |
|
ParamsT, |
|
) |
|
|
|
|
|
__all__ = ["AdamW", "adamw"] |
|
|
|
|
|
class AdamW(Adam): |
|
def __init__( |
|
self, |
|
params: ParamsT, |
|
lr: Union[float, Tensor] = 1e-3, |
|
betas: tuple[Union[float, Tensor], Union[float, Tensor]] = (0.9, 0.999), |
|
eps: float = 1e-8, |
|
weight_decay: float = 1e-2, |
|
amsgrad: bool = False, |
|
*, |
|
maximize: bool = False, |
|
foreach: Optional[bool] = None, |
|
capturable: bool = False, |
|
differentiable: bool = False, |
|
fused: Optional[bool] = None, |
|
): |
|
super().__init__( |
|
params, |
|
lr, |
|
betas, |
|
eps, |
|
weight_decay, |
|
amsgrad, |
|
foreach=foreach, |
|
maximize=maximize, |
|
capturable=capturable, |
|
differentiable=differentiable, |
|
fused=fused, |
|
decoupled_weight_decay=True, |
|
) |
|
|
|
|
|
|
|
|
|
def __setstate__(self, state): |
|
super().__setstate__(state) |
|
for group in self.param_groups: |
|
group["decoupled_weight_decay"] = True |
|
|
|
|
|
AdamW.__doc__ = ( |
|
r"""Implements AdamW algorithm, where weight decay does not accumulate in the momentum nor variance. |
|
|
|
.. math:: |
|
\begin{aligned} |
|
&\rule{110mm}{0.4pt} \\ |
|
&\textbf{input} : \gamma \text{(lr)}, \: \beta_1, \beta_2 |
|
\text{(betas)}, \: \theta_0 \text{(params)}, \: f(\theta) \text{(objective)}, |
|
\: \epsilon \text{ (epsilon)} \\ |
|
&\hspace{13mm} \lambda \text{(weight decay)}, \: \textit{amsgrad}, |
|
\: \textit{maximize} \\ |
|
&\textbf{initialize} : m_0 \leftarrow 0 \text{ (first moment)}, v_0 \leftarrow 0 |
|
\text{ ( second moment)}, \: v_0^{max}\leftarrow 0 \\[-1.ex] |
|
&\rule{110mm}{0.4pt} \\ |
|
&\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\ |
|
|
|
&\hspace{5mm}\textbf{if} \: \textit{maximize}: \\ |
|
&\hspace{10mm}g_t \leftarrow -\nabla_{\theta} f_t (\theta_{t-1}) \\ |
|
&\hspace{5mm}\textbf{else} \\ |
|
&\hspace{10mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\ |
|
&\hspace{5mm} \theta_t \leftarrow \theta_{t-1} - \gamma \lambda \theta_{t-1} \\ |
|
&\hspace{5mm}m_t \leftarrow \beta_1 m_{t-1} + (1 - \beta_1) g_t \\ |
|
&\hspace{5mm}v_t \leftarrow \beta_2 v_{t-1} + (1-\beta_2) g^2_t \\ |
|
&\hspace{5mm}\widehat{m_t} \leftarrow m_t/\big(1-\beta_1^t \big) \\ |
|
&\hspace{5mm}\textbf{if} \: amsgrad \\ |
|
&\hspace{10mm} v_t^{max} \leftarrow \mathrm{max}(v_{t-1}^{max},v_t) \\ |
|
&\hspace{10mm}\widehat{v_t} \leftarrow v_t^{max}/\big(1-\beta_2^t \big) \\ |
|
&\hspace{5mm}\textbf{else} \\ |
|
&\hspace{10mm}\widehat{v_t} \leftarrow v_t/\big(1-\beta_2^t \big) \\ |
|
&\hspace{5mm}\theta_t \leftarrow \theta_t - \gamma \widehat{m_t}/ |
|
\big(\sqrt{\widehat{v_t}} + \epsilon \big) \\ |
|
&\rule{110mm}{0.4pt} \\[-1.ex] |
|
&\bf{return} \: \theta_t \\[-1.ex] |
|
&\rule{110mm}{0.4pt} \\[-1.ex] |
|
\end{aligned} |
|
|
|
For further details regarding the algorithm we refer to `Decoupled Weight Decay Regularization`_. |
|
""" |
|
+ rf""" |
|
Args: |
|
{_params_doc} |
|
lr (float, Tensor, optional): learning rate (default: 1e-3). A tensor LR |
|
is not yet supported for all our implementations. Please use a float |
|
LR if you are not also specifying fused=True or capturable=True. |
|
betas (Tuple[float, float], optional): coefficients used for computing |
|
running averages of gradient and its square (default: (0.9, 0.999)) |
|
eps (float, optional): term added to the denominator to improve |
|
numerical stability (default: 1e-8) |
|
weight_decay (float, optional): weight decay coefficient (default: 1e-2) |
|
amsgrad (bool, optional): whether to use the AMSGrad variant of this |
|
algorithm from the paper `On the Convergence of Adam and Beyond`_ |
|
(default: False) |
|
{_maximize_doc} |
|
{_foreach_doc} |
|
{_capturable_doc} |
|
{_differentiable_doc} |
|
{_fused_doc} |
|
.. Note:: |
|
A prototype implementation of Adam and AdamW for MPS supports `torch.float32` and `torch.float16`. |
|
.. _Decoupled Weight Decay Regularization: |
|
https://arxiv.org/abs/1711.05101 |
|
.. _On the Convergence of Adam and Beyond: |
|
https://openreview.net/forum?id=ryQu7f-RZ |
|
|
|
""" |
|
) |
|
|
|
|
|
|
|
def adamw( |
|
params: list[Tensor], |
|
grads: list[Tensor], |
|
exp_avgs: list[Tensor], |
|
exp_avg_sqs: list[Tensor], |
|
max_exp_avg_sqs: list[Tensor], |
|
state_steps: list[Tensor], |
|
|
|
|
|
foreach: Optional[bool] = None, |
|
capturable: bool = False, |
|
differentiable: bool = False, |
|
fused: Optional[bool] = None, |
|
grad_scale: Optional[Tensor] = None, |
|
found_inf: Optional[Tensor] = None, |
|
has_complex: bool = False, |
|
*, |
|
amsgrad: bool, |
|
beta1: float, |
|
beta2: float, |
|
lr: Union[float, Tensor], |
|
weight_decay: float, |
|
eps: float, |
|
maximize: bool, |
|
): |
|
r"""Functional API that performs AdamW algorithm computation. |
|
|
|
See :class:`~torch.optim.AdamW` for details. |
|
""" |
|
adam( |
|
params, |
|
grads, |
|
exp_avgs, |
|
exp_avg_sqs, |
|
max_exp_avg_sqs, |
|
state_steps, |
|
foreach=foreach, |
|
capturable=capturable, |
|
differentiable=differentiable, |
|
fused=fused, |
|
grad_scale=grad_scale, |
|
found_inf=found_inf, |
|
has_complex=has_complex, |
|
amsgrad=amsgrad, |
|
beta1=beta1, |
|
beta2=beta2, |
|
lr=lr, |
|
weight_decay=weight_decay, |
|
eps=eps, |
|
maximize=maximize, |
|
decoupled_weight_decay=True, |
|
) |
|
|