jamtur01's picture
Upload folder using huggingface_hub
9c6594c verified
import json
import operator
import os
import warnings
from typing import TYPE_CHECKING, Any, Literal, Optional, Union
import torch
from lightning_utilities.core.imports import compare_version
from lightning_fabric.utilities.types import _PATH
if TYPE_CHECKING:
from lightning_fabric.fabric import Fabric
_TORCHMETRICS_GREATER_EQUAL_1_0_0 = compare_version("torchmetrics", operator.ge, "1.0.0")
class SpikeDetection:
"""Spike Detection Callback.
Terminates training with a ``TrainingSpikeException`` when a loss-spike was detected and
saves the batches to skip when resuming to a file.
We skip the current and the previous batch since it is unclear whether the previous batch
altered the weights in a way that it causes the spike or just the current batch is corrupted somehow.
Args:
mode: Whether to minimize or maximize the tracked metric
window: A running mean of metrics with ``window`` size. Serves as reference value for spikes.
warmup: After how many batches spike-tracking should start
atol: An absolute tolerance. Every diff between the running mean and the current value,
that's not an improvement and above ``atol`` will be considered a spike
rtol: A relative tolerance. Every diff between the running mean and the current value,
that's higher than ``rtol * running_mean`` is considered a spike
exclude_batches_path: Where to save the file that contains the batches to exclude.
Will default to current directory.
finite_only: If set to ``False``, consider non-finite values like NaN, inf and -inf a spike as well.
"""
def __init__(
self,
mode: Literal["min", "max"] = "min",
window: int = 10,
warmup: int = 1,
atol: Optional[float] = None,
rtol: Optional[float] = 2.0,
exclude_batches_path: Optional[_PATH] = None,
finite_only: bool = True,
):
if _TORCHMETRICS_GREATER_EQUAL_1_0_0:
from torchmetrics.aggregation import MeanMetric
from torchmetrics.wrappers import Running
else:
raise RuntimeError("SpikeDetection requires `torchmetrics>=1.0.0` Please upgrade your version.")
super().__init__()
self.last_val: Union[torch.Tensor, float] = 0.0
# spike detection happens individually on each machine
self.running_mean = Running(MeanMetric(dist_sync_on_step=False, sync_on_compute=False), window=window)
# workaround for https://github.com/Lightning-AI/torchmetrics/issues/1899
self.running_mean.dist_sync_on_step = False
self.running_mean.sync_on_compute = False
self.mode = mode
self.warmup = warmup
self.atol = atol
self.rtol = rtol
self.bad_batches: list[int] = []
self.exclude_batches_path = exclude_batches_path
self.finite_only = finite_only
@torch.no_grad()
def on_train_batch_end(self, fabric: "Fabric", loss: torch.Tensor, batch: Any, batch_idx: int) -> None:
"""Checks if we currently have a loss-spike."""
if batch_idx == 0:
self.running_mean.to(fabric.strategy.root_device)
if self.exclude_batches_path is None:
self.exclude_batches_path = os.getcwd()
if not str(self.exclude_batches_path).endswith(".json"):
self.exclude_batches_path = os.path.join(self.exclude_batches_path, "skip_batches.json")
is_spike = bool(batch_idx >= self.warmup and self._is_spike(loss))
fabric.strategy.barrier()
# While spike-detection happens on a per-rank level, we need to fail all ranks if any rank detected a spike
is_spike_global = fabric.strategy.reduce_boolean_decision(is_spike, all=False)
if is_spike_global:
self._handle_spike(fabric, batch_idx)
else:
is_finite_all = self.finite_only or fabric.strategy.reduce_boolean_decision(
bool(torch.isfinite(loss).all()), all=True
)
if is_finite_all:
self._update_stats(loss)
def _is_spike(self, loss: torch.Tensor) -> bool:
# we might call compute more often than update which is fine as long as the
# metric has at least one internal value.
with warnings.catch_warnings():
warnings.simplefilter("ignore")
running_val = self.running_mean.compute()
curr_diff = loss - self.last_val
if self.finite_only and not torch.isfinite(loss):
return True
if self._is_better(curr_diff):
return False
return self._check_atol(loss, running_val) and self._check_rtol(loss, running_val)
def _handle_spike(self, fabric: "Fabric", batch_idx: int) -> None:
# Exclude current and last batch
# Current batch is excluded since it could be that the data of this batch produces a high loss
# Last batch is excluded since the previous batch could have "corrupted" the weights
self.bad_batches.extend([batch_idx - 1, batch_idx])
if fabric.global_rank == 0:
assert self.exclude_batches_path is not None
os.makedirs(os.path.dirname(self.exclude_batches_path), exist_ok=True)
with open(self.exclude_batches_path, "w") as f:
json.dump(self.bad_batches, f, indent=4)
raise TrainingSpikeException(batch_idx=batch_idx)
def _check_atol(self, val_a: Union[float, torch.Tensor], val_b: Union[float, torch.Tensor]) -> bool:
return (self.atol is None) or bool(abs(val_a - val_b) >= abs(self.atol))
def _check_rtol(self, val_a: Union[float, torch.Tensor], val_b: Union[float, torch.Tensor]) -> bool:
return (self.rtol is None) or bool(abs(val_a - val_b) >= abs(self.rtol * val_b))
def _is_better(self, diff_val: torch.Tensor) -> bool:
if self.mode == "min":
return bool((diff_val <= 0.0).all())
if self.mode == "max":
return bool((diff_val >= 0).all())
raise ValueError(f"Invalid mode. Has to be min or max, found {self.mode}")
def _update_stats(self, val: torch.Tensor) -> None:
# only update if finite
self.running_mean.update(val)
self.last_val = val
def state_dict(self) -> dict[str, Any]:
return {
"last_val": self.last_val.item() if isinstance(self.last_val, torch.Tensor) else self.last_val,
"mode": self.mode,
"warmup": self.warmup,
"atol": self.atol,
"rtol": self.rtol,
"bad_batches": self.bad_batches,
"bad_batches_path": self.exclude_batches_path,
"running": self.running_mean.state_dict(),
"mean": self.running_mean.base_metric.state_dict(),
}
def load_state_dict(self, state_dict: dict[str, Any]) -> None:
self.last_val = state_dict.pop("last_val")
self.mode = state_dict.pop("mode")
self.warmup = state_dict.pop("warmup")
self.atol = state_dict.pop("atol")
self.rtol = state_dict.pop("rtol")
self.bad_batches = state_dict.pop("bad_batches")
self.exclude_batches_path = state_dict.pop("bad_batches_path")
self.running.load_state_dict(state_dict.pop("running"))
self.running_mean.base_metric.load_state_dict(state_dict.pop("mean"))
class TrainingSpikeException(RuntimeError):
"""Exception to be raised with Training Spikes."""
def __init__(self, batch_idx: int, *args: Any, **kwargs: Any):
super().__init__(f"Training spike detected in batch {batch_idx}", *args, **kwargs)