File size: 6,711 Bytes
9c6594c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
import json
from typing import Any, Dict, NewType, Optional, Sequence

from wandb.proto import wandb_internal_pb2
from wandb.sdk.lib import proto_util, telemetry

BackendConfigDict = NewType("BackendConfigDict", Dict[str, Any])
"""Run config dictionary in the format used by the backend."""

_WANDB_INTERNAL_KEY = "_wandb"


class ConfigState:
    """The configuration of a run."""

    def __init__(self, tree: Optional[Dict[str, Any]] = None) -> None:
        self._tree: Dict[str, Any] = tree or {}
        """A tree with string-valued nodes and JSON leaves.

        Leaves are Python objects that are valid JSON values:

        * Primitives like strings and numbers
        * Dictionaries from strings to JSON objects
        * Lists of JSON objects
        """

    def non_internal_config(self) -> Dict[str, Any]:
        """Returns the config settings minus "_wandb"."""
        return {k: v for k, v in self._tree.items() if k != _WANDB_INTERNAL_KEY}

    def update_from_proto(
        self,
        config_record: wandb_internal_pb2.ConfigRecord,
    ) -> None:
        """Applies update and remove commands."""
        for config_item in config_record.update:
            self._update_at_path(
                _key_path(config_item),
                json.loads(config_item.value_json),
            )

        for config_item in config_record.remove:
            self._delete_at_path(_key_path(config_item))

    def merge_resumed_config(self, old_config_tree: Dict[str, Any]) -> None:
        """Merges the config from a run that's being resumed."""
        # Add any top-level keys that aren't already set.
        self._add_unset_keys_from_subtree(old_config_tree, [])

        # When resuming a run, we want to ensure the some of the old configs keys
        # are maintained. So we have this logic here to add back
        # any keys that were in the old config but not in the new config
        for key in ["viz", "visualize", "mask/class_labels"]:
            self._add_unset_keys_from_subtree(
                old_config_tree,
                [_WANDB_INTERNAL_KEY, key],
            )

    def _add_unset_keys_from_subtree(
        self,
        old_config_tree: Dict[str, Any],
        path: Sequence[str],
    ) -> None:
        """Uses the given subtree for keys that aren't already set."""
        old_subtree = _subtree(old_config_tree, path, create=False)
        if not old_subtree:
            return

        new_subtree = _subtree(self._tree, path, create=True)
        assert new_subtree is not None

        for key, value in old_subtree.items():
            if key not in new_subtree:
                new_subtree[key] = value

    def to_backend_dict(
        self,
        telemetry_record: telemetry.TelemetryRecord,
        framework: Optional[str],
        start_time_millis: int,
        metric_pbdicts: Sequence[Dict[int, Any]],
    ) -> BackendConfigDict:
        """Returns a dictionary representation expected by the backend.

        The backend expects the configuration in a specific format, and the
        config is also used to store additional metadata about the run.

        Args:
            telemetry_record: Telemetry information to insert.
            framework: The detected framework used in the run (e.g. TensorFlow).
            start_time_millis: The run's start time in Unix milliseconds.
            metric_pbdicts: List of dict representations of metric protobuffers.
        """
        backend_dict = self._tree.copy()
        wandb_internal = backend_dict.setdefault(_WANDB_INTERNAL_KEY, {})

        ###################################################
        # Telemetry information
        ###################################################
        py_version = telemetry_record.python_version
        if py_version:
            wandb_internal["python_version"] = py_version

        cli_version = telemetry_record.cli_version
        if cli_version:
            wandb_internal["cli_version"] = cli_version

        if framework:
            wandb_internal["framework"] = framework

        huggingface_version = telemetry_record.huggingface_version
        if huggingface_version:
            wandb_internal["huggingface_version"] = huggingface_version

        wandb_internal["is_jupyter_run"] = telemetry_record.env.jupyter
        wandb_internal["is_kaggle_kernel"] = telemetry_record.env.kaggle
        wandb_internal["start_time"] = start_time_millis

        # The full telemetry record.
        wandb_internal["t"] = proto_util.proto_encode_to_dict(telemetry_record)

        ###################################################
        # Metrics
        ###################################################
        if metric_pbdicts:
            wandb_internal["m"] = metric_pbdicts

        return BackendConfigDict(
            {
                key: {
                    # Configurations can be stored in a hand-written YAML file,
                    # and users can add descriptions to their hyperparameters
                    # there. However, we don't support a way to set descriptions
                    # via code, so this is always None.
                    "desc": None,
                    "value": value,
                }
                for key, value in self._tree.items()
            }
        )

    def _update_at_path(
        self,
        key_path: Sequence[str],
        value: Any,
    ) -> None:
        """Sets the value at the path in the config tree."""
        subtree = _subtree(self._tree, key_path[:-1], create=True)
        assert subtree is not None

        subtree[key_path[-1]] = value

    def _delete_at_path(
        self,
        key_path: Sequence[str],
    ) -> None:
        """Removes the subtree at the path in the config tree."""
        subtree = _subtree(self._tree, key_path[:-1], create=False)
        if subtree:
            del subtree[key_path[-1]]


def _key_path(config_item: wandb_internal_pb2.ConfigItem) -> Sequence[str]:
    """Returns the key path referenced by the config item."""
    if config_item.nested_key:
        return config_item.nested_key
    elif config_item.key:
        return [config_item.key]
    else:
        raise AssertionError(
            "Invalid ConfigItem: either key or nested_key must be set",
        )


def _subtree(
    tree: Dict[str, Any],
    key_path: Sequence[str],
    *,
    create: bool = False,
) -> Optional[Dict[str, Any]]:
    """Returns a subtree at the given path."""
    for key in key_path:
        subtree = tree.get(key)

        if not subtree:
            if create:
                subtree = {}
                tree[key] = subtree
            else:
                return None

        tree = subtree

    return tree