# Copyright The Lightning AI team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. r"""Base class used to build new callbacks.""" from typing import Any from torch import Tensor from torch.optim import Optimizer import pytorch_lightning as pl from pytorch_lightning.utilities.types import STEP_OUTPUT class Callback: r"""Abstract base class used to build new callbacks. Subclass this class and override any of the relevant hooks """ @property def state_key(self) -> str: """Identifier for the state of the callback. Used to store and retrieve a callback's state from the checkpoint dictionary by ``checkpoint["callbacks"][state_key]``. Implementations of a callback need to provide a unique state key if 1) the callback has state and 2) it is desired to maintain the state of multiple instances of that callback. """ return self.__class__.__qualname__ @property def _legacy_state_key(self) -> type["Callback"]: """State key for checkpoints saved prior to version 1.5.0.""" return type(self) def _generate_state_key(self, **kwargs: Any) -> str: """Formats a set of key-value pairs into a state key string with the callback class name prefixed. Useful for defining a :attr:`state_key`. Args: **kwargs: A set of key-value pairs. Must be serializable to :class:`str`. """ return f"{self.__class__.__qualname__}{repr(kwargs)}" def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: str) -> None: """Called when fit, validate, test, predict, or tune begins.""" def teardown(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: str) -> None: """Called when fit, validate, test, predict, or tune ends.""" def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: """Called when fit begins.""" def on_fit_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: """Called when fit ends.""" def on_sanity_check_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: """Called when the validation sanity check starts.""" def on_sanity_check_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: """Called when the validation sanity check ends.""" def on_train_batch_start( self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", batch: Any, batch_idx: int ) -> None: """Called when the train batch begins.""" def on_train_batch_end( self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", outputs: STEP_OUTPUT, batch: Any, batch_idx: int ) -> None: """Called when the train batch ends. Note: The value ``outputs["loss"]`` here will be the normalized value w.r.t ``accumulate_grad_batches`` of the loss returned from ``training_step``. """ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: """Called when the train epoch begins.""" def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: """Called when the train epoch ends. To access all batch outputs at the end of the epoch, you can cache step outputs as an attribute of the :class:`pytorch_lightning.core.LightningModule` and access them in this hook: .. code-block:: python class MyLightningModule(L.LightningModule): def __init__(self): super().__init__() self.training_step_outputs = [] def training_step(self): loss = ... self.training_step_outputs.append(loss) return loss class MyCallback(L.Callback): def on_train_epoch_end(self, trainer, pl_module): # do something with all training_step outputs, for example: epoch_mean = torch.stack(pl_module.training_step_outputs).mean() pl_module.log("training_epoch_mean", epoch_mean) # free up the memory pl_module.training_step_outputs.clear() """ def on_validation_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: """Called when the val epoch begins.""" def on_validation_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: """Called when the val epoch ends.""" def on_test_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: """Called when the test epoch begins.""" def on_test_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: """Called when the test epoch ends.""" def on_predict_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: """Called when the predict epoch begins.""" def on_predict_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: """Called when the predict epoch ends.""" def on_validation_batch_start( self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", batch: Any, batch_idx: int, dataloader_idx: int = 0, ) -> None: """Called when the validation batch begins.""" def on_validation_batch_end( self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", outputs: STEP_OUTPUT, batch: Any, batch_idx: int, dataloader_idx: int = 0, ) -> None: """Called when the validation batch ends.""" def on_test_batch_start( self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", batch: Any, batch_idx: int, dataloader_idx: int = 0, ) -> None: """Called when the test batch begins.""" def on_test_batch_end( self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", outputs: STEP_OUTPUT, batch: Any, batch_idx: int, dataloader_idx: int = 0, ) -> None: """Called when the test batch ends.""" def on_predict_batch_start( self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", batch: Any, batch_idx: int, dataloader_idx: int = 0, ) -> None: """Called when the predict batch begins.""" def on_predict_batch_end( self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", outputs: Any, batch: Any, batch_idx: int, dataloader_idx: int = 0, ) -> None: """Called when the predict batch ends.""" def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: """Called when the train begins.""" def on_train_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: """Called when the train ends.""" def on_validation_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: """Called when the validation loop begins.""" def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: """Called when the validation loop ends.""" def on_test_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: """Called when the test begins.""" def on_test_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: """Called when the test ends.""" def on_predict_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: """Called when the predict begins.""" def on_predict_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: """Called when predict ends.""" def on_exception(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", exception: BaseException) -> None: """Called when any trainer execution is interrupted by an exception.""" def state_dict(self) -> dict[str, Any]: """Called when saving a checkpoint, implement to generate callback's ``state_dict``. Returns: A dictionary containing callback state. """ return {} def load_state_dict(self, state_dict: dict[str, Any]) -> None: """Called when loading a checkpoint, implement to reload callback state given callback's ``state_dict``. Args: state_dict: the callback state returned by ``state_dict``. """ pass def on_save_checkpoint( self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", checkpoint: dict[str, Any] ) -> None: r"""Called when saving a checkpoint to give you a chance to store anything else you might want to save. Args: trainer: the current :class:`~pytorch_lightning.trainer.trainer.Trainer` instance. pl_module: the current :class:`~pytorch_lightning.core.LightningModule` instance. checkpoint: the checkpoint dictionary that will be saved. """ def on_load_checkpoint( self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", checkpoint: dict[str, Any] ) -> None: r"""Called when loading a model checkpoint, use to reload state. Args: trainer: the current :class:`~pytorch_lightning.trainer.trainer.Trainer` instance. pl_module: the current :class:`~pytorch_lightning.core.LightningModule` instance. checkpoint: the full checkpoint dictionary that got loaded by the Trainer. """ def on_before_backward(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", loss: Tensor) -> None: """Called before ``loss.backward()``.""" def on_after_backward(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: """Called after ``loss.backward()`` and before optimizers are stepped.""" def on_before_optimizer_step( self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", optimizer: Optimizer ) -> None: """Called before ``optimizer.step()``.""" def on_before_zero_grad(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", optimizer: Optimizer) -> None: """Called before ``optimizer.zero_grad()``."""