|
import os |
|
from collections.abc import Mapping |
|
from typing import Any, Union |
|
|
|
import torch |
|
|
|
import pytorch_lightning as pl |
|
from lightning_fabric.utilities.spike import SpikeDetection as FabricSpikeDetection |
|
from pytorch_lightning.callbacks.callback import Callback |
|
|
|
|
|
class SpikeDetection(FabricSpikeDetection, Callback): |
|
@torch.no_grad() |
|
def on_train_batch_end( |
|
self, |
|
trainer: "pl.Trainer", |
|
pl_module: "pl.LightningModule", |
|
outputs: Union[torch.Tensor, Mapping[str, torch.Tensor]], |
|
batch: Any, |
|
batch_idx: int, |
|
) -> None: |
|
if isinstance(outputs, torch.Tensor): |
|
loss = outputs.detach() |
|
elif isinstance(outputs, Mapping): |
|
loss = outputs["loss"].detach() |
|
else: |
|
raise TypeError(f"outputs have to be of type torch.Tensor or Mapping, got {type(outputs).__qualname__}") |
|
|
|
if self.exclude_batches_path is None: |
|
self.exclude_batches_path = os.path.join(trainer.default_root_dir, "skip_batches.json") |
|
|
|
return FabricSpikeDetection.on_train_batch_end(self, trainer, loss, batch, batch_idx) |
|
|