import json import operator import os import warnings from typing import TYPE_CHECKING, Any, Literal, Optional, Union import torch from lightning_utilities.core.imports import compare_version from lightning.fabric.utilities.types import _PATH if TYPE_CHECKING: from lightning.fabric.fabric import Fabric _TORCHMETRICS_GREATER_EQUAL_1_0_0 = compare_version("torchmetrics", operator.ge, "1.0.0") class SpikeDetection: """Spike Detection Callback. Terminates training with a ``TrainingSpikeException`` when a loss-spike was detected and saves the batches to skip when resuming to a file. We skip the current and the previous batch since it is unclear whether the previous batch altered the weights in a way that it causes the spike or just the current batch is corrupted somehow. Args: mode: Whether to minimize or maximize the tracked metric window: A running mean of metrics with ``window`` size. Serves as reference value for spikes. warmup: After how many batches spike-tracking should start atol: An absolute tolerance. Every diff between the running mean and the current value, that's not an improvement and above ``atol`` will be considered a spike rtol: A relative tolerance. Every diff between the running mean and the current value, that's higher than ``rtol * running_mean`` is considered a spike exclude_batches_path: Where to save the file that contains the batches to exclude. Will default to current directory. finite_only: If set to ``False``, consider non-finite values like NaN, inf and -inf a spike as well. """ def __init__( self, mode: Literal["min", "max"] = "min", window: int = 10, warmup: int = 1, atol: Optional[float] = None, rtol: Optional[float] = 2.0, exclude_batches_path: Optional[_PATH] = None, finite_only: bool = True, ): if _TORCHMETRICS_GREATER_EQUAL_1_0_0: from torchmetrics.aggregation import MeanMetric from torchmetrics.wrappers import Running else: raise RuntimeError("SpikeDetection requires `torchmetrics>=1.0.0` Please upgrade your version.") super().__init__() self.last_val: Union[torch.Tensor, float] = 0.0 # spike detection happens individually on each machine self.running_mean = Running(MeanMetric(dist_sync_on_step=False, sync_on_compute=False), window=window) # workaround for https://github.com/Lightning-AI/torchmetrics/issues/1899 self.running_mean.dist_sync_on_step = False self.running_mean.sync_on_compute = False self.mode = mode self.warmup = warmup self.atol = atol self.rtol = rtol self.bad_batches: list[int] = [] self.exclude_batches_path = exclude_batches_path self.finite_only = finite_only @torch.no_grad() def on_train_batch_end(self, fabric: "Fabric", loss: torch.Tensor, batch: Any, batch_idx: int) -> None: """Checks if we currently have a loss-spike.""" if batch_idx == 0: self.running_mean.to(fabric.strategy.root_device) if self.exclude_batches_path is None: self.exclude_batches_path = os.getcwd() if not str(self.exclude_batches_path).endswith(".json"): self.exclude_batches_path = os.path.join(self.exclude_batches_path, "skip_batches.json") is_spike = bool(batch_idx >= self.warmup and self._is_spike(loss)) fabric.strategy.barrier() # While spike-detection happens on a per-rank level, we need to fail all ranks if any rank detected a spike is_spike_global = fabric.strategy.reduce_boolean_decision(is_spike, all=False) if is_spike_global: self._handle_spike(fabric, batch_idx) else: is_finite_all = self.finite_only or fabric.strategy.reduce_boolean_decision( bool(torch.isfinite(loss).all()), all=True ) if is_finite_all: self._update_stats(loss) def _is_spike(self, loss: torch.Tensor) -> bool: # we might call compute more often than update which is fine as long as the # metric has at least one internal value. with warnings.catch_warnings(): warnings.simplefilter("ignore") running_val = self.running_mean.compute() curr_diff = loss - self.last_val if self.finite_only and not torch.isfinite(loss): return True if self._is_better(curr_diff): return False return self._check_atol(loss, running_val) and self._check_rtol(loss, running_val) def _handle_spike(self, fabric: "Fabric", batch_idx: int) -> None: # Exclude current and last batch # Current batch is excluded since it could be that the data of this batch produces a high loss # Last batch is excluded since the previous batch could have "corrupted" the weights self.bad_batches.extend([batch_idx - 1, batch_idx]) if fabric.global_rank == 0: assert self.exclude_batches_path is not None os.makedirs(os.path.dirname(self.exclude_batches_path), exist_ok=True) with open(self.exclude_batches_path, "w") as f: json.dump(self.bad_batches, f, indent=4) raise TrainingSpikeException(batch_idx=batch_idx) def _check_atol(self, val_a: Union[float, torch.Tensor], val_b: Union[float, torch.Tensor]) -> bool: return (self.atol is None) or bool(abs(val_a - val_b) >= abs(self.atol)) def _check_rtol(self, val_a: Union[float, torch.Tensor], val_b: Union[float, torch.Tensor]) -> bool: return (self.rtol is None) or bool(abs(val_a - val_b) >= abs(self.rtol * val_b)) def _is_better(self, diff_val: torch.Tensor) -> bool: if self.mode == "min": return bool((diff_val <= 0.0).all()) if self.mode == "max": return bool((diff_val >= 0).all()) raise ValueError(f"Invalid mode. Has to be min or max, found {self.mode}") def _update_stats(self, val: torch.Tensor) -> None: # only update if finite self.running_mean.update(val) self.last_val = val def state_dict(self) -> dict[str, Any]: return { "last_val": self.last_val.item() if isinstance(self.last_val, torch.Tensor) else self.last_val, "mode": self.mode, "warmup": self.warmup, "atol": self.atol, "rtol": self.rtol, "bad_batches": self.bad_batches, "bad_batches_path": self.exclude_batches_path, "running": self.running_mean.state_dict(), "mean": self.running_mean.base_metric.state_dict(), } def load_state_dict(self, state_dict: dict[str, Any]) -> None: self.last_val = state_dict.pop("last_val") self.mode = state_dict.pop("mode") self.warmup = state_dict.pop("warmup") self.atol = state_dict.pop("atol") self.rtol = state_dict.pop("rtol") self.bad_batches = state_dict.pop("bad_batches") self.exclude_batches_path = state_dict.pop("bad_batches_path") self.running.load_state_dict(state_dict.pop("running")) self.running_mean.base_metric.load_state_dict(state_dict.pop("mean")) class TrainingSpikeException(RuntimeError): """Exception to be raised with Training Spikes.""" def __init__(self, batch_idx: int, *args: Any, **kwargs: Any): super().__init__(f"Training spike detected in batch {batch_idx}", *args, **kwargs)