|
|
|
|
|
from torch import optim |
|
|
|
from .functional_adadelta import _FunctionalAdadelta |
|
from .functional_adagrad import _FunctionalAdagrad |
|
from .functional_adam import _FunctionalAdam |
|
from .functional_adamax import _FunctionalAdamax |
|
from .functional_adamw import _FunctionalAdamW |
|
from .functional_rmsprop import _FunctionalRMSprop |
|
from .functional_rprop import _FunctionalRprop |
|
from .functional_sgd import _FunctionalSGD |
|
|
|
|
|
|
|
|
|
|
|
|
|
functional_optim_map = { |
|
optim.Adagrad: _FunctionalAdagrad, |
|
optim.Adam: _FunctionalAdam, |
|
optim.AdamW: _FunctionalAdamW, |
|
optim.SGD: _FunctionalSGD, |
|
optim.Adadelta: _FunctionalAdadelta, |
|
optim.RMSprop: _FunctionalRMSprop, |
|
optim.Rprop: _FunctionalRprop, |
|
optim.Adamax: _FunctionalAdamax, |
|
} |
|
|
|
|
|
def register_functional_optim(key, optim): |
|
""" |
|
Interface to insert a new functional optimizer to functional_optim_map |
|
``fn_optim_key`` and ``fn_optimizer`` are user defined. The optimizer and key |
|
need not be of :class:`torch.optim.Optimizer` (e.g. for custom optimizers) |
|
Example:: |
|
>>> # import the new functional optimizer |
|
>>> # xdoctest: +SKIP |
|
>>> from xyz import fn_optimizer |
|
>>> from torch.distributed.optim.utils import register_functional_optim |
|
>>> fn_optim_key = "XYZ_optim" |
|
>>> register_functional_optim(fn_optim_key, fn_optimizer) |
|
""" |
|
if key not in functional_optim_map: |
|
functional_optim_map[key] = optim |
|
|
|
|
|
def as_functional_optim(optim_cls: type, *args, **kwargs): |
|
try: |
|
functional_cls = functional_optim_map[optim_cls] |
|
except KeyError as e: |
|
raise ValueError( |
|
f"Optimizer {optim_cls} does not have a functional counterpart!" |
|
) from e |
|
|
|
return _create_functional_optim(functional_cls, *args, **kwargs) |
|
|
|
|
|
def _create_functional_optim(functional_optim_cls: type, *args, **kwargs): |
|
return functional_optim_cls( |
|
[], |
|
*args, |
|
**kwargs, |
|
_allow_empty_param_list=True, |
|
) |
|
|