|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Abstract base class used to build new loggers.""" |
|
|
|
import functools |
|
import operator |
|
import statistics |
|
from abc import ABC |
|
from collections import defaultdict |
|
from collections.abc import Mapping, Sequence |
|
from typing import Any, Callable, Optional |
|
|
|
from typing_extensions import override |
|
|
|
from lightning_fabric.loggers import Logger as FabricLogger |
|
from lightning_fabric.loggers.logger import _DummyExperiment as DummyExperiment |
|
from lightning_fabric.loggers.logger import rank_zero_experiment |
|
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint |
|
|
|
|
|
class Logger(FabricLogger, ABC): |
|
"""Base class for experiment loggers.""" |
|
|
|
def after_save_checkpoint(self, checkpoint_callback: ModelCheckpoint) -> None: |
|
"""Called after model checkpoint callback saves a new checkpoint. |
|
|
|
Args: |
|
checkpoint_callback: the model checkpoint callback instance |
|
|
|
""" |
|
pass |
|
|
|
@property |
|
def save_dir(self) -> Optional[str]: |
|
"""Return the root directory where experiment logs get saved, or `None` if the logger does not save data |
|
locally.""" |
|
return None |
|
|
|
|
|
class DummyLogger(Logger): |
|
"""Dummy logger for internal use. |
|
|
|
It is useful if we want to disable user's logger for a feature, but still ensure that user code can run |
|
|
|
""" |
|
|
|
def __init__(self) -> None: |
|
super().__init__() |
|
self._experiment = DummyExperiment() |
|
|
|
@property |
|
def experiment(self) -> DummyExperiment: |
|
"""Return the experiment object associated with this logger.""" |
|
return self._experiment |
|
|
|
@override |
|
def log_metrics(self, *args: Any, **kwargs: Any) -> None: |
|
pass |
|
|
|
@override |
|
def log_hyperparams(self, *args: Any, **kwargs: Any) -> None: |
|
pass |
|
|
|
@property |
|
@override |
|
def name(self) -> str: |
|
"""Return the experiment name.""" |
|
return "" |
|
|
|
@property |
|
@override |
|
def version(self) -> str: |
|
"""Return the experiment version.""" |
|
return "" |
|
|
|
def __getitem__(self, idx: int) -> "DummyLogger": |
|
|
|
return self |
|
|
|
def __getattr__(self, name: str) -> Callable: |
|
"""Allows the DummyLogger to be called with arbitrary methods, to avoid AttributeErrors.""" |
|
|
|
def method(*args: Any, **kwargs: Any) -> None: |
|
return None |
|
|
|
return method |
|
|
|
|
|
|
|
def merge_dicts( |
|
dicts: Sequence[Mapping], |
|
agg_key_funcs: Optional[Mapping] = None, |
|
default_func: Callable[[Sequence[float]], float] = statistics.mean, |
|
) -> dict: |
|
"""Merge a sequence with dictionaries into one dictionary by aggregating the same keys with some given function. |
|
|
|
Args: |
|
dicts: |
|
Sequence of dictionaries to be merged. |
|
agg_key_funcs: |
|
Mapping from key name to function. This function will aggregate a |
|
list of values, obtained from the same key of all dictionaries. |
|
If some key has no specified aggregation function, the default one |
|
will be used. Default is: ``None`` (all keys will be aggregated by the |
|
default function). |
|
default_func: |
|
Default function to aggregate keys, which are not presented in the |
|
`agg_key_funcs` map. |
|
|
|
Returns: |
|
Dictionary with merged values. |
|
|
|
Examples: |
|
>>> import pprint |
|
>>> d1 = {'a': 1.7, 'b': 2.0, 'c': 1, 'd': {'d1': 1, 'd3': 3}} |
|
>>> d2 = {'a': 1.1, 'b': 2.2, 'v': 1, 'd': {'d1': 2, 'd2': 3}} |
|
>>> d3 = {'a': 1.1, 'v': 2.3, 'd': {'d3': 3, 'd4': {'d5': 1}}} |
|
>>> dflt_func = min |
|
>>> agg_funcs = {'a': statistics.mean, 'v': max, 'd': {'d1': sum}} |
|
>>> pprint.pprint(merge_dicts([d1, d2, d3], agg_funcs, dflt_func)) |
|
{'a': 1.3, |
|
'b': 2.0, |
|
'c': 1, |
|
'd': {'d1': 3, 'd2': 3, 'd3': 3, 'd4': {'d5': 1}}, |
|
'v': 2.3} |
|
|
|
""" |
|
agg_key_funcs = agg_key_funcs or {} |
|
keys = list(functools.reduce(operator.or_, [set(d.keys()) for d in dicts])) |
|
d_out: dict = defaultdict(dict) |
|
for k in keys: |
|
fn = agg_key_funcs.get(k) |
|
values_to_agg = [v for v in [d_in.get(k) for d_in in dicts] if v is not None] |
|
|
|
if isinstance(values_to_agg[0], dict): |
|
d_out[k] = merge_dicts(values_to_agg, fn, default_func) |
|
else: |
|
d_out[k] = (fn or default_func)(values_to_agg) |
|
|
|
return dict(d_out) |
|
|