|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from contextlib import AbstractContextManager |
|
from typing import TYPE_CHECKING, Any, Literal, Optional |
|
|
|
import torch |
|
from lightning_utilities import apply_to_collection |
|
from torch import Tensor |
|
from torch.nn import Module |
|
from torch.optim import Optimizer |
|
from typing_extensions import get_args, override |
|
|
|
from lightning_fabric.plugins.precision.amp import _optimizer_handles_unscaling |
|
from lightning_fabric.plugins.precision.precision import Precision |
|
from lightning_fabric.plugins.precision.utils import _convert_fp_tensor, _DtypeContextManager |
|
from lightning_fabric.utilities.types import Optimizable |
|
|
|
if TYPE_CHECKING: |
|
from torch.distributed.fsdp.fully_sharded_data_parallel import MixedPrecision as TorchMixedPrecision |
|
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler |
|
|
|
_PRECISION_INPUT = Literal["32-true", "16-true", "bf16-true", "16-mixed", "bf16-mixed"] |
|
|
|
|
|
class FSDPPrecision(Precision): |
|
"""Precision plugin for training with Fully Sharded Data Parallel (FSDP). |
|
|
|
.. warning:: This is an :ref:`experimental <versioning:Experimental API>` feature. |
|
|
|
Args: |
|
precision: Full precision (32-true), half precision (16-true, bf16-true) or |
|
mixed precision (16-mixed, bf16-mixed). |
|
scaler: An optional :class:`torch.distributed.fsdp.sharded_grad_scaler.ShardedGradScaler` to use. |
|
|
|
Raises: |
|
ValueError: |
|
If unsupported ``precision`` is provided. |
|
|
|
""" |
|
|
|
def __init__(self, precision: _PRECISION_INPUT, scaler: Optional["ShardedGradScaler"] = None) -> None: |
|
supported_precision = get_args(_PRECISION_INPUT) |
|
if precision not in supported_precision: |
|
raise ValueError( |
|
f"`precision={precision!r})` is not supported in FSDP." |
|
f" `precision` must be one of: {supported_precision}." |
|
) |
|
|
|
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler |
|
|
|
if scaler is not None and self.precision != "16-mixed": |
|
raise ValueError(f"`precision={precision!r}` does not use a scaler, found {scaler}.") |
|
|
|
self.scaler = ShardedGradScaler() if scaler is None and precision == "16-mixed" else None |
|
self.precision = precision |
|
|
|
precision_to_type = { |
|
"bf16-mixed": torch.float32, |
|
"16-mixed": torch.float32, |
|
"bf16-true": torch.bfloat16, |
|
"16-true": torch.float16, |
|
"32-true": torch.float32, |
|
} |
|
self._desired_input_dtype = precision_to_type[self.precision] |
|
|
|
@override |
|
def convert_module(self, module: Module) -> Module: |
|
if "true" in self.precision: |
|
return module.to(dtype=self._desired_input_dtype) |
|
return module |
|
|
|
@property |
|
def mixed_precision_config(self) -> "TorchMixedPrecision": |
|
from torch.distributed.fsdp.fully_sharded_data_parallel import MixedPrecision as TorchMixedPrecision |
|
|
|
if self.precision == "16-mixed": |
|
param_dtype = torch.float32 |
|
reduce_dtype = buffer_dtype = torch.float16 |
|
elif self.precision == "bf16-mixed": |
|
param_dtype = torch.float32 |
|
reduce_dtype = buffer_dtype = torch.bfloat16 |
|
elif self.precision == "16-true": |
|
param_dtype = reduce_dtype = buffer_dtype = torch.float16 |
|
elif self.precision == "bf16-true": |
|
param_dtype = reduce_dtype = buffer_dtype = torch.bfloat16 |
|
elif self.precision == "32-true": |
|
param_dtype = torch.float32 |
|
reduce_dtype = buffer_dtype = torch.float32 |
|
else: |
|
raise ValueError(f"Was unable to infer precision type, received {self.precision!r}.") |
|
|
|
return TorchMixedPrecision( |
|
param_dtype=param_dtype, |
|
reduce_dtype=reduce_dtype, |
|
buffer_dtype=buffer_dtype, |
|
) |
|
|
|
@override |
|
def tensor_init_context(self) -> AbstractContextManager: |
|
return _DtypeContextManager(self._desired_input_dtype) |
|
|
|
@override |
|
def module_init_context(self) -> AbstractContextManager: |
|
return _DtypeContextManager(self.mixed_precision_config.param_dtype or torch.float32) |
|
|
|
@override |
|
def forward_context(self) -> AbstractContextManager: |
|
if "mixed" in self.precision: |
|
return torch.autocast("cuda", dtype=(torch.bfloat16 if self.precision == "bf16-mixed" else torch.float16)) |
|
return self.tensor_init_context() |
|
|
|
@override |
|
def convert_input(self, data: Any) -> Any: |
|
return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=self._desired_input_dtype) |
|
|
|
@override |
|
def convert_output(self, data: Any) -> Any: |
|
return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=torch.get_default_dtype()) |
|
|
|
@override |
|
def backward(self, tensor: Tensor, model: Optional[Module], *args: Any, **kwargs: Any) -> None: |
|
if self.scaler is not None: |
|
tensor = self.scaler.scale(tensor) |
|
super().backward(tensor, model, *args, **kwargs) |
|
|
|
@override |
|
def optimizer_step( |
|
self, |
|
optimizer: Optimizable, |
|
**kwargs: Any, |
|
) -> Any: |
|
if self.scaler is None: |
|
|
|
return super().optimizer_step(optimizer, **kwargs) |
|
|
|
step_output = self.scaler.step(optimizer, **kwargs) |
|
self.scaler.update() |
|
return step_output |
|
|
|
@override |
|
def unscale_gradients(self, optimizer: Optimizer) -> None: |
|
scaler = self.scaler |
|
if scaler is not None: |
|
if _optimizer_handles_unscaling(optimizer): |
|
raise NotImplementedError("Gradient clipping is not implemented for optimizers handling the unscaling.") |
|
scaler.unscale_(optimizer) |
|
|
|
@override |
|
def state_dict(self) -> dict[str, Any]: |
|
if self.scaler is not None: |
|
return self.scaler.state_dict() |
|
return {} |
|
|
|
@override |
|
def load_state_dict(self, state_dict: dict[str, Any]) -> None: |
|
if self.scaler is not None: |
|
self.scaler.load_state_dict(state_dict) |
|
|