File size: 16,019 Bytes
9c6594c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 |
# mypy: allow-untyped-defs
from typing import cast, Optional, Union
import torch
from torch import Tensor
from .optimizer import (
_capturable_doc,
_default_to_fused_or_foreach,
_differentiable_doc,
_disable_dynamo_if_unsupported,
_foreach_doc,
_get_capturable_supported_devices,
_get_scalar_dtype,
_get_value,
_maximize_doc,
_params_doc,
_use_grad_for_differentiable,
_view_as_real,
Optimizer,
ParamsT,
)
__all__ = ["ASGD", "asgd"]
class ASGD(Optimizer):
def __init__(
self,
params: ParamsT,
lr: Union[float, Tensor] = 1e-2,
lambd: float = 1e-4,
alpha: float = 0.75,
t0: float = 1e6,
weight_decay: float = 0,
foreach: Optional[bool] = None,
maximize: bool = False,
differentiable: bool = False,
capturable: bool = False,
):
if isinstance(lr, Tensor) and lr.numel() != 1:
raise ValueError("Tensor lr must be 1-element")
if not 0.0 <= lr:
raise ValueError(f"Invalid learning rate: {lr}")
if not 0.0 <= weight_decay:
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
defaults = dict(
lr=lr,
lambd=lambd,
alpha=alpha,
t0=t0,
weight_decay=weight_decay,
foreach=foreach,
maximize=maximize,
differentiable=differentiable,
capturable=capturable,
)
super().__init__(params, defaults)
def __setstate__(self, state):
super().__setstate__(state)
for group in self.param_groups:
group.setdefault("foreach", None)
group.setdefault("maximize", False)
group.setdefault("differentiable", False)
group.setdefault("capturable", False)
for p in group["params"]:
p_state = self.state.get(p, [])
if len(p_state) != 0:
if not torch.is_tensor(p_state["step"]):
step_val = float(p_state["step"])
p_state["step"] = torch.tensor(
step_val, dtype=_get_scalar_dtype(), device=p.device
)
if not torch.is_tensor(p_state["eta"]):
p_state["eta"] = torch.tensor(
p_state["eta"], dtype=_get_scalar_dtype(), device=p.device
)
if not torch.is_tensor(p_state["mu"]):
p_state["mu"] = torch.tensor(
p_state["mu"], dtype=_get_scalar_dtype(), device=p.device
)
def _init_group(self, group, params_with_grad, grads, mus, axs, etas, state_steps):
has_complex = False
for p in group["params"]:
if p.grad is not None:
has_complex |= torch.is_complex(p)
params_with_grad.append(p)
if p.grad.is_sparse:
raise RuntimeError("ASGD does not support sparse gradients")
grads.append(p.grad)
state = self.state[p]
# State initialization
if len(state) == 0:
state["step"] = torch.zeros(
(), device=p.device, dtype=_get_scalar_dtype()
)
state["eta"] = (
torch.as_tensor(
group["lr"], device=p.device, dtype=_get_scalar_dtype()
)
.clone()
.detach()
)
state["mu"] = torch.ones(
(), device=p.device, dtype=_get_scalar_dtype()
)
state["ax"] = torch.zeros_like(
p, memory_format=torch.preserve_format
)
mus.append(state["mu"])
axs.append(state["ax"])
etas.append(state["eta"])
state_steps.append(state["step"])
return has_complex
@_use_grad_for_differentiable
def step(self, closure=None):
"""Perform a single optimization step.
Args:
closure (Callable, optional): A closure that reevaluates the model
and returns the loss.
"""
self._cuda_graph_capture_health_check()
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
params_with_grad: list[Tensor] = []
grads: list[Tensor] = []
mus: list[Tensor] = []
axs: list[Tensor] = []
etas: list[Tensor] = []
state_steps: list[Tensor] = []
has_complex = self._init_group(
group, params_with_grad, grads, mus, axs, etas, state_steps
)
asgd(
params_with_grad,
grads,
axs,
mus,
etas,
state_steps,
lambd=group["lambd"],
lr=group["lr"],
t0=group["t0"],
alpha=group["alpha"],
weight_decay=group["weight_decay"],
foreach=group["foreach"],
maximize=group["maximize"],
differentiable=group["differentiable"],
capturable=group["capturable"],
has_complex=has_complex,
)
return loss
ASGD.__doc__ = rf"""Implements Averaged Stochastic Gradient Descent.
It has been proposed in `Acceleration of stochastic approximation by
averaging`_.
Args:
{_params_doc}
lr (float, Tensor, optional): learning rate (default: 1e-2)
lambd (float, optional): decay term (default: 1e-4)
alpha (float, optional): power for eta update (default: 0.75)
t0 (float, optional): point at which to start averaging (default: 1e6)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
{_foreach_doc}
{_maximize_doc}
{_differentiable_doc}
{_capturable_doc}
.. _Acceleration of stochastic approximation by averaging:
https://dl.acm.org/citation.cfm?id=131098
"""
def _single_tensor_asgd(
params: list[Tensor],
grads: list[Tensor],
axs: list[Tensor],
mus: list[Tensor],
etas: list[Tensor],
state_steps: list[Tensor],
*,
lambd: float,
lr: float,
t0: float,
alpha: float,
weight_decay: float,
maximize: bool,
differentiable: bool,
capturable: bool,
has_complex: bool,
):
for i, param in enumerate(params):
grad = grads[i]
grad = grad if not maximize else -grad
mu = mus[i]
ax = axs[i]
eta = etas[i]
step_t = state_steps[i]
# If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
if not torch.compiler.is_compiling() and capturable:
capturable_supported_devices = _get_capturable_supported_devices()
assert (
param.device.type
== mu.device.type
== eta.device.type
== step_t.device.type
and param.device.type in capturable_supported_devices
), (
f"If capturable=True, params, mus, etas, and state_steps must be "
f"on supported devices: {capturable_supported_devices}."
)
if torch.is_complex(param):
grad = torch.view_as_real(grad)
param = torch.view_as_real(param)
ax = torch.view_as_real(ax)
# update step
step_t += 1
if weight_decay != 0:
grad = grad.add(param, alpha=weight_decay)
if capturable:
param.mul_(1 - lambd * eta)
param.addcmul_(grad, eta, value=-1) # update parameter
else:
eta_value = _get_value(eta)
param.mul_(1 - lambd * eta_value) # decay term
param.add_(grad, alpha=-eta_value) # update parameter
# averaging
if capturable or mu.item() != 1:
ax.add_(param.sub(ax).mul_(mu))
else:
ax.copy_(param)
if capturable:
eta.copy_(lr / ((1 + lambd * lr * step_t) ** alpha))
mu.copy_(1 / torch.maximum(step_t - t0, torch.ones_like(step_t)))
else:
step = _get_value(step_t)
new_eta = torch.as_tensor(lr / ((1 + lambd * lr * step) ** alpha))
eta.copy_(new_eta)
new_mu = torch.as_tensor(1 / max(1, step - t0))
mu.copy_(new_mu)
def _multi_tensor_asgd(
params: list[Tensor],
grads: list[Tensor],
axs: list[Tensor],
mus: list[Tensor],
etas: list[Tensor],
state_steps: list[Tensor],
*,
lambd: float,
lr: float,
t0: float,
alpha: float,
weight_decay: float,
maximize: bool,
differentiable: bool,
capturable: bool,
has_complex: bool,
):
if len(params) == 0:
return
assert not differentiable, "_foreach ops don't support autograd"
# If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
if not torch.compiler.is_compiling() and capturable:
capturable_supported_devices = _get_capturable_supported_devices(
supports_xla=False
)
assert all(
p.device.type == mu.device.type == eta.device.type == step.device.type
and p.device.type in capturable_supported_devices
for p, mu, eta, step in zip(params, mus, etas, state_steps)
), f"If capturable=True, params, mus, etas, and state_steps must be on supported devices: {capturable_supported_devices}."
grouped_tensors = Optimizer._group_tensors_by_device_and_dtype(
[params, grads, axs, mus, etas, state_steps] # type: ignore[list-item]
)
for (device, _), (
(
grouped_params_,
grouped_grads_,
grouped_axs_,
grouped_mus_,
grouped_etas_,
grouped_state_steps_,
),
_,
) in grouped_tensors.items():
grouped_params = cast(list[Tensor], grouped_params_)
grouped_grads = cast(list[Tensor], grouped_grads_)
grouped_axs = cast(list[Tensor], grouped_axs_)
grouped_mus = cast(list[Tensor], grouped_mus_)
grouped_etas = cast(list[Tensor], grouped_etas_)
grouped_state_steps = cast(list[Tensor], grouped_state_steps_)
if has_complex:
_view_as_real(grouped_params, grouped_grads, grouped_axs)
if maximize:
grouped_grads = torch._foreach_neg(grouped_grads) # type: ignore[assignment]
# Update steps
# If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over
# and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just
# wrapped it once now. The alpha is required to assure we go to the right overload.
if not torch.compiler.is_compiling() and grouped_state_steps[0].is_cpu:
torch._foreach_add_(
grouped_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0
)
else:
torch._foreach_add_(grouped_state_steps, 1)
# intermediate = grad + param * lambd
intermediate: Union[tuple[Tensor, ...], list[Tensor]]
if weight_decay != 0:
if maximize:
torch._foreach_add_(grouped_grads, grouped_params, alpha=weight_decay)
intermediate = grouped_grads
else:
intermediate = torch._foreach_add(
grouped_grads, grouped_params, alpha=weight_decay
)
torch._foreach_add_(intermediate, grouped_params, alpha=lambd)
else:
intermediate = torch._foreach_add(
grouped_grads, grouped_params, alpha=lambd
)
# update param
# param * (1 - lambd * eta) - eta * grad
# => param - param * lambd * eta - eta * grad
# => param - eta * intermediate
torch._foreach_addcmul_(grouped_params, intermediate, grouped_etas, value=-1)
del intermediate
# update grouped_axs
# averaging: ax = ax + mu * (param - ax)
# Note (mlazos): We can't use lerp here since it requires weight to be float64
# and our grouping code requires dtypes to match for all tensors in a group (and it should, since
# we use the mus in other places)
# all dtypes need to match, so we could introduce a cast in a loop
# but since this only adds one additional kernel launch, this looks like the cleaner
# and faster solution
intermediate = torch._foreach_sub(grouped_params, grouped_axs)
torch._foreach_addcmul_(grouped_axs, intermediate, grouped_mus)
del intermediate
new_etas: Union[tuple[Tensor, ...], list[Tensor]]
new_mus: Union[tuple[Tensor, ...], list[Tensor]]
if capturable:
# update grouped_mus
new_mus = torch._foreach_sub(grouped_state_steps, t0)
torch._foreach_maximum_(new_mus, 1.0)
torch._foreach_reciprocal_(new_mus)
torch._foreach_copy_(grouped_mus, new_mus)
del new_mus
# update eta = lr / ((1 + lambd * lr * step)^alpha)
new_etas = torch._foreach_mul(grouped_state_steps, lambd)
torch._foreach_mul_(new_etas, lr)
torch._foreach_add_(new_etas, 1)
torch._foreach_pow_(new_etas, alpha)
torch._foreach_reciprocal_(new_etas)
torch._foreach_mul_(new_etas, lr)
torch._foreach_copy_(grouped_etas, new_etas)
else:
new_etas = [
torch.as_tensor(lr / ((1 + lambd * lr * step) ** alpha), device=device)
for step in grouped_state_steps
]
new_mus = [
torch.as_tensor(1 / max(1, _get_value(step) - t0), device=device)
for step in grouped_state_steps
]
torch._foreach_copy_(grouped_etas, new_etas)
torch._foreach_copy_(grouped_mus, new_mus)
@_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_asgd)
def asgd(
params: list[Tensor],
grads: list[Tensor],
axs: list[Tensor],
mus: list[Tensor],
etas: list[Tensor],
state_steps: list[Tensor],
# kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
# setting this as kwarg for now as functional API is compiled by torch/distributed/optim
foreach: Optional[bool] = None,
maximize: bool = False,
differentiable: bool = False,
capturable: bool = False,
has_complex: bool = False,
*,
lambd: float,
lr: float,
t0: float,
alpha: float,
weight_decay: float,
):
r"""Functional API that performs asgd algorithm computation.
See :class:`~torch.optim.ASGD` for details.
"""
if foreach is None:
_, foreach = _default_to_fused_or_foreach(
params, differentiable, use_fused=False
)
if foreach and torch.jit.is_scripting():
raise RuntimeError("torch.jit.script not supported with foreach optimizers")
if foreach and not torch.jit.is_scripting():
func = _multi_tensor_asgd
else:
func = _single_tensor_asgd
func(
params,
grads,
axs,
mus,
etas,
state_steps,
lambd=lambd,
lr=lr,
t0=t0,
alpha=alpha,
weight_decay=weight_decay,
maximize=maximize,
differentiable=differentiable,
capturable=capturable,
has_complex=has_complex,
)
|