jamtur01's picture
Upload folder using huggingface_hub
9c6594c verified
"""config."""
import logging
from typing import Optional
import wandb
from wandb.util import (
_is_artifact_representation,
check_dict_contains_nested_artifact,
json_friendly_val,
)
from . import wandb_helper
from .lib import config_util
logger = logging.getLogger("wandb")
# TODO(jhr): consider a callback for persisting changes?
# if this is done right we might make sure this is pickle-able
# we might be able to do this on other objects like Run?
class Config:
"""Config object.
Config objects are intended to hold all of the hyperparameters associated
with a wandb run and are saved with the run object when `wandb.init` is
called.
We recommend setting the config once when initializing your run by passing
the `config` parameter to `init`:
```
wandb.init(config=my_config_dict)
```
You can create a file called `config-defaults.yaml`, and it will
automatically be loaded as each run's config. You can also pass the name
of the file as the `config` parameter to `init`:
```
wandb.init(config="my_config.yaml")
```
See https://docs.wandb.com/guides/track/config#file-based-configs.
Examples:
Basic usage
```
with wandb.init(config={"epochs": 4}) as run:
for x in range(run.config.epochs):
# train
```
Nested values
```
with wandb.init(config={"train": {"epochs": 4}}) as run:
for x in range(run.config["train"]["epochs"]):
# train
```
Using absl flags
```
flags.DEFINE_string("model", None, "model to run") # name, default, help
with wandb.init() as run:
run.config.update(flags.FLAGS) # adds all absl flags to config
```
Argparse flags
```python
with wandb.init(config={"epochs": 4}) as run:
parser = argparse.ArgumentParser()
parser.add_argument(
"-b",
"--batch-size",
type=int,
default=8,
metavar="N",
help="input batch size for training (default: 8)",
)
args = parser.parse_args()
run.config.update(args)
```
Using TensorFlow flags (deprecated in tensorflow v2)
```python
flags = tf.app.flags
flags.DEFINE_string("data_dir", "/tmp/data")
flags.DEFINE_integer("batch_size", 128, "Batch size.")
with wandb.init() as run:
run.config.update(flags.FLAGS)
```
"""
def __init__(self):
object.__setattr__(self, "_items", dict())
object.__setattr__(self, "_locked", dict())
object.__setattr__(self, "_users", dict())
object.__setattr__(self, "_users_inv", dict())
object.__setattr__(self, "_users_cnt", 0)
object.__setattr__(self, "_callback", None)
object.__setattr__(self, "_settings", None)
object.__setattr__(self, "_artifact_callback", None)
self._load_defaults()
def _set_callback(self, cb):
object.__setattr__(self, "_callback", cb)
def _set_artifact_callback(self, cb):
object.__setattr__(self, "_artifact_callback", cb)
def _set_settings(self, settings):
object.__setattr__(self, "_settings", settings)
def __repr__(self):
return str(dict(self))
def keys(self):
return [k for k in self._items.keys() if not k.startswith("_")]
def _as_dict(self):
return self._items
def as_dict(self):
# TODO: add telemetry, deprecate, then remove
return dict(self)
def __getitem__(self, key):
return self._items[key]
def __iter__(self):
return iter(self._items)
def _check_locked(self, key, ignore_locked=False) -> bool:
locked = self._locked.get(key)
if locked is not None:
locked_user = self._users_inv[locked]
if not ignore_locked:
wandb.termwarn(
f"Config item '{key}' was locked by '{locked_user}' (ignored update)."
)
return True
return False
def __setitem__(self, key, val):
if self._check_locked(key):
return
with wandb.sdk.lib.telemetry.context() as tel:
tel.feature.set_config_item = True
self._raise_value_error_on_nested_artifact(val, nested=True)
key, val = self._sanitize(key, val)
self._items[key] = val
logger.info("config set %s = %s - %s", key, val, self._callback)
if self._callback:
self._callback(key=key, val=val)
def items(self):
return [(k, v) for k, v in self._items.items() if not k.startswith("_")]
__setattr__ = __setitem__
def __getattr__(self, key):
try:
return self.__getitem__(key)
except KeyError as ke:
raise AttributeError(
f"{self.__class__!r} object has no attribute {key!r}"
) from ke
def __contains__(self, key):
return key in self._items
def _update(self, d, allow_val_change=None, ignore_locked=None):
parsed_dict = wandb_helper.parse_config(d)
locked_keys = set()
for key in list(parsed_dict):
if self._check_locked(key, ignore_locked=ignore_locked):
locked_keys.add(key)
sanitized = self._sanitize_dict(
parsed_dict, allow_val_change, ignore_keys=locked_keys
)
self._items.update(sanitized)
return sanitized
def update(self, d, allow_val_change=None):
sanitized = self._update(d, allow_val_change)
if self._callback:
self._callback(data=sanitized)
def get(self, *args):
return self._items.get(*args)
def persist(self):
"""Call the callback if it's set."""
if self._callback:
self._callback(data=self._as_dict())
def setdefaults(self, d):
d = wandb_helper.parse_config(d)
# strip out keys already configured
d = {k: v for k, v in d.items() if k not in self._items}
d = self._sanitize_dict(d)
self._items.update(d)
if self._callback:
self._callback(data=d)
def _get_user_id(self, user) -> int:
if user not in self._users:
self._users[user] = self._users_cnt
self._users_inv[self._users_cnt] = user
object.__setattr__(self, "_users_cnt", self._users_cnt + 1)
return self._users[user]
def update_locked(self, d, user=None, _allow_val_change=None):
"""Shallow-update config with `d` and lock config updates on d's keys."""
num = self._get_user_id(user)
for k, v in d.items():
k, v = self._sanitize(k, v, allow_val_change=_allow_val_change)
self._locked[k] = num
self._items[k] = v
if self._callback:
self._callback(data=d)
def merge_locked(self, d, user=None, _allow_val_change=None):
"""Recursively merge-update config with `d` and lock config updates on d's keys."""
num = self._get_user_id(user)
callback_d = {}
for k, v in d.items():
k, v = self._sanitize(k, v, allow_val_change=_allow_val_change)
self._locked[k] = num
if (
k in self._items
and isinstance(self._items[k], dict)
and isinstance(v, dict)
):
self._items[k] = config_util.merge_dicts(self._items[k], v)
else:
self._items[k] = v
callback_d[k] = self._items[k]
if self._callback:
self._callback(data=callback_d)
def _load_defaults(self):
conf_dict = config_util.dict_from_config_file("config-defaults.yaml")
if conf_dict is not None:
self.update(conf_dict)
def _sanitize_dict(
self,
config_dict,
allow_val_change=None,
ignore_keys: Optional[set] = None,
):
sanitized = {}
self._raise_value_error_on_nested_artifact(config_dict)
for k, v in config_dict.items():
if ignore_keys and k in ignore_keys:
continue
k, v = self._sanitize(k, v, allow_val_change)
sanitized[k] = v
return sanitized
def _sanitize(self, key, val, allow_val_change=None):
# TODO: enable WBValues in the config in the future
# refuse all WBValues which is all Media and Histograms
if isinstance(val, wandb.sdk.data_types.base_types.wb_value.WBValue):
raise TypeError("WBValue objects cannot be added to the run config")
# Let jupyter change config freely by default
if self._settings and self._settings._jupyter and allow_val_change is None:
allow_val_change = True
# We always normalize keys by stripping '-'
key = key.strip("-")
if _is_artifact_representation(val):
val = self._artifact_callback(key, val)
# if the user inserts an artifact into the config
if not isinstance(val, wandb.Artifact):
val = json_friendly_val(val)
if not allow_val_change:
if key in self._items and val != self._items[key]:
raise config_util.ConfigError(
f'Attempted to change value of key "{key}" '
f"from {self._items[key]} to {val}\n"
"If you really want to do this, pass"
" allow_val_change=True to config.update()"
)
return key, val
def _raise_value_error_on_nested_artifact(self, v, nested=False):
# we can't swap nested artifacts because their root key can be locked by other values
# best if we don't allow nested artifacts until we can lock nested keys in the config
if isinstance(v, dict) and check_dict_contains_nested_artifact(v, nested):
raise ValueError(
"Instances of wandb.Artifact can only be top level keys in"
" a run's config"
)
class ConfigStatic:
def __init__(self, config):
object.__setattr__(self, "__dict__", dict(config))
def __setattr__(self, name, value):
raise AttributeError("Error: run.config_static is a readonly object")
def __setitem__(self, key, val):
raise AttributeError("Error: run.config_static is a readonly object")
def keys(self):
return self.__dict__.keys()
def __getitem__(self, key):
return self.__dict__[key]
def __str__(self):
return str(self.__dict__)