File size: 30,567 Bytes
9c6594c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 |
# Copyright The Lightning AI team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import shutil
from collections.abc import Generator, Mapping
from contextlib import contextmanager, nullcontext
from datetime import timedelta
from pathlib import Path
from typing import (
TYPE_CHECKING,
Any,
Callable,
Literal,
Optional,
Union,
)
import torch
from lightning_utilities.core.rank_zero import rank_zero_only as utils_rank_zero_only
from torch import Tensor
from torch.nn import Module
from torch.optim import Optimizer
from typing_extensions import override
import pytorch_lightning as pl
from lightning_fabric.plugins import CheckpointIO, ClusterEnvironment
from lightning_fabric.plugins.collectives.torch_collective import default_pg_timeout
from lightning_fabric.strategies import _StrategyRegistry
from lightning_fabric.strategies.fsdp import (
_METADATA_FILENAME,
_activation_checkpointing_kwargs,
_auto_wrap_policy_kwargs,
_distributed_checkpoint_load,
_distributed_checkpoint_save,
_get_full_state_dict_context,
_get_sharded_state_dict_context,
_init_cpu_offload,
_init_sharding_strategy,
_is_full_checkpoint,
_is_sharded_checkpoint,
_move_torchmetrics_to_device,
_optimizer_has_flat_params,
_setup_activation_checkpointing,
)
from lightning_fabric.strategies.model_parallel import _load_raw_module_state
from lightning_fabric.utilities.distributed import (
_distributed_is_initialized,
_get_default_process_group_backend_for_device,
_init_dist_connection,
_sync_ddp_if_available,
)
from lightning_fabric.utilities.distributed import group as _group
from lightning_fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_2
from lightning_fabric.utilities.init import _has_meta_device_parameters_or_buffers
from lightning_fabric.utilities.load import _lazy_load, _materialize_tensors
from lightning_fabric.utilities.optimizer import _optimizers_to_device
from lightning_fabric.utilities.seed import reset_seed
from lightning_fabric.utilities.types import _PATH, ReduceOp
from pytorch_lightning.core.optimizer import LightningOptimizer
from pytorch_lightning.plugins.precision import Precision
from pytorch_lightning.plugins.precision.fsdp import FSDPPrecision
from pytorch_lightning.strategies.launchers.subprocess_script import _SubprocessScriptLauncher
from pytorch_lightning.strategies.parallel import ParallelStrategy
from pytorch_lightning.strategies.strategy import TBroadcast
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_only, rank_zero_warn
if TYPE_CHECKING:
from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload, MixedPrecision, ShardingStrategy
from torch.distributed.fsdp.wrap import ModuleWrapPolicy
_POLICY = Union[set[type[Module]], Callable[[Module, bool, int], bool], ModuleWrapPolicy]
_SHARDING_STRATEGY = Union[ShardingStrategy, Literal["FULL_SHARD", "SHARD_GRAD_OP", "NO_SHARD", "HYBRID_SHARD"]]
log = logging.getLogger(__name__)
class FSDPStrategy(ParallelStrategy):
r"""Strategy for Fully Sharded Data Parallel provided by torch.distributed.
Fully Sharded Training shards the entire model across all available GPUs, allowing you to scale model
size, whilst using efficient communication to reduce overhead. In practice, this means we can remain
at parity with PyTorch DDP, whilst scaling our model sizes dramatically. The technique is similar
to ZeRO-Stage 3.
For more information check out
`this blogpost <https://pytorch.org/blog/introducing-pytorch-fully-sharded-data-parallel-api>`__.
Defaults have been set and options have been exposed, but may require configuration
based on your level of memory/speed efficiency. We suggest having a look at
`this tutorial <https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html>`__ for more information.
Arguments:
cpu_offload: See ``cpu_offload`` parameter in :class:`torch.distributed.fsdp.FullyShardedDataParallel`.
mixed_precision: See ``mixed_precision`` parameter in :class:`torch.distributed.fsdp.FullyShardedDataParallel`.
auto_wrap_policy: Same as ``auto_wrap_policy`` parameter in
:class:`torch.distributed.fsdp.FullyShardedDataParallel`. For convenience, this also accepts a set of the
layer classes to wrap.
activation_checkpointing: Deprecated. Use ``activation_checkpointing_policy``.
activation_checkpointing_policy: Same as ``auto_wrap_policy`` parameter in
:class:`torch.distributed.fsdp.FullyShardedDataParallel` but used when selecting the modules for which you
want to enable activation checkpointing. Enabling this can free up a significant amount of memory at the
cost of speed since activations in these layers need to be recomputed during backpropagation. For
convenience, this also accepts a set of the layer classes to wrap.
sharding_strategy: Select whether to shard model parameters, gradients, optimizer states, or a combination of
them. Available values are:
- ``"FULL_SHARD"``: Shards model parameters, gradients, and optimizer states (default).
- ``"SHARD_GRAD_OP"``: Shards gradients and optimizer states only. Model parameters get replicated.
- ``"NO_SHARD"``: No sharding (identical to regular DDP).
- ``"HYBRID_SHARD"``: Shards model parameters, gradients, and optimizer states within a single machine, but
replicates across machines. See also the `device_mesh` parameter below.
Also accepts a :class:`torch.distributed.fsdp.ShardingStrategy` enum value.
device_mesh: A tuple `(replication size, sharding size)` that defines over how many devices to shard and
replicate the model. The product of the two numbers must equal the world size. Only valid in combination
with the `HYBRID_SHARD` sharding strategy.
state_dict_type: The format in which the state of the model and optimizers gets saved into the checkpoint.
- ``"full"``: The full weights and optimizer states get assembled on rank 0 and saved to a single file.
- ``"sharded"``: Each rank saves its shard of weights and optimizer states to a file. The checkpoint is
a folder with as many files as the world size.
\**kwargs: See available parameters in :class:`torch.distributed.fsdp.FullyShardedDataParallel`.
"""
strategy_name = "fsdp"
_registered_strategies: list[str] = []
def __init__(
self,
accelerator: Optional["pl.accelerators.Accelerator"] = None,
parallel_devices: Optional[list[torch.device]] = None,
cluster_environment: Optional[ClusterEnvironment] = None,
checkpoint_io: Optional[CheckpointIO] = None,
precision_plugin: Optional[Precision] = None,
process_group_backend: Optional[str] = None,
timeout: Optional[timedelta] = default_pg_timeout,
cpu_offload: Union[bool, "CPUOffload", None] = None,
mixed_precision: Optional["MixedPrecision"] = None,
auto_wrap_policy: Optional["_POLICY"] = None,
activation_checkpointing: Optional[Union[type[Module], list[type[Module]]]] = None,
activation_checkpointing_policy: Optional["_POLICY"] = None,
sharding_strategy: "_SHARDING_STRATEGY" = "FULL_SHARD",
state_dict_type: Literal["full", "sharded"] = "full",
device_mesh: Optional[Union[tuple[int], "DeviceMesh"]] = None,
**kwargs: Any,
) -> None:
super().__init__(
accelerator=accelerator,
parallel_devices=parallel_devices,
cluster_environment=cluster_environment,
checkpoint_io=checkpoint_io,
precision_plugin=precision_plugin,
)
self.num_nodes = 1
self._process_group_backend = process_group_backend
self._timeout: Optional[timedelta] = timeout
self.cpu_offload = _init_cpu_offload(cpu_offload)
self.mixed_precision = mixed_precision
self.kwargs = _auto_wrap_policy_kwargs(auto_wrap_policy, kwargs)
if device_mesh is not None:
if not _TORCH_GREATER_EQUAL_2_2:
raise ValueError("The `device_mesh` argument is only supported in torch >= 2.2.")
self.kwargs["device_mesh"] = device_mesh
self.sharding_strategy = _init_sharding_strategy(sharding_strategy, self.kwargs)
# Avoids the need for user to reference params in `configure_optimizers` via
# `self.trainer.model.parameters()` and enables support for multiple parameter groups.
self.kwargs.setdefault("use_orig_params", True)
self._activation_checkpointing_kwargs = _activation_checkpointing_kwargs(
activation_checkpointing, activation_checkpointing_policy
)
self._state_dict_type = state_dict_type
@property
@override
def root_device(self) -> torch.device:
assert self.parallel_devices is not None
return self.parallel_devices[self.local_rank]
@property
def num_processes(self) -> int:
return len(self.parallel_devices) if self.parallel_devices is not None else 0
@property
def process_group_backend(self) -> Optional[str]:
return self._process_group_backend
@property
def mixed_precision_config(self) -> Optional["MixedPrecision"]:
if self.mixed_precision:
return self.mixed_precision
plugin = self.precision_plugin
if isinstance(plugin, FSDPPrecision):
return plugin.mixed_precision_config
return None
@property
@override
def precision_plugin(self) -> FSDPPrecision:
plugin = self._precision_plugin
if plugin is not None:
assert isinstance(plugin, FSDPPrecision)
return plugin
return FSDPPrecision("32-true")
@precision_plugin.setter
@override
def precision_plugin(self, precision_plugin: Optional[Precision]) -> None:
if precision_plugin is not None and not isinstance(precision_plugin, FSDPPrecision):
raise TypeError(
f"The FSDP strategy can only work with the `FSDPPrecision` plugin, found {precision_plugin}"
)
self._precision_plugin = precision_plugin
@property
@override
def distributed_sampler_kwargs(self) -> dict:
return {"num_replicas": (self.num_nodes * self.num_processes), "rank": self.global_rank}
@property
@override
def restore_checkpoint_after_setup(self) -> bool:
return True
@property
@override
def lightning_restore_optimizer(self) -> bool:
return False
@override
def setup_environment(self) -> None:
super().setup_environment()
log.debug(f"{self.__class__.__name__}: setting up distributed...")
reset_seed()
# determine which process we are and world size
self.set_world_ranks()
self._process_group_backend = self._get_process_group_backend()
assert self.cluster_environment is not None
_init_dist_connection(self.cluster_environment, self._process_group_backend, timeout=self._timeout)
# if 'device_mesh' in the `kwargs` is provided as a tuple, update it into the `DeviceMesh` object here
if isinstance(self.kwargs.get("device_mesh"), tuple):
from torch.distributed.device_mesh import init_device_mesh
self.kwargs["device_mesh"] = init_device_mesh("cuda", self.kwargs["device_mesh"])
def _get_process_group_backend(self) -> str:
return self._process_group_backend or _get_default_process_group_backend_for_device(self.root_device)
def set_world_ranks(self) -> None:
if self.cluster_environment is not None:
self.cluster_environment.set_global_rank(self.node_rank * self.num_processes + self.local_rank)
self.cluster_environment.set_world_size(self.num_nodes * self.num_processes)
# `LightningEnvironment.set_global_rank` will do this too, but we cannot rely on that implementation detail
# additionally, for some implementations, the setter is a no-op, so it's safer to access the getter
rank_zero_only.rank = utils_rank_zero_only.rank = self.global_rank
@override
def _configure_launcher(self) -> None:
assert self.cluster_environment is not None
if not self.cluster_environment.creates_processes_externally:
self._launcher = _SubprocessScriptLauncher(self.cluster_environment, self.num_processes, self.num_nodes)
@override
def _setup_model(self, model: Module) -> Module:
"""Wraps the model into a :class:`~torch.distributed.fsdp.fully_sharded_data_parallel.FullyShardedDataParallel`
module."""
from torch.distributed.fsdp import FullyShardedDataParallel
if any(isinstance(mod, FullyShardedDataParallel) for mod in model.modules()):
if _has_meta_device_parameters_or_buffers(model):
rank_zero_warn(
"The model is already wrapped in `FSDP` but there are still parameters on the meta device."
)
if "auto_wrap_policy" in self.kwargs:
# The user has wrapped their submodules manually, don't apply the auto wrap policy.
rank_zero_warn(
"A FSDP `auto_wrap_policy` is set, but the model is already wrapped. The policy will be ignored."
)
del self.kwargs["auto_wrap_policy"]
else:
log.debug(f"setting up FSDP model with device id: {self.root_device.index}, kwargs: {self.kwargs}")
model = FullyShardedDataParallel(
module=model,
cpu_offload=self.cpu_offload,
mixed_precision=self.mixed_precision_config,
sharding_strategy=self.sharding_strategy,
device_id=self.root_device.index,
**self.kwargs,
)
_move_torchmetrics_to_device(model, self.root_device)
# activation checkpointing needs to be set up after wrapping the model
_setup_activation_checkpointing(model, self._activation_checkpointing_kwargs)
return model
@override
def setup(self, trainer: "pl.Trainer") -> None:
assert self.accelerator is not None
self.accelerator.setup(trainer)
assert self.model is not None
if trainer.state.fn == TrainerFn.FITTING and self._layer_sync:
self.model = self._layer_sync.apply(self.model)
self.model = self.precision_plugin.convert_module(self.model)
if is_overridden("configure_sharded_model", self.lightning_module):
# legacy: we don't skip setup with the `configure_model` alternative
rank_zero_info(
"You have overridden `LightningModule.configure_sharded_model` hook. It will assume that all the layers"
" are already wrapped for sharding and won't wrap the entire model using `FullyShardedDataParallel`."
)
else:
self.model = self._setup_model(self.model)
self.barrier()
if trainer.state.fn == TrainerFn.FITTING:
self.setup_optimizers(trainer)
self.setup_precision_plugin()
if trainer.state.fn == TrainerFn.FITTING:
_optimizers_to_device(self.optimizers, self.root_device)
@override
def setup_optimizers(self, trainer: "pl.Trainer") -> None:
# If we're setting up for evaluation after fitting, we need to discard the optimizers
# since we're rewrapping the model, otherwise optimizer param references are no longer valid
# and subsequent checkpoint saving can fail
self._reset_optimizers_and_schedulers()
if self.kwargs.get("use_orig_params"):
return super().setup_optimizers(trainer)
invalid_params_error = False
try:
# If `use_orig_params=False` the user needs to do access `self.trainer.model.parameters()` in
# `configure_optimizers()`
super().setup_optimizers(trainer)
except ValueError as ex:
if "optimizer got an empty parameter list" not in str(ex):
raise
invalid_params_error = True
if invalid_params_error or any(not _optimizer_has_flat_params(optimizer) for optimizer in self.optimizers):
# We avoid this limitation by setting `use_orig_params=True`
raise ValueError(
"The optimizer does not seem to reference any FSDP parameters. HINT: Make sure to create the"
" optimizer after setting up the model by referencing `self.trainer.model.parameters()` in the"
" `configure_optimizers()` hook."
)
return None
@override
def model_to_device(self) -> None:
# FSDP takes care of moving the model to device
pass
@contextmanager
@override
def tensor_init_context(self, empty_init: Optional[bool] = None) -> Generator[None, None, None]:
# Materialization happens in `setup`. When modules get wrapped by FSDP, the sequence of operations is:
# 1) materialize module 2) call `reset_parameters()` 3) shard the module.
# These operations are applied to each submodule 'bottom up' in the module hierarchy.
empty_init_context = torch.device("meta") if empty_init else nullcontext()
with empty_init_context, self.precision_plugin.tensor_init_context():
yield
@contextmanager
@override
def model_sharded_context(self) -> Generator[None, None, None]:
log.debug(f"{self.__class__.__name__}: entered model_sharded_context.")
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel
from torch.distributed.fsdp.wrap import enable_wrap
with enable_wrap(
wrapper_cls=FullyShardedDataParallel,
cpu_offload=self.cpu_offload,
mixed_precision=self.mixed_precision_config,
sharding_strategy=self.sharding_strategy,
device_id=self.root_device.index,
**self.kwargs,
):
yield
@override
def barrier(self, name: Optional[str] = None) -> None:
if not _distributed_is_initialized():
return
if torch.distributed.get_backend() == "nccl":
torch.distributed.barrier(device_ids=self._determine_device_ids())
else:
torch.distributed.barrier()
@override
def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast:
if not _distributed_is_initialized():
return obj
obj = [obj]
torch.distributed.broadcast_object_list(obj, src, group=_group.WORLD)
return obj[0]
@override
def reduce(
self,
tensor: Union[Tensor, Any],
group: Optional[Any] = None,
reduce_op: Optional[Union[ReduceOp, str]] = "mean",
) -> Tensor:
"""Reduces a tensor from several distributed processes to one aggregated tensor.
Args:
tensor: the tensor to sync and reduce
group: the process group to gather results from. Defaults to all processes (world)
reduce_op: the reduction operation. Defaults to 'mean'/'avg'.
Can also be a string 'sum' to calculate the sum during reduction.
Return:
reduced value, except when the input was not a tensor the output remains is unchanged
"""
if isinstance(tensor, Tensor):
return _sync_ddp_if_available(tensor, group, reduce_op=reduce_op)
return tensor
def _determine_device_ids(self) -> list[int]:
return [self.root_device.index]
@override
def teardown(self) -> None:
log.debug(f"{self.__class__.__name__}: tearing down strategy...")
pl_module = self.lightning_module
if (
pl_module is not None
# `self.lightning_module._trainer` can be None if teardown gets called on an exception before
# the trainer gets set on the LightningModule
and pl_module._trainer is not None
and pl_module._trainer.state.fn == TrainerFn.FITTING
and self._layer_sync
):
assert self.model is not None
self.model = self._layer_sync.revert(self.model)
assert self.cluster_environment is not None
assert self.accelerator is not None
self.cluster_environment.teardown()
self.precision_plugin.teardown()
self.accelerator.teardown()
@classmethod
def get_registered_strategies(cls) -> list[str]:
return cls._registered_strategies
@classmethod
@override
def register_strategies(cls, strategy_registry: _StrategyRegistry) -> None:
if not torch.distributed.is_available():
return
strategy_registry.register(
"fsdp",
cls,
description="Fully Sharded Data Parallel (FSDP) training",
)
cls._registered_strategies.append("fsdp")
strategy_registry.register(
"fsdp_cpu_offload",
cls,
description="Fully Sharded Data Parallel (FSDP) training with Full Sharding and CPU Offloading",
cpu_offload=True,
)
cls._registered_strategies.append("fsdp_cpu_offload")
@override
def lightning_module_state_dict(self) -> dict[str, Any]:
assert self.model is not None
if self._state_dict_type == "sharded":
state_dict_ctx = _get_sharded_state_dict_context(self.model)
elif self._state_dict_type == "full":
state_dict_ctx = _get_full_state_dict_context(self.model, world_size=self.world_size)
else:
raise ValueError(f"Unknown state_dict_type: {self._state_dict_type}")
with state_dict_ctx:
return self.model.state_dict()
@override
def load_model_state_dict(self, checkpoint: Mapping[str, Any], strict: bool = True) -> None:
# Override to do nothing, FSDP already loaded the states in `load_checkpoint()`
pass
@override
def optimizer_state(self, optimizer: Optimizer) -> dict[str, Tensor]:
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import OptimStateKeyType
if isinstance(optimizer, LightningOptimizer):
optimizer = optimizer._optimizer
assert self.model is not None
if self._state_dict_type == "sharded":
with _get_sharded_state_dict_context(self.model):
return FSDP.optim_state_dict(self.model, optimizer)
elif self._state_dict_type == "full":
with _get_full_state_dict_context(self.model, world_size=self.world_size):
state_dict = FSDP.optim_state_dict(self.model, optimizer)
if self.global_rank == 0:
# Store the optimizer state dict in standard format
state_dict = FSDP.rekey_optim_state_dict(state_dict, OptimStateKeyType.PARAM_ID, self.model)
return state_dict
raise ValueError(f"Unknown state_dict_type: {self._state_dict_type}")
@override
def load_optimizer_state_dict(self, checkpoint: Mapping[str, Any]) -> None:
# Override to do nothing, the FSDP already loaded the states in `load_checkpoint()`
pass
@override
def save_checkpoint(
self, checkpoint: dict[str, Any], filepath: _PATH, storage_options: Optional[Any] = None
) -> None:
if storage_options is not None:
raise TypeError(
"`FSDPStrategy.save_checkpoint(..., storage_options=...)` is not supported because"
" `FSDPStrategy` does not use the `CheckpointIO`."
)
path = Path(self.broadcast(filepath))
if path.is_dir() and self._state_dict_type == "full" and not _is_sharded_checkpoint(path):
raise IsADirectoryError(f"The checkpoint path exists and is a directory: {path}")
if self._state_dict_type == "sharded":
if path.is_file():
path.unlink()
path.mkdir(parents=True, exist_ok=True)
converted_state = {"model": checkpoint.pop("state_dict")}
converted_state.update({
f"optimizer_{idx}": optim_state
for idx, optim_state in enumerate(checkpoint.pop("optimizer_states", []))
})
_distributed_checkpoint_save(converted_state, path)
if self.global_rank == 0:
torch.save(checkpoint, path / _METADATA_FILENAME)
elif self._state_dict_type == "full":
if _is_sharded_checkpoint(path):
shutil.rmtree(path)
return super().save_checkpoint(checkpoint=checkpoint, filepath=path)
else:
raise ValueError(f"Unknown state_dict_type: {self._state_dict_type}")
@override
def load_checkpoint(self, checkpoint_path: _PATH) -> dict[str, Any]:
# broadcast the path from rank 0 to ensure all the states are loaded from a common path
path = Path(self.broadcast(checkpoint_path))
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
assert self.model is not None
assert self.lightning_module is not None
if _is_sharded_checkpoint(path):
from torch.distributed.checkpoint.optimizer import load_sharded_optimizer_state_dict
state_dict_ctx = _get_sharded_state_dict_context(self.model)
with state_dict_ctx:
module_state = {"model": self.model.state_dict()}
_distributed_checkpoint_load(module_state, path)
self.model.load_state_dict(module_state["model"], strict=self.lightning_module.strict_loading)
if self.lightning_module.trainer.state.fn == TrainerFn.FITTING and self.optimizers:
from torch.distributed.checkpoint import FileSystemReader
# TODO: replace with newer APIs
# https://github.com/pytorch/pytorch/issues/119800#issuecomment-1942156271
reader = FileSystemReader(path=path)
# the optimizer states must be loaded separately
for idx, optim in enumerate(self.optimizers):
optim_key = f"optimizer_{idx}"
optim_state = load_sharded_optimizer_state_dict(
model_state_dict=module_state["model"],
optimizer_key=optim_key,
storage_reader=reader,
)
flattened_osd = FSDP.optim_state_dict_to_load(
optim_state_dict=optim_state[optim_key],
model=self.model,
optim=optim,
)
optim.load_state_dict(flattened_osd)
# Load metadata (anything not a module or optimizer)
metadata = torch.load(path / _METADATA_FILENAME)
return metadata
if _is_full_checkpoint(path):
checkpoint = _lazy_load(path)
_load_raw_module_state(
checkpoint.pop("state_dict"),
module=self.model,
world_size=self.world_size,
strict=self.lightning_module.strict_loading,
)
# Materialize lazy tensors if there are any left in the checkpoint
# The `torch.Optimizer.load_state_dict` method can't load lazy tensors because of deepcopy pickle issues
checkpoint = _materialize_tensors(checkpoint)
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import OptimStateKeyType
optimizer_states = checkpoint.get("optimizer_states")
if optimizer_states is None or self.lightning_module.trainer.state.fn != TrainerFn.FITTING:
# If the optimizer states are not present, we don't need to do anything (backward compatibility)
return checkpoint
if len(self.optimizers) != len(optimizer_states):
raise RuntimeError(
f"You have configured {len(self.optimizers)} optimizers but the checkpoint contains"
f" {len(optimizer_states)} optimizers to load. Please resume training with the same number"
" of optimizers or edit the checkpoint manually to remove states."
)
# rank0_only should be false because we need to load the optimizer state on all ranks
with _get_full_state_dict_context(self.model, world_size=self.world_size, rank0_only=False):
for optimizer, opt_state in zip(self.optimizers, optimizer_states):
if isinstance(list(opt_state["state"].keys())[0], int):
# Handling the case where the optimizer state is saved from a normal optimizer
opt_state = FSDP.rekey_optim_state_dict(opt_state, OptimStateKeyType.PARAM_NAME, self.model)
opt_state = FSDP.optim_state_dict_to_load(
optim_state_dict=opt_state,
model=self.model,
optim=optimizer,
)
optimizer.load_state_dict(opt_state)
return checkpoint
raise ValueError(
f"The path {str(path)!r} does not point to a valid checkpoint. Make sure the path points to either a"
" directory with FSDP checkpoint shards, or a single file with a full checkpoint."
)
|