File size: 5,052 Bytes
9c6594c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 |
# Copyright The Lightning AI team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import AbstractContextManager
from typing import Any, Literal, Optional
import torch
from lightning_utilities.core.apply_func import apply_to_collection
from torch import Tensor
from torch.nn import Module
from torch.optim import LBFGS, Optimizer
from typing_extensions import override
from lightning_fabric.plugins.precision.precision import Precision
from lightning_fabric.plugins.precision.utils import _convert_fp_tensor
from lightning_fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_4
from lightning_fabric.utilities.types import Optimizable
class MixedPrecision(Precision):
"""Plugin for Automatic Mixed Precision (AMP) training with ``torch.autocast``.
Args:
precision: Whether to use ``torch.float16`` (``'16-mixed'``) or ``torch.bfloat16`` (``'bf16-mixed'``).
device: The device for ``torch.autocast``.
scaler: An optional :class:`torch.cuda.amp.GradScaler` to use.
"""
def __init__(
self,
precision: Literal["16-mixed", "bf16-mixed"],
device: str,
scaler: Optional["torch.amp.GradScaler"] = None,
) -> None:
if precision not in ("16-mixed", "bf16-mixed"):
raise ValueError(
f"Passed `{type(self).__name__}(precision={precision!r})`."
" Precision must be '16-mixed' or 'bf16-mixed'."
)
self.precision = precision
if scaler is None and self.precision == "16-mixed":
scaler = torch.amp.GradScaler(device=device) if _TORCH_GREATER_EQUAL_2_4 else torch.cuda.amp.GradScaler()
if scaler is not None and self.precision == "bf16-mixed":
raise ValueError(f"`precision='bf16-mixed'` does not use a scaler, found {scaler}.")
self.device = device
self.scaler = scaler
self._desired_input_dtype = torch.bfloat16 if self.precision == "bf16-mixed" else torch.float16
@override
def forward_context(self) -> AbstractContextManager:
return torch.autocast(self.device, dtype=self._desired_input_dtype)
@override
def convert_input(self, data: Any) -> Any:
return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=self._desired_input_dtype)
@override
def convert_output(self, data: Any) -> Any:
return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=torch.get_default_dtype())
@override
def backward(self, tensor: Tensor, model: Optional[Module], *args: Any, **kwargs: Any) -> None:
if self.scaler is not None:
tensor = self.scaler.scale(tensor)
super().backward(tensor, model, *args, **kwargs)
@override
def optimizer_step(
self,
optimizer: Optimizable,
**kwargs: Any,
) -> Any:
if self.scaler is None:
# skip scaler logic, as bfloat16 does not require scaler
return super().optimizer_step(optimizer, **kwargs)
if isinstance(optimizer, LBFGS):
raise TypeError("AMP and the LBFGS optimizer are not compatible.")
# note: the scaler will skip the `optimizer.step` if nonfinite gradients are found
step_output = self.scaler.step(optimizer, **kwargs) # type: ignore[arg-type]
self.scaler.update()
return step_output
@override
def state_dict(self) -> dict[str, Any]:
if self.scaler is not None:
return self.scaler.state_dict()
return {}
@override
def load_state_dict(self, state_dict: dict[str, Any]) -> None:
if self.scaler is not None:
self.scaler.load_state_dict(state_dict)
@override
def unscale_gradients(self, optimizer: Optimizer) -> None:
scaler = self.scaler
if scaler is not None:
if _optimizer_handles_unscaling(optimizer):
raise NotImplementedError("Gradient clipping is not implemented for optimizers handling the unscaling.")
scaler.unscale_(optimizer)
def _optimizer_handles_unscaling(optimizer: Any) -> bool:
"""Determines whether a PyTorch optimizer handles unscaling gradients in the step method rather than through the
:class:`torch.cuda.amp.GradScaler`.
Since, the current implementation of this function checks a PyTorch internal variable on the optimizer, the return
value will only be reliable for built-in PyTorch optimizers.
"""
return getattr(optimizer, "_step_supports_amp_scaling", False)
|