|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
r"""Base class used to build new callbacks.""" |
|
|
|
from typing import Any |
|
|
|
from torch import Tensor |
|
from torch.optim import Optimizer |
|
|
|
import pytorch_lightning as pl |
|
from pytorch_lightning.utilities.types import STEP_OUTPUT |
|
|
|
|
|
class Callback: |
|
r"""Abstract base class used to build new callbacks. |
|
|
|
Subclass this class and override any of the relevant hooks |
|
|
|
""" |
|
|
|
@property |
|
def state_key(self) -> str: |
|
"""Identifier for the state of the callback. |
|
|
|
Used to store and retrieve a callback's state from the checkpoint dictionary by |
|
``checkpoint["callbacks"][state_key]``. Implementations of a callback need to provide a unique state key if 1) |
|
the callback has state and 2) it is desired to maintain the state of multiple instances of that callback. |
|
|
|
""" |
|
return self.__class__.__qualname__ |
|
|
|
@property |
|
def _legacy_state_key(self) -> type["Callback"]: |
|
"""State key for checkpoints saved prior to version 1.5.0.""" |
|
return type(self) |
|
|
|
def _generate_state_key(self, **kwargs: Any) -> str: |
|
"""Formats a set of key-value pairs into a state key string with the callback class name prefixed. Useful for |
|
defining a :attr:`state_key`. |
|
|
|
Args: |
|
**kwargs: A set of key-value pairs. Must be serializable to :class:`str`. |
|
|
|
""" |
|
return f"{self.__class__.__qualname__}{repr(kwargs)}" |
|
|
|
def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: str) -> None: |
|
"""Called when fit, validate, test, predict, or tune begins.""" |
|
|
|
def teardown(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: str) -> None: |
|
"""Called when fit, validate, test, predict, or tune ends.""" |
|
|
|
def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: |
|
"""Called when fit begins.""" |
|
|
|
def on_fit_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: |
|
"""Called when fit ends.""" |
|
|
|
def on_sanity_check_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: |
|
"""Called when the validation sanity check starts.""" |
|
|
|
def on_sanity_check_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: |
|
"""Called when the validation sanity check ends.""" |
|
|
|
def on_train_batch_start( |
|
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", batch: Any, batch_idx: int |
|
) -> None: |
|
"""Called when the train batch begins.""" |
|
|
|
def on_train_batch_end( |
|
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", outputs: STEP_OUTPUT, batch: Any, batch_idx: int |
|
) -> None: |
|
"""Called when the train batch ends. |
|
|
|
Note: |
|
The value ``outputs["loss"]`` here will be the normalized value w.r.t ``accumulate_grad_batches`` of the |
|
loss returned from ``training_step``. |
|
|
|
""" |
|
|
|
def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: |
|
"""Called when the train epoch begins.""" |
|
|
|
def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: |
|
"""Called when the train epoch ends. |
|
|
|
To access all batch outputs at the end of the epoch, you can cache step outputs as an attribute of the |
|
:class:`pytorch_lightning.core.LightningModule` and access them in this hook: |
|
|
|
.. code-block:: python |
|
|
|
class MyLightningModule(L.LightningModule): |
|
def __init__(self): |
|
super().__init__() |
|
self.training_step_outputs = [] |
|
|
|
def training_step(self): |
|
loss = ... |
|
self.training_step_outputs.append(loss) |
|
return loss |
|
|
|
|
|
class MyCallback(L.Callback): |
|
def on_train_epoch_end(self, trainer, pl_module): |
|
# do something with all training_step outputs, for example: |
|
epoch_mean = torch.stack(pl_module.training_step_outputs).mean() |
|
pl_module.log("training_epoch_mean", epoch_mean) |
|
# free up the memory |
|
pl_module.training_step_outputs.clear() |
|
|
|
""" |
|
|
|
def on_validation_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: |
|
"""Called when the val epoch begins.""" |
|
|
|
def on_validation_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: |
|
"""Called when the val epoch ends.""" |
|
|
|
def on_test_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: |
|
"""Called when the test epoch begins.""" |
|
|
|
def on_test_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: |
|
"""Called when the test epoch ends.""" |
|
|
|
def on_predict_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: |
|
"""Called when the predict epoch begins.""" |
|
|
|
def on_predict_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: |
|
"""Called when the predict epoch ends.""" |
|
|
|
def on_validation_batch_start( |
|
self, |
|
trainer: "pl.Trainer", |
|
pl_module: "pl.LightningModule", |
|
batch: Any, |
|
batch_idx: int, |
|
dataloader_idx: int = 0, |
|
) -> None: |
|
"""Called when the validation batch begins.""" |
|
|
|
def on_validation_batch_end( |
|
self, |
|
trainer: "pl.Trainer", |
|
pl_module: "pl.LightningModule", |
|
outputs: STEP_OUTPUT, |
|
batch: Any, |
|
batch_idx: int, |
|
dataloader_idx: int = 0, |
|
) -> None: |
|
"""Called when the validation batch ends.""" |
|
|
|
def on_test_batch_start( |
|
self, |
|
trainer: "pl.Trainer", |
|
pl_module: "pl.LightningModule", |
|
batch: Any, |
|
batch_idx: int, |
|
dataloader_idx: int = 0, |
|
) -> None: |
|
"""Called when the test batch begins.""" |
|
|
|
def on_test_batch_end( |
|
self, |
|
trainer: "pl.Trainer", |
|
pl_module: "pl.LightningModule", |
|
outputs: STEP_OUTPUT, |
|
batch: Any, |
|
batch_idx: int, |
|
dataloader_idx: int = 0, |
|
) -> None: |
|
"""Called when the test batch ends.""" |
|
|
|
def on_predict_batch_start( |
|
self, |
|
trainer: "pl.Trainer", |
|
pl_module: "pl.LightningModule", |
|
batch: Any, |
|
batch_idx: int, |
|
dataloader_idx: int = 0, |
|
) -> None: |
|
"""Called when the predict batch begins.""" |
|
|
|
def on_predict_batch_end( |
|
self, |
|
trainer: "pl.Trainer", |
|
pl_module: "pl.LightningModule", |
|
outputs: Any, |
|
batch: Any, |
|
batch_idx: int, |
|
dataloader_idx: int = 0, |
|
) -> None: |
|
"""Called when the predict batch ends.""" |
|
|
|
def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: |
|
"""Called when the train begins.""" |
|
|
|
def on_train_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: |
|
"""Called when the train ends.""" |
|
|
|
def on_validation_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: |
|
"""Called when the validation loop begins.""" |
|
|
|
def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: |
|
"""Called when the validation loop ends.""" |
|
|
|
def on_test_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: |
|
"""Called when the test begins.""" |
|
|
|
def on_test_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: |
|
"""Called when the test ends.""" |
|
|
|
def on_predict_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: |
|
"""Called when the predict begins.""" |
|
|
|
def on_predict_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: |
|
"""Called when predict ends.""" |
|
|
|
def on_exception(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", exception: BaseException) -> None: |
|
"""Called when any trainer execution is interrupted by an exception.""" |
|
|
|
def state_dict(self) -> dict[str, Any]: |
|
"""Called when saving a checkpoint, implement to generate callback's ``state_dict``. |
|
|
|
Returns: |
|
A dictionary containing callback state. |
|
|
|
""" |
|
return {} |
|
|
|
def load_state_dict(self, state_dict: dict[str, Any]) -> None: |
|
"""Called when loading a checkpoint, implement to reload callback state given callback's ``state_dict``. |
|
|
|
Args: |
|
state_dict: the callback state returned by ``state_dict``. |
|
|
|
""" |
|
pass |
|
|
|
def on_save_checkpoint( |
|
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", checkpoint: dict[str, Any] |
|
) -> None: |
|
r"""Called when saving a checkpoint to give you a chance to store anything else you might want to save. |
|
|
|
Args: |
|
trainer: the current :class:`~pytorch_lightning.trainer.trainer.Trainer` instance. |
|
pl_module: the current :class:`~pytorch_lightning.core.LightningModule` instance. |
|
checkpoint: the checkpoint dictionary that will be saved. |
|
|
|
""" |
|
|
|
def on_load_checkpoint( |
|
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", checkpoint: dict[str, Any] |
|
) -> None: |
|
r"""Called when loading a model checkpoint, use to reload state. |
|
|
|
Args: |
|
trainer: the current :class:`~pytorch_lightning.trainer.trainer.Trainer` instance. |
|
pl_module: the current :class:`~pytorch_lightning.core.LightningModule` instance. |
|
checkpoint: the full checkpoint dictionary that got loaded by the Trainer. |
|
|
|
""" |
|
|
|
def on_before_backward(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", loss: Tensor) -> None: |
|
"""Called before ``loss.backward()``.""" |
|
|
|
def on_after_backward(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: |
|
"""Called after ``loss.backward()`` and before optimizers are stepped.""" |
|
|
|
def on_before_optimizer_step( |
|
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", optimizer: Optimizer |
|
) -> None: |
|
"""Called before ``optimizer.step()``.""" |
|
|
|
def on_before_zero_grad(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", optimizer: Optimizer) -> None: |
|
"""Called before ``optimizer.zero_grad()``.""" |
|
|