|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import logging |
|
import signal |
|
from copy import deepcopy |
|
from typing import Any, Callable, Optional, Union |
|
|
|
from packaging.version import Version |
|
|
|
import pytorch_lightning as pl |
|
from lightning_fabric.utilities.device_dtype_mixin import _DeviceDtypeModuleMixin |
|
from pytorch_lightning.callbacks import Checkpoint, EarlyStopping |
|
from pytorch_lightning.loggers import WandbLogger |
|
from pytorch_lightning.strategies.launchers import _SubprocessScriptLauncher |
|
from pytorch_lightning.trainer.connectors.signal_connector import _get_sigkill_signal |
|
from pytorch_lightning.trainer.states import TrainerStatus |
|
from pytorch_lightning.utilities.exceptions import _TunerExitException |
|
from pytorch_lightning.utilities.model_helpers import is_overridden |
|
from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_warn |
|
|
|
log = logging.getLogger(__name__) |
|
|
|
|
|
def _call_and_handle_interrupt(trainer: "pl.Trainer", trainer_fn: Callable, *args: Any, **kwargs: Any) -> Any: |
|
r"""Error handling, intended to be used only for main trainer function entry points (fit, validate, test, predict) |
|
as all errors should funnel through them. |
|
|
|
Args: |
|
trainer_fn: one of (fit, validate, test, predict) |
|
*args: positional arguments to be passed to the `trainer_fn` |
|
**kwargs: keyword arguments to be passed to `trainer_fn` |
|
|
|
""" |
|
try: |
|
if trainer.strategy.launcher is not None: |
|
return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs) |
|
return trainer_fn(*args, **kwargs) |
|
|
|
except _TunerExitException: |
|
_call_teardown_hook(trainer) |
|
trainer._teardown() |
|
trainer.state.status = TrainerStatus.FINISHED |
|
trainer.state.stage = None |
|
|
|
except KeyboardInterrupt as exception: |
|
rank_zero_info("\nDetected KeyboardInterrupt, attempting graceful shutdown ...") |
|
|
|
signal.signal(signal.SIGINT, signal.SIG_IGN) |
|
_interrupt(trainer, exception) |
|
trainer._teardown() |
|
launcher = trainer.strategy.launcher |
|
if isinstance(launcher, _SubprocessScriptLauncher): |
|
launcher.kill(_get_sigkill_signal()) |
|
exit(1) |
|
|
|
except BaseException as exception: |
|
_interrupt(trainer, exception) |
|
trainer._teardown() |
|
|
|
trainer.state.stage = None |
|
raise |
|
|
|
|
|
def _interrupt(trainer: "pl.Trainer", exception: BaseException) -> None: |
|
trainer.state.status = TrainerStatus.INTERRUPTED |
|
_call_callback_hooks(trainer, "on_exception", exception) |
|
if trainer.datamodule is not None: |
|
_call_lightning_datamodule_hook(trainer, "on_exception", exception) |
|
trainer.strategy.on_exception(exception) |
|
for logger in trainer.loggers: |
|
logger.finalize("failed") |
|
|
|
|
|
def _call_setup_hook(trainer: "pl.Trainer") -> None: |
|
assert trainer.state.fn is not None |
|
fn = trainer.state.fn |
|
|
|
|
|
|
|
for module in trainer.lightning_module.modules(): |
|
if isinstance(module, _DeviceDtypeModuleMixin): |
|
module._device = trainer.strategy.root_device |
|
|
|
|
|
|
|
loggers = sorted(trainer.loggers, key=lambda logger: not isinstance(logger, WandbLogger)) |
|
|
|
|
|
for logger in loggers: |
|
if hasattr(logger, "experiment"): |
|
_ = logger.experiment |
|
|
|
trainer.strategy.barrier("pre_setup") |
|
|
|
if trainer.datamodule is not None: |
|
_call_lightning_datamodule_hook(trainer, "setup", stage=fn) |
|
_call_callback_hooks(trainer, "setup", stage=fn) |
|
_call_lightning_module_hook(trainer, "setup", stage=fn) |
|
|
|
trainer.strategy.barrier("post_setup") |
|
|
|
|
|
def _call_configure_model(trainer: "pl.Trainer") -> None: |
|
|
|
if is_overridden("configure_sharded_model", trainer.lightning_module): |
|
with trainer.strategy.model_sharded_context(): |
|
_call_lightning_module_hook(trainer, "configure_sharded_model") |
|
|
|
|
|
|
|
if is_overridden("configure_model", trainer.lightning_module): |
|
with ( |
|
trainer.strategy.tensor_init_context(), |
|
trainer.strategy.model_sharded_context(), |
|
trainer.precision_plugin.module_init_context(), |
|
): |
|
_call_lightning_module_hook(trainer, "configure_model") |
|
|
|
|
|
def _call_teardown_hook(trainer: "pl.Trainer") -> None: |
|
assert trainer.state.fn is not None |
|
fn = trainer.state.fn |
|
|
|
if trainer.datamodule is not None: |
|
_call_lightning_datamodule_hook(trainer, "teardown", stage=fn) |
|
|
|
_call_callback_hooks(trainer, "teardown", stage=fn) |
|
_call_lightning_module_hook(trainer, "teardown", stage=fn) |
|
|
|
trainer.lightning_module._current_fx_name = None |
|
|
|
trainer.lightning_module._metric_attributes = None |
|
|
|
|
|
|
|
for logger in trainer.loggers: |
|
logger.finalize("success") |
|
|
|
|
|
trainer.profiler.describe() |
|
|
|
|
|
def _call_lightning_module_hook( |
|
trainer: "pl.Trainer", |
|
hook_name: str, |
|
*args: Any, |
|
pl_module: Optional["pl.LightningModule"] = None, |
|
**kwargs: Any, |
|
) -> Any: |
|
log.debug(f"{trainer.__class__.__name__}: calling lightning module hook: {hook_name}") |
|
|
|
pl_module = pl_module or trainer.lightning_module |
|
|
|
if pl_module is None: |
|
raise TypeError("No `LightningModule` is available to call hooks on.") |
|
|
|
fn = getattr(pl_module, hook_name) |
|
if not callable(fn): |
|
return None |
|
|
|
prev_fx_name = pl_module._current_fx_name |
|
pl_module._current_fx_name = hook_name |
|
|
|
with trainer.profiler.profile(f"[LightningModule]{pl_module.__class__.__name__}.{hook_name}"): |
|
output = fn(*args, **kwargs) |
|
|
|
|
|
pl_module._current_fx_name = prev_fx_name |
|
|
|
return output |
|
|
|
|
|
def _call_lightning_datamodule_hook( |
|
trainer: "pl.Trainer", |
|
hook_name: str, |
|
*args: Any, |
|
**kwargs: Any, |
|
) -> Any: |
|
log.debug(f"{trainer.__class__.__name__}: calling lightning datamodule hook: {hook_name}") |
|
|
|
if trainer.datamodule is None: |
|
raise TypeError("No `LightningDataModule` is available to call hooks on.") |
|
|
|
fn = getattr(trainer.datamodule, hook_name) |
|
if callable(fn): |
|
with trainer.profiler.profile(f"[LightningDataModule]{trainer.datamodule.__class__.__name__}.{hook_name}"): |
|
return fn(*args, **kwargs) |
|
return None |
|
|
|
|
|
def _call_callback_hooks( |
|
trainer: "pl.Trainer", |
|
hook_name: str, |
|
*args: Any, |
|
monitoring_callbacks: Optional[bool] = None, |
|
**kwargs: Any, |
|
) -> None: |
|
log.debug(f"{trainer.__class__.__name__}: calling callback hook: {hook_name}") |
|
|
|
pl_module = trainer.lightning_module |
|
if pl_module: |
|
prev_fx_name = pl_module._current_fx_name |
|
pl_module._current_fx_name = hook_name |
|
|
|
callbacks = trainer.callbacks |
|
if monitoring_callbacks is True: |
|
|
|
callbacks = [cb for cb in callbacks if isinstance(cb, (EarlyStopping, Checkpoint))] |
|
elif monitoring_callbacks is False: |
|
callbacks = [cb for cb in callbacks if not isinstance(cb, (EarlyStopping, Checkpoint))] |
|
|
|
for callback in callbacks: |
|
fn = getattr(callback, hook_name) |
|
if callable(fn): |
|
with trainer.profiler.profile(f"[Callback]{callback.state_key}.{hook_name}"): |
|
fn(trainer, trainer.lightning_module, *args, **kwargs) |
|
|
|
if pl_module: |
|
|
|
pl_module._current_fx_name = prev_fx_name |
|
|
|
|
|
def _call_callbacks_state_dict(trainer: "pl.Trainer") -> dict[str, dict]: |
|
"""Called when saving a model checkpoint, calls and returns every callback's `state_dict`, keyed by |
|
`Callback.state_key`.""" |
|
callback_state_dicts = {} |
|
for callback in trainer.callbacks: |
|
state_dict = callback.state_dict() |
|
if state_dict: |
|
callback_state_dicts[callback.state_key] = state_dict |
|
return callback_state_dicts |
|
|
|
|
|
def _call_callbacks_on_save_checkpoint(trainer: "pl.Trainer", checkpoint: dict[str, Any]) -> None: |
|
"""Called when saving a model checkpoint, calls every callback's `on_save_checkpoint` hook.""" |
|
pl_module = trainer.lightning_module |
|
if pl_module: |
|
prev_fx_name = pl_module._current_fx_name |
|
pl_module._current_fx_name = "on_save_checkpoint" |
|
|
|
for callback in trainer.callbacks: |
|
with trainer.profiler.profile(f"[Callback]{callback.state_key}.on_save_checkpoint"): |
|
callback.on_save_checkpoint(trainer, trainer.lightning_module, checkpoint) |
|
|
|
if pl_module: |
|
|
|
pl_module._current_fx_name = prev_fx_name |
|
|
|
|
|
def _call_callbacks_on_load_checkpoint(trainer: "pl.Trainer", checkpoint: dict[str, Any]) -> None: |
|
"""Called when loading a model checkpoint. |
|
|
|
Calls every callback's `on_load_checkpoint` hook. We have a dedicated function for this rather than using |
|
`_call_callback_hooks` because we have special logic for getting callback_states. |
|
|
|
""" |
|
pl_module = trainer.lightning_module |
|
if pl_module: |
|
prev_fx_name = pl_module._current_fx_name |
|
pl_module._current_fx_name = "on_load_checkpoint" |
|
|
|
callback_states: Optional[dict[Union[type, str], dict]] = checkpoint.get("callbacks") |
|
|
|
if callback_states is None: |
|
return |
|
|
|
is_legacy_ckpt = Version(checkpoint["pytorch-lightning_version"]) < Version("1.5.0dev") |
|
current_callbacks_keys = {cb._legacy_state_key if is_legacy_ckpt else cb.state_key for cb in trainer.callbacks} |
|
difference = callback_states.keys() - current_callbacks_keys |
|
if difference: |
|
rank_zero_warn( |
|
"Be aware that when using `ckpt_path`," |
|
" callbacks used to create the checkpoint need to be provided during `Trainer` instantiation." |
|
f" Please add the following callbacks: {list(difference)}.", |
|
) |
|
|
|
for callback in trainer.callbacks: |
|
with trainer.profiler.profile(f"[Callback]{callback.state_key}.on_load_checkpoint"): |
|
callback.on_load_checkpoint(trainer, trainer.lightning_module, checkpoint) |
|
|
|
if pl_module: |
|
|
|
pl_module._current_fx_name = prev_fx_name |
|
|
|
|
|
def _call_callbacks_load_state_dict(trainer: "pl.Trainer", checkpoint: dict[str, Any]) -> None: |
|
"""Called when loading a model checkpoint, calls every callback's `load_state_dict`.""" |
|
callback_states: Optional[dict[Union[type, str], dict]] = checkpoint.get("callbacks") |
|
|
|
if callback_states is None: |
|
return |
|
|
|
for callback in trainer.callbacks: |
|
state = callback_states.get(callback.state_key, callback_states.get(callback._legacy_state_key)) |
|
if state: |
|
state = deepcopy(state) |
|
callback.load_state_dict(state) |
|
|
|
|
|
def _call_strategy_hook( |
|
trainer: "pl.Trainer", |
|
hook_name: str, |
|
*args: Any, |
|
**kwargs: Any, |
|
) -> Any: |
|
log.debug(f"{trainer.__class__.__name__}: calling strategy hook: {hook_name}") |
|
|
|
pl_module = trainer.lightning_module |
|
prev_fx_name = pl_module._current_fx_name |
|
pl_module._current_fx_name = hook_name |
|
|
|
fn = getattr(trainer.strategy, hook_name) |
|
if not callable(fn): |
|
return None |
|
|
|
with trainer.profiler.profile(f"[Strategy]{trainer.strategy.__class__.__name__}.{hook_name}"): |
|
output = fn(*args, **kwargs) |
|
|
|
|
|
pl_module._current_fx_name = prev_fx_name |
|
|
|
return output |
|
|