|
"""config.""" |
|
|
|
import logging |
|
from typing import Optional |
|
|
|
import wandb |
|
from wandb.util import ( |
|
_is_artifact_representation, |
|
check_dict_contains_nested_artifact, |
|
json_friendly_val, |
|
) |
|
|
|
from . import wandb_helper |
|
from .lib import config_util |
|
|
|
logger = logging.getLogger("wandb") |
|
|
|
|
|
|
|
|
|
|
|
class Config: |
|
"""Config object. |
|
|
|
Config objects are intended to hold all of the hyperparameters associated |
|
with a wandb run and are saved with the run object when `wandb.init` is |
|
called. |
|
|
|
We recommend setting the config once when initializing your run by passing |
|
the `config` parameter to `init`: |
|
|
|
``` |
|
wandb.init(config=my_config_dict) |
|
``` |
|
|
|
You can create a file called `config-defaults.yaml`, and it will |
|
automatically be loaded as each run's config. You can also pass the name |
|
of the file as the `config` parameter to `init`: |
|
|
|
``` |
|
wandb.init(config="my_config.yaml") |
|
``` |
|
|
|
See https://docs.wandb.com/guides/track/config#file-based-configs. |
|
|
|
Examples: |
|
Basic usage |
|
``` |
|
with wandb.init(config={"epochs": 4}) as run: |
|
for x in range(run.config.epochs): |
|
# train |
|
``` |
|
|
|
Nested values |
|
``` |
|
with wandb.init(config={"train": {"epochs": 4}}) as run: |
|
for x in range(run.config["train"]["epochs"]): |
|
# train |
|
``` |
|
|
|
Using absl flags |
|
``` |
|
flags.DEFINE_string("model", None, "model to run") # name, default, help |
|
with wandb.init() as run: |
|
run.config.update(flags.FLAGS) # adds all absl flags to config |
|
``` |
|
|
|
Argparse flags |
|
```python |
|
with wandb.init(config={"epochs": 4}) as run: |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument( |
|
"-b", |
|
"--batch-size", |
|
type=int, |
|
default=8, |
|
metavar="N", |
|
help="input batch size for training (default: 8)", |
|
) |
|
args = parser.parse_args() |
|
run.config.update(args) |
|
``` |
|
|
|
Using TensorFlow flags (deprecated in tensorflow v2) |
|
```python |
|
flags = tf.app.flags |
|
flags.DEFINE_string("data_dir", "/tmp/data") |
|
flags.DEFINE_integer("batch_size", 128, "Batch size.") |
|
|
|
with wandb.init() as run: |
|
run.config.update(flags.FLAGS) |
|
``` |
|
""" |
|
|
|
def __init__(self): |
|
object.__setattr__(self, "_items", dict()) |
|
object.__setattr__(self, "_locked", dict()) |
|
object.__setattr__(self, "_users", dict()) |
|
object.__setattr__(self, "_users_inv", dict()) |
|
object.__setattr__(self, "_users_cnt", 0) |
|
object.__setattr__(self, "_callback", None) |
|
object.__setattr__(self, "_settings", None) |
|
object.__setattr__(self, "_artifact_callback", None) |
|
|
|
self._load_defaults() |
|
|
|
def _set_callback(self, cb): |
|
object.__setattr__(self, "_callback", cb) |
|
|
|
def _set_artifact_callback(self, cb): |
|
object.__setattr__(self, "_artifact_callback", cb) |
|
|
|
def _set_settings(self, settings): |
|
object.__setattr__(self, "_settings", settings) |
|
|
|
def __repr__(self): |
|
return str(dict(self)) |
|
|
|
def keys(self): |
|
return [k for k in self._items.keys() if not k.startswith("_")] |
|
|
|
def _as_dict(self): |
|
return self._items |
|
|
|
def as_dict(self): |
|
|
|
return dict(self) |
|
|
|
def __getitem__(self, key): |
|
return self._items[key] |
|
|
|
def __iter__(self): |
|
return iter(self._items) |
|
|
|
def _check_locked(self, key, ignore_locked=False) -> bool: |
|
locked = self._locked.get(key) |
|
if locked is not None: |
|
locked_user = self._users_inv[locked] |
|
if not ignore_locked: |
|
wandb.termwarn( |
|
f"Config item '{key}' was locked by '{locked_user}' (ignored update)." |
|
) |
|
return True |
|
return False |
|
|
|
def __setitem__(self, key, val): |
|
if self._check_locked(key): |
|
return |
|
with wandb.sdk.lib.telemetry.context() as tel: |
|
tel.feature.set_config_item = True |
|
self._raise_value_error_on_nested_artifact(val, nested=True) |
|
key, val = self._sanitize(key, val) |
|
self._items[key] = val |
|
logger.info("config set %s = %s - %s", key, val, self._callback) |
|
if self._callback: |
|
self._callback(key=key, val=val) |
|
|
|
def items(self): |
|
return [(k, v) for k, v in self._items.items() if not k.startswith("_")] |
|
|
|
__setattr__ = __setitem__ |
|
|
|
def __getattr__(self, key): |
|
try: |
|
return self.__getitem__(key) |
|
except KeyError as ke: |
|
raise AttributeError( |
|
f"{self.__class__!r} object has no attribute {key!r}" |
|
) from ke |
|
|
|
def __contains__(self, key): |
|
return key in self._items |
|
|
|
def _update(self, d, allow_val_change=None, ignore_locked=None): |
|
parsed_dict = wandb_helper.parse_config(d) |
|
locked_keys = set() |
|
for key in list(parsed_dict): |
|
if self._check_locked(key, ignore_locked=ignore_locked): |
|
locked_keys.add(key) |
|
sanitized = self._sanitize_dict( |
|
parsed_dict, allow_val_change, ignore_keys=locked_keys |
|
) |
|
self._items.update(sanitized) |
|
return sanitized |
|
|
|
def update(self, d, allow_val_change=None): |
|
sanitized = self._update(d, allow_val_change) |
|
if self._callback: |
|
self._callback(data=sanitized) |
|
|
|
def get(self, *args): |
|
return self._items.get(*args) |
|
|
|
def persist(self): |
|
"""Call the callback if it's set.""" |
|
if self._callback: |
|
self._callback(data=self._as_dict()) |
|
|
|
def setdefaults(self, d): |
|
d = wandb_helper.parse_config(d) |
|
|
|
d = {k: v for k, v in d.items() if k not in self._items} |
|
d = self._sanitize_dict(d) |
|
self._items.update(d) |
|
if self._callback: |
|
self._callback(data=d) |
|
|
|
def _get_user_id(self, user) -> int: |
|
if user not in self._users: |
|
self._users[user] = self._users_cnt |
|
self._users_inv[self._users_cnt] = user |
|
object.__setattr__(self, "_users_cnt", self._users_cnt + 1) |
|
|
|
return self._users[user] |
|
|
|
def update_locked(self, d, user=None, _allow_val_change=None): |
|
"""Shallow-update config with `d` and lock config updates on d's keys.""" |
|
num = self._get_user_id(user) |
|
|
|
for k, v in d.items(): |
|
k, v = self._sanitize(k, v, allow_val_change=_allow_val_change) |
|
self._locked[k] = num |
|
self._items[k] = v |
|
|
|
if self._callback: |
|
self._callback(data=d) |
|
|
|
def merge_locked(self, d, user=None, _allow_val_change=None): |
|
"""Recursively merge-update config with `d` and lock config updates on d's keys.""" |
|
num = self._get_user_id(user) |
|
callback_d = {} |
|
|
|
for k, v in d.items(): |
|
k, v = self._sanitize(k, v, allow_val_change=_allow_val_change) |
|
self._locked[k] = num |
|
|
|
if ( |
|
k in self._items |
|
and isinstance(self._items[k], dict) |
|
and isinstance(v, dict) |
|
): |
|
self._items[k] = config_util.merge_dicts(self._items[k], v) |
|
else: |
|
self._items[k] = v |
|
|
|
callback_d[k] = self._items[k] |
|
|
|
if self._callback: |
|
self._callback(data=callback_d) |
|
|
|
def _load_defaults(self): |
|
conf_dict = config_util.dict_from_config_file("config-defaults.yaml") |
|
if conf_dict is not None: |
|
self.update(conf_dict) |
|
|
|
def _sanitize_dict( |
|
self, |
|
config_dict, |
|
allow_val_change=None, |
|
ignore_keys: Optional[set] = None, |
|
): |
|
sanitized = {} |
|
self._raise_value_error_on_nested_artifact(config_dict) |
|
for k, v in config_dict.items(): |
|
if ignore_keys and k in ignore_keys: |
|
continue |
|
k, v = self._sanitize(k, v, allow_val_change) |
|
sanitized[k] = v |
|
return sanitized |
|
|
|
def _sanitize(self, key, val, allow_val_change=None): |
|
|
|
|
|
if isinstance(val, wandb.sdk.data_types.base_types.wb_value.WBValue): |
|
raise TypeError("WBValue objects cannot be added to the run config") |
|
|
|
if self._settings and self._settings._jupyter and allow_val_change is None: |
|
allow_val_change = True |
|
|
|
key = key.strip("-") |
|
if _is_artifact_representation(val): |
|
val = self._artifact_callback(key, val) |
|
|
|
if not isinstance(val, wandb.Artifact): |
|
val = json_friendly_val(val) |
|
if not allow_val_change: |
|
if key in self._items and val != self._items[key]: |
|
raise config_util.ConfigError( |
|
f'Attempted to change value of key "{key}" ' |
|
f"from {self._items[key]} to {val}\n" |
|
"If you really want to do this, pass" |
|
" allow_val_change=True to config.update()" |
|
) |
|
return key, val |
|
|
|
def _raise_value_error_on_nested_artifact(self, v, nested=False): |
|
|
|
|
|
if isinstance(v, dict) and check_dict_contains_nested_artifact(v, nested): |
|
raise ValueError( |
|
"Instances of wandb.Artifact can only be top level keys in" |
|
" a run's config" |
|
) |
|
|
|
|
|
class ConfigStatic: |
|
def __init__(self, config): |
|
object.__setattr__(self, "__dict__", dict(config)) |
|
|
|
def __setattr__(self, name, value): |
|
raise AttributeError("Error: run.config_static is a readonly object") |
|
|
|
def __setitem__(self, key, val): |
|
raise AttributeError("Error: run.config_static is a readonly object") |
|
|
|
def keys(self): |
|
return self.__dict__.keys() |
|
|
|
def __getitem__(self, key): |
|
return self.__dict__[key] |
|
|
|
def __str__(self): |
|
return str(self.__dict__) |
|
|