File size: 6,711 Bytes
9c6594c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 |
import json
from typing import Any, Dict, NewType, Optional, Sequence
from wandb.proto import wandb_internal_pb2
from wandb.sdk.lib import proto_util, telemetry
BackendConfigDict = NewType("BackendConfigDict", Dict[str, Any])
"""Run config dictionary in the format used by the backend."""
_WANDB_INTERNAL_KEY = "_wandb"
class ConfigState:
"""The configuration of a run."""
def __init__(self, tree: Optional[Dict[str, Any]] = None) -> None:
self._tree: Dict[str, Any] = tree or {}
"""A tree with string-valued nodes and JSON leaves.
Leaves are Python objects that are valid JSON values:
* Primitives like strings and numbers
* Dictionaries from strings to JSON objects
* Lists of JSON objects
"""
def non_internal_config(self) -> Dict[str, Any]:
"""Returns the config settings minus "_wandb"."""
return {k: v for k, v in self._tree.items() if k != _WANDB_INTERNAL_KEY}
def update_from_proto(
self,
config_record: wandb_internal_pb2.ConfigRecord,
) -> None:
"""Applies update and remove commands."""
for config_item in config_record.update:
self._update_at_path(
_key_path(config_item),
json.loads(config_item.value_json),
)
for config_item in config_record.remove:
self._delete_at_path(_key_path(config_item))
def merge_resumed_config(self, old_config_tree: Dict[str, Any]) -> None:
"""Merges the config from a run that's being resumed."""
# Add any top-level keys that aren't already set.
self._add_unset_keys_from_subtree(old_config_tree, [])
# When resuming a run, we want to ensure the some of the old configs keys
# are maintained. So we have this logic here to add back
# any keys that were in the old config but not in the new config
for key in ["viz", "visualize", "mask/class_labels"]:
self._add_unset_keys_from_subtree(
old_config_tree,
[_WANDB_INTERNAL_KEY, key],
)
def _add_unset_keys_from_subtree(
self,
old_config_tree: Dict[str, Any],
path: Sequence[str],
) -> None:
"""Uses the given subtree for keys that aren't already set."""
old_subtree = _subtree(old_config_tree, path, create=False)
if not old_subtree:
return
new_subtree = _subtree(self._tree, path, create=True)
assert new_subtree is not None
for key, value in old_subtree.items():
if key not in new_subtree:
new_subtree[key] = value
def to_backend_dict(
self,
telemetry_record: telemetry.TelemetryRecord,
framework: Optional[str],
start_time_millis: int,
metric_pbdicts: Sequence[Dict[int, Any]],
) -> BackendConfigDict:
"""Returns a dictionary representation expected by the backend.
The backend expects the configuration in a specific format, and the
config is also used to store additional metadata about the run.
Args:
telemetry_record: Telemetry information to insert.
framework: The detected framework used in the run (e.g. TensorFlow).
start_time_millis: The run's start time in Unix milliseconds.
metric_pbdicts: List of dict representations of metric protobuffers.
"""
backend_dict = self._tree.copy()
wandb_internal = backend_dict.setdefault(_WANDB_INTERNAL_KEY, {})
###################################################
# Telemetry information
###################################################
py_version = telemetry_record.python_version
if py_version:
wandb_internal["python_version"] = py_version
cli_version = telemetry_record.cli_version
if cli_version:
wandb_internal["cli_version"] = cli_version
if framework:
wandb_internal["framework"] = framework
huggingface_version = telemetry_record.huggingface_version
if huggingface_version:
wandb_internal["huggingface_version"] = huggingface_version
wandb_internal["is_jupyter_run"] = telemetry_record.env.jupyter
wandb_internal["is_kaggle_kernel"] = telemetry_record.env.kaggle
wandb_internal["start_time"] = start_time_millis
# The full telemetry record.
wandb_internal["t"] = proto_util.proto_encode_to_dict(telemetry_record)
###################################################
# Metrics
###################################################
if metric_pbdicts:
wandb_internal["m"] = metric_pbdicts
return BackendConfigDict(
{
key: {
# Configurations can be stored in a hand-written YAML file,
# and users can add descriptions to their hyperparameters
# there. However, we don't support a way to set descriptions
# via code, so this is always None.
"desc": None,
"value": value,
}
for key, value in self._tree.items()
}
)
def _update_at_path(
self,
key_path: Sequence[str],
value: Any,
) -> None:
"""Sets the value at the path in the config tree."""
subtree = _subtree(self._tree, key_path[:-1], create=True)
assert subtree is not None
subtree[key_path[-1]] = value
def _delete_at_path(
self,
key_path: Sequence[str],
) -> None:
"""Removes the subtree at the path in the config tree."""
subtree = _subtree(self._tree, key_path[:-1], create=False)
if subtree:
del subtree[key_path[-1]]
def _key_path(config_item: wandb_internal_pb2.ConfigItem) -> Sequence[str]:
"""Returns the key path referenced by the config item."""
if config_item.nested_key:
return config_item.nested_key
elif config_item.key:
return [config_item.key]
else:
raise AssertionError(
"Invalid ConfigItem: either key or nested_key must be set",
)
def _subtree(
tree: Dict[str, Any],
key_path: Sequence[str],
*,
create: bool = False,
) -> Optional[Dict[str, Any]]:
"""Returns a subtree at the given path."""
for key in key_path:
subtree = tree.get(key)
if not subtree:
if create:
subtree = {}
tree[key] = subtree
else:
return None
tree = subtree
return tree
|