|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import inspect |
|
from collections.abc import Generator |
|
from contextlib import AbstractContextManager, contextmanager |
|
from typing import Any, Callable, Optional |
|
|
|
import torch |
|
import torch.distributed as dist |
|
from torch import Tensor |
|
|
|
import pytorch_lightning as pl |
|
from lightning_fabric.utilities.distributed import _distributed_is_initialized |
|
from lightning_fabric.utilities.warnings import PossibleUserWarning |
|
from pytorch_lightning.accelerators.xla import XLAAccelerator |
|
from pytorch_lightning.callbacks.timer import Timer |
|
from pytorch_lightning.loops import _Loop |
|
from pytorch_lightning.loops.fetchers import _DataFetcher, _DataLoaderIterDataFetcher, _PrefetchDataFetcher |
|
from pytorch_lightning.loops.progress import _BaseProgress |
|
from pytorch_lightning.strategies import FSDPStrategy |
|
from pytorch_lightning.strategies.parallel import ParallelStrategy |
|
from pytorch_lightning.strategies.strategy import Strategy |
|
from pytorch_lightning.trainer.states import RunningStage |
|
from pytorch_lightning.utilities.rank_zero import rank_zero_warn |
|
from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature |
|
|
|
|
|
def check_finite_loss(loss: Optional[Tensor]) -> None: |
|
"""Checks for finite loss value. |
|
|
|
Args: |
|
loss: the loss value to check to be finite |
|
|
|
""" |
|
if loss is not None and not torch.isfinite(loss).all(): |
|
raise ValueError(f"The loss returned in `training_step` is {loss}.") |
|
|
|
|
|
def _parse_loop_limits( |
|
min_steps: Optional[int], |
|
max_steps: int, |
|
min_epochs: Optional[int], |
|
max_epochs: Optional[int], |
|
trainer: "pl.Trainer", |
|
) -> tuple[int, int]: |
|
"""This utility computes the default values for the minimum and maximum number of steps and epochs given the values |
|
the user has selected. |
|
|
|
Args: |
|
min_steps: Minimum number of steps. |
|
max_steps: Maximum number of steps. |
|
min_epochs: Minimum number of epochs. |
|
max_epochs: Maximum number of epochs. |
|
trainer: Trainer instance. |
|
|
|
Returns: |
|
The parsed limits, with default values being set for the ones that the user did not specify. |
|
|
|
""" |
|
if max_epochs is None: |
|
if max_steps == -1 and not any(isinstance(cb, Timer) for cb in trainer.callbacks): |
|
rank_zero_warn( |
|
"`max_epochs` was not set. Setting it to 1000 epochs. To train without an epoch limit," |
|
" set `max_epochs=-1`.", |
|
category=PossibleUserWarning, |
|
) |
|
max_epochs = 1000 |
|
else: |
|
max_epochs = -1 |
|
|
|
if min_epochs is None and min_steps is not None: |
|
|
|
min_epochs = 1 |
|
|
|
if min_epochs is None: |
|
|
|
min_epochs = 0 |
|
|
|
return min_epochs, max_epochs |
|
|
|
|
|
@contextmanager |
|
def _block_parallel_sync_behavior(strategy: Strategy, block: bool = True) -> Generator[None, None, None]: |
|
"""Blocks synchronization in :class:`~pytorch_lightning.strategies.parallel.ParallelStrategy`. This is useful for |
|
example when accumulating gradients to reduce communication when it is not needed. |
|
|
|
Args: |
|
strategy: the strategy instance to use. |
|
block: whether the context manager is enabled or not |
|
|
|
Returns: |
|
context manager with sync behaviour off |
|
|
|
""" |
|
if isinstance(strategy, ParallelStrategy) and block: |
|
with strategy.block_backward_sync(): |
|
yield None |
|
else: |
|
yield None |
|
|
|
|
|
def _is_max_limit_reached(current: int, maximum: int = -1) -> bool: |
|
"""Check if the limit has been reached (if enabled). |
|
|
|
Args: |
|
current: the current value |
|
maximum: the maximum value (or -1 to disable limit) |
|
|
|
Returns: |
|
bool: whether the limit has been reached |
|
|
|
""" |
|
return maximum != -1 and current >= maximum |
|
|
|
|
|
def _reset_progress(loop: _Loop) -> None: |
|
for v in vars(loop).values(): |
|
if isinstance(v, _BaseProgress): |
|
v.reset() |
|
elif isinstance(v, _Loop): |
|
_reset_progress(v) |
|
|
|
|
|
def _select_data_fetcher(trainer: "pl.Trainer", stage: RunningStage) -> _DataFetcher: |
|
lightning_module = trainer.lightning_module |
|
if stage == RunningStage.TESTING: |
|
step_fx_name = "test_step" |
|
elif stage == RunningStage.TRAINING: |
|
step_fx_name = "training_step" |
|
elif stage in (RunningStage.VALIDATING, RunningStage.SANITY_CHECKING): |
|
step_fx_name = "validation_step" |
|
elif stage == RunningStage.PREDICTING: |
|
step_fx_name = "predict_step" |
|
else: |
|
raise RuntimeError(f"DataFetcher is unsupported for {trainer.state.stage}") |
|
step_fx = getattr(lightning_module, step_fx_name) |
|
if is_param_in_hook_signature(step_fx, "dataloader_iter", explicit=True): |
|
rank_zero_warn( |
|
f"Found `dataloader_iter` argument in the `{step_fx_name}`. Note that the support for " |
|
"this signature is experimental and the behavior is subject to change." |
|
) |
|
return _DataLoaderIterDataFetcher() |
|
return _PrefetchDataFetcher() |
|
|
|
|
|
def _no_grad_context(loop_run: Callable) -> Callable: |
|
def _decorator(self: _Loop, *args: Any, **kwargs: Any) -> Any: |
|
if not isinstance(self, _Loop): |
|
raise TypeError(f"`{type(self).__name__}` needs to be a Loop.") |
|
if not hasattr(self, "inference_mode"): |
|
raise TypeError(f"`{type(self).__name__}.inference_mode` needs to be defined") |
|
context_manager: type[AbstractContextManager] |
|
if _distributed_is_initialized() and dist.get_backend() == "gloo": |
|
|
|
|
|
|
|
context_manager = torch.no_grad |
|
elif isinstance(self.trainer.accelerator, XLAAccelerator): |
|
context_manager = torch.no_grad |
|
elif isinstance(self.trainer.strategy, FSDPStrategy): |
|
|
|
context_manager = torch.no_grad |
|
elif self.inference_mode: |
|
context_manager = torch.inference_mode |
|
else: |
|
context_manager = torch.no_grad |
|
with context_manager(): |
|
return loop_run(self, *args, **kwargs) |
|
|
|
return _decorator |
|
|
|
|
|
def _verify_dataloader_idx_requirement( |
|
hooks: tuple[str, ...], is_expected: bool, stage: RunningStage, pl_module: "pl.LightningModule" |
|
) -> None: |
|
for hook in hooks: |
|
fx = getattr(pl_module, hook) |
|
|
|
param_present = is_param_in_hook_signature(fx, "dataloader_idx") |
|
if not is_expected: |
|
if param_present: |
|
params = inspect.signature(fx).parameters |
|
if "dataloader_idx" in params and params["dataloader_idx"].default is inspect.Parameter.empty: |
|
raise RuntimeError( |
|
f"You provided only a single `{stage.dataloader_prefix}_dataloader`, but have included " |
|
f"`dataloader_idx` in `{type(pl_module).__name__}.{hook}()`. Either remove the" |
|
" argument or give it a default value i.e. `dataloader_idx=0`." |
|
) |
|
elif not param_present: |
|
raise RuntimeError( |
|
f"You provided multiple `{stage.dataloader_prefix}_dataloader`, but no `dataloader_idx`" |
|
f" argument in `{type(pl_module).__name__}.{hook}()`. Try adding `dataloader_idx=0` to its" |
|
" signature." |
|
) |
|
|