# Copyright The Lightning AI team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect import json from argparse import Namespace from collections.abc import Mapping, MutableMapping from dataclasses import asdict, is_dataclass from typing import Any, Optional, Union from torch import Tensor from lightning_fabric.utilities.imports import _NUMPY_AVAILABLE def _convert_params(params: Optional[Union[dict[str, Any], Namespace]]) -> dict[str, Any]: """Ensure parameters are a dict or convert to dict if necessary. Args: params: Target to be converted to a dictionary Returns: params as a dictionary """ # in case converting from namespace if isinstance(params, Namespace): params = vars(params) if params is None: params = {} return params def _sanitize_callable_params(params: dict[str, Any]) -> dict[str, Any]: """Sanitize callable params dict, e.g. ``{'a': } -> {'a': 'function_****'}``. Args: params: Dictionary containing the hyperparameters Returns: dictionary with all callables sanitized """ def _sanitize_callable(val: Any) -> Any: if inspect.isclass(val): # If it's a class, don't try to instantiate it, just return the name return val.__name__ if callable(val): # Callables get a chance to return a name try: _val = val() if callable(_val): return val.__name__ return _val # todo: specify the possible exception except Exception: return getattr(val, "__name__", None) return val return {key: _sanitize_callable(val) for key, val in params.items()} def _flatten_dict(params: MutableMapping[Any, Any], delimiter: str = "/", parent_key: str = "") -> dict[str, Any]: """Flatten hierarchical dict, e.g. ``{'a': {'b': 'c'}} -> {'a/b': 'c'}``. Args: params: Dictionary containing the hyperparameters delimiter: Delimiter to express the hierarchy. Defaults to ``'/'``. Returns: Flattened dict. Examples: >>> _flatten_dict({'a': {'b': 'c'}}) {'a/b': 'c'} >>> _flatten_dict({'a': {'b': 123}}) {'a/b': 123} >>> _flatten_dict({5: {'a': 123}}) {'5/a': 123} >>> _flatten_dict({"dl": [{"a": 1, "c": 3}, {"b": 2, "d": 5}], "l": [1, 2, 3, 4]}) {'dl/0/a': 1, 'dl/0/c': 3, 'dl/1/b': 2, 'dl/1/d': 5, 'l': [1, 2, 3, 4]} """ result: dict[str, Any] = {} for k, v in params.items(): new_key = parent_key + delimiter + str(k) if parent_key else str(k) if is_dataclass(v) and not isinstance(v, type): v = asdict(v) elif isinstance(v, Namespace): v = vars(v) if isinstance(v, MutableMapping): result = {**result, **_flatten_dict(v, parent_key=new_key, delimiter=delimiter)} # Also handle the case where v is a list of dictionaries elif isinstance(v, list) and all(isinstance(item, MutableMapping) for item in v): for i, item in enumerate(v): result = {**result, **_flatten_dict(item, parent_key=f"{new_key}/{i}", delimiter=delimiter)} else: result[new_key] = v return result def _sanitize_params(params: dict[str, Any]) -> dict[str, Any]: """Returns params with non-primitvies converted to strings for logging. >>> import torch >>> params = {"float": 0.3, ... "int": 1, ... "string": "abc", ... "bool": True, ... "list": [1, 2, 3], ... "namespace": Namespace(foo=3), ... "layer": torch.nn.BatchNorm1d} >>> import pprint >>> pprint.pprint(_sanitize_params(params)) # doctest: +NORMALIZE_WHITESPACE {'bool': True, 'float': 0.3, 'int': 1, 'layer': "", 'list': '[1, 2, 3]', 'namespace': 'Namespace(foo=3)', 'string': 'abc'} """ for k in params: if _NUMPY_AVAILABLE: import numpy as np if isinstance(params[k], (np.bool_, np.integer, np.floating)): params[k] = params[k].item() if type(params[k]) not in [bool, int, float, str, Tensor]: params[k] = str(params[k]) return params def _convert_json_serializable(params: dict[str, Any]) -> dict[str, Any]: """Convert non-serializable objects in params to string.""" return {k: str(v) if not _is_json_serializable(v) else v for k, v in params.items()} def _is_json_serializable(value: Any) -> bool: """Test whether a variable can be encoded as json.""" if value is None or isinstance(value, (bool, int, float, str, list, dict)): # fast path return True try: json.dumps(value) return True except (TypeError, OverflowError): # OverflowError is raised if number is too large to encode return False def _add_prefix( metrics: Mapping[str, Union[Tensor, float]], prefix: str, separator: str ) -> Mapping[str, Union[Tensor, float]]: """Insert prefix before each key in a dict, separated by the separator. Args: metrics: Dictionary with metric names as keys and measured quantities as values prefix: Prefix to insert before each key separator: Separates prefix and original key name Returns: Dictionary with prefix and separator inserted before each key """ if not prefix: return metrics return {f"{prefix}{separator}{k}": v for k, v in metrics.items()}