|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Utilities for loggers.""" |
|
|
|
from pathlib import Path |
|
from typing import Any, Union |
|
|
|
from torch import Tensor |
|
|
|
import pytorch_lightning as pl |
|
from pytorch_lightning.callbacks import Checkpoint |
|
|
|
|
|
def _version(loggers: list[Any], separator: str = "_") -> Union[int, str]: |
|
if len(loggers) == 1: |
|
return loggers[0].version |
|
|
|
return separator.join(dict.fromkeys(str(logger.version) for logger in loggers)) |
|
|
|
|
|
def _scan_checkpoints(checkpoint_callback: Checkpoint, logged_model_time: dict) -> list[tuple[float, str, float, str]]: |
|
"""Return the checkpoints to be logged. |
|
|
|
Args: |
|
checkpoint_callback: Checkpoint callback reference. |
|
logged_model_time: dictionary containing the logged model times. |
|
|
|
""" |
|
|
|
checkpoints = {} |
|
if hasattr(checkpoint_callback, "last_model_path") and hasattr(checkpoint_callback, "current_score"): |
|
checkpoints[checkpoint_callback.last_model_path] = (checkpoint_callback.current_score, "latest") |
|
|
|
if hasattr(checkpoint_callback, "best_model_path") and hasattr(checkpoint_callback, "best_model_score"): |
|
checkpoints[checkpoint_callback.best_model_path] = (checkpoint_callback.best_model_score, "best") |
|
|
|
if hasattr(checkpoint_callback, "best_k_models"): |
|
for key, value in checkpoint_callback.best_k_models.items(): |
|
checkpoints[key] = (value, "best_k") |
|
|
|
checkpoints = sorted( |
|
(Path(p).stat().st_mtime, p, s, tag) for p, (s, tag) in checkpoints.items() if Path(p).is_file() |
|
) |
|
checkpoints = [c for c in checkpoints if c[1] not in logged_model_time or logged_model_time[c[1]] < c[0]] |
|
return checkpoints |
|
|
|
|
|
def _log_hyperparams(trainer: "pl.Trainer") -> None: |
|
if not trainer.loggers: |
|
return |
|
|
|
pl_module = trainer.lightning_module |
|
datamodule_log_hyperparams = trainer.datamodule._log_hyperparams if trainer.datamodule is not None else False |
|
|
|
hparams_initial = None |
|
if pl_module._log_hyperparams and datamodule_log_hyperparams: |
|
datamodule_hparams = trainer.datamodule.hparams_initial |
|
lightning_hparams = pl_module.hparams_initial |
|
inconsistent_keys = [] |
|
for key in lightning_hparams.keys() & datamodule_hparams.keys(): |
|
if key == "_class_path": |
|
|
|
continue |
|
lm_val, dm_val = lightning_hparams[key], datamodule_hparams[key] |
|
if ( |
|
type(lm_val) != type(dm_val) |
|
or (isinstance(lm_val, Tensor) and id(lm_val) != id(dm_val)) |
|
or lm_val != dm_val |
|
): |
|
inconsistent_keys.append(key) |
|
if inconsistent_keys: |
|
raise RuntimeError( |
|
f"Error while merging hparams: the keys {inconsistent_keys} are present " |
|
"in both the LightningModule's and LightningDataModule's hparams " |
|
"but have different values." |
|
) |
|
hparams_initial = {**lightning_hparams, **datamodule_hparams} |
|
elif pl_module._log_hyperparams: |
|
hparams_initial = pl_module.hparams_initial |
|
elif datamodule_log_hyperparams: |
|
hparams_initial = trainer.datamodule.hparams_initial |
|
|
|
|
|
if hparams_initial is not None: |
|
hparams_initial = {k: v for k, v in hparams_initial.items() if k != "_class_path"} |
|
|
|
for logger in trainer.loggers: |
|
if hparams_initial is not None: |
|
logger.log_hyperparams(hparams_initial) |
|
logger.log_graph(pl_module) |
|
logger.save() |
|
|