File size: 7,634 Bytes
9c6594c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
import json
import operator
import os
import warnings
from typing import TYPE_CHECKING, Any, Literal, Optional, Union

import torch
from lightning_utilities.core.imports import compare_version

from lightning_fabric.utilities.types import _PATH

if TYPE_CHECKING:
    from lightning_fabric.fabric import Fabric

_TORCHMETRICS_GREATER_EQUAL_1_0_0 = compare_version("torchmetrics", operator.ge, "1.0.0")


class SpikeDetection:
    """Spike Detection Callback.

    Terminates training with a ``TrainingSpikeException`` when a loss-spike was detected and
    saves the batches to skip when resuming to a file.

    We skip the current and the previous batch since it is unclear whether the previous batch
    altered the weights in a way that it causes the spike or just the current batch is corrupted somehow.

    Args:
        mode: Whether to minimize or maximize the tracked metric
        window: A running mean of metrics with ``window`` size. Serves as reference value for spikes.
        warmup: After how many batches spike-tracking should start
        atol: An absolute tolerance.  Every diff between the running mean and the current value,
            that's not an improvement and above ``atol`` will be considered a spike
        rtol: A relative tolerance. Every diff between the running mean and the current value,
            that's higher than ``rtol * running_mean`` is considered a spike
        exclude_batches_path: Where to save the file that contains the batches to exclude.
            Will default to current directory.
        finite_only: If set to ``False``, consider non-finite values like NaN, inf and -inf a spike as well.

    """

    def __init__(
        self,
        mode: Literal["min", "max"] = "min",
        window: int = 10,
        warmup: int = 1,
        atol: Optional[float] = None,
        rtol: Optional[float] = 2.0,
        exclude_batches_path: Optional[_PATH] = None,
        finite_only: bool = True,
    ):
        if _TORCHMETRICS_GREATER_EQUAL_1_0_0:
            from torchmetrics.aggregation import MeanMetric
            from torchmetrics.wrappers import Running
        else:
            raise RuntimeError("SpikeDetection requires `torchmetrics>=1.0.0` Please upgrade your version.")
        super().__init__()

        self.last_val: Union[torch.Tensor, float] = 0.0
        # spike detection happens individually on each machine
        self.running_mean = Running(MeanMetric(dist_sync_on_step=False, sync_on_compute=False), window=window)
        # workaround for https://github.com/Lightning-AI/torchmetrics/issues/1899
        self.running_mean.dist_sync_on_step = False
        self.running_mean.sync_on_compute = False

        self.mode = mode
        self.warmup = warmup
        self.atol = atol
        self.rtol = rtol
        self.bad_batches: list[int] = []
        self.exclude_batches_path = exclude_batches_path
        self.finite_only = finite_only

    @torch.no_grad()
    def on_train_batch_end(self, fabric: "Fabric", loss: torch.Tensor, batch: Any, batch_idx: int) -> None:
        """Checks if we currently have a loss-spike."""
        if batch_idx == 0:
            self.running_mean.to(fabric.strategy.root_device)

        if self.exclude_batches_path is None:
            self.exclude_batches_path = os.getcwd()

        if not str(self.exclude_batches_path).endswith(".json"):
            self.exclude_batches_path = os.path.join(self.exclude_batches_path, "skip_batches.json")

        is_spike = bool(batch_idx >= self.warmup and self._is_spike(loss))
        fabric.strategy.barrier()

        # While spike-detection happens on a per-rank level, we need to fail all ranks if any rank detected a spike
        is_spike_global = fabric.strategy.reduce_boolean_decision(is_spike, all=False)

        if is_spike_global:
            self._handle_spike(fabric, batch_idx)
        else:
            is_finite_all = self.finite_only or fabric.strategy.reduce_boolean_decision(
                bool(torch.isfinite(loss).all()), all=True
            )
            if is_finite_all:
                self._update_stats(loss)

    def _is_spike(self, loss: torch.Tensor) -> bool:
        # we might call compute more often than update which is fine as long as the
        # metric has at least one internal value.
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            running_val = self.running_mean.compute()
        curr_diff = loss - self.last_val

        if self.finite_only and not torch.isfinite(loss):
            return True

        if self._is_better(curr_diff):
            return False

        return self._check_atol(loss, running_val) and self._check_rtol(loss, running_val)

    def _handle_spike(self, fabric: "Fabric", batch_idx: int) -> None:
        # Exclude current and last batch
        # Current batch is excluded since it could be that the data of this batch produces a high loss
        # Last batch is excluded since the previous batch could have "corrupted" the weights
        self.bad_batches.extend([batch_idx - 1, batch_idx])

        if fabric.global_rank == 0:
            assert self.exclude_batches_path is not None
            os.makedirs(os.path.dirname(self.exclude_batches_path), exist_ok=True)

            with open(self.exclude_batches_path, "w") as f:
                json.dump(self.bad_batches, f, indent=4)

        raise TrainingSpikeException(batch_idx=batch_idx)

    def _check_atol(self, val_a: Union[float, torch.Tensor], val_b: Union[float, torch.Tensor]) -> bool:
        return (self.atol is None) or bool(abs(val_a - val_b) >= abs(self.atol))

    def _check_rtol(self, val_a: Union[float, torch.Tensor], val_b: Union[float, torch.Tensor]) -> bool:
        return (self.rtol is None) or bool(abs(val_a - val_b) >= abs(self.rtol * val_b))

    def _is_better(self, diff_val: torch.Tensor) -> bool:
        if self.mode == "min":
            return bool((diff_val <= 0.0).all())
        if self.mode == "max":
            return bool((diff_val >= 0).all())

        raise ValueError(f"Invalid mode. Has to be min or max, found {self.mode}")

    def _update_stats(self, val: torch.Tensor) -> None:
        # only update if finite
        self.running_mean.update(val)
        self.last_val = val

    def state_dict(self) -> dict[str, Any]:
        return {
            "last_val": self.last_val.item() if isinstance(self.last_val, torch.Tensor) else self.last_val,
            "mode": self.mode,
            "warmup": self.warmup,
            "atol": self.atol,
            "rtol": self.rtol,
            "bad_batches": self.bad_batches,
            "bad_batches_path": self.exclude_batches_path,
            "running": self.running_mean.state_dict(),
            "mean": self.running_mean.base_metric.state_dict(),
        }

    def load_state_dict(self, state_dict: dict[str, Any]) -> None:
        self.last_val = state_dict.pop("last_val")
        self.mode = state_dict.pop("mode")
        self.warmup = state_dict.pop("warmup")
        self.atol = state_dict.pop("atol")
        self.rtol = state_dict.pop("rtol")
        self.bad_batches = state_dict.pop("bad_batches")
        self.exclude_batches_path = state_dict.pop("bad_batches_path")
        self.running.load_state_dict(state_dict.pop("running"))
        self.running_mean.base_metric.load_state_dict(state_dict.pop("mean"))


class TrainingSpikeException(RuntimeError):
    """Exception to be raised with Training Spikes."""

    def __init__(self, batch_idx: int, *args: Any, **kwargs: Any):
        super().__init__(f"Training spike detected in batch {batch_idx}", *args, **kwargs)