File size: 1,154 Bytes
9c6594c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import os
from collections.abc import Mapping
from typing import Any, Union

import torch

import pytorch_lightning as pl
from lightning_fabric.utilities.spike import SpikeDetection as FabricSpikeDetection
from pytorch_lightning.callbacks.callback import Callback


class SpikeDetection(FabricSpikeDetection, Callback):
    @torch.no_grad()
    def on_train_batch_end(  # type: ignore
        self,
        trainer: "pl.Trainer",
        pl_module: "pl.LightningModule",
        outputs: Union[torch.Tensor, Mapping[str, torch.Tensor]],
        batch: Any,
        batch_idx: int,
    ) -> None:
        if isinstance(outputs, torch.Tensor):
            loss = outputs.detach()
        elif isinstance(outputs, Mapping):
            loss = outputs["loss"].detach()
        else:
            raise TypeError(f"outputs have to be of type torch.Tensor or Mapping, got {type(outputs).__qualname__}")

        if self.exclude_batches_path is None:
            self.exclude_batches_path = os.path.join(trainer.default_root_dir, "skip_batches.json")

        return FabricSpikeDetection.on_train_batch_end(self, trainer, loss, batch, batch_idx)  # type: ignore