|
import json |
|
import os |
|
import time |
|
|
|
from wandb_gql import gql |
|
|
|
import wandb |
|
from wandb import util |
|
from wandb.apis.internal import Api |
|
from wandb.sdk import lib as wandb_lib |
|
from wandb.sdk.data_types.utils import val_to_json |
|
|
|
DEEP_SUMMARY_FNAME = "wandb.h5" |
|
H5_TYPES = ("numpy.ndarray", "tensorflow.Tensor", "torch.Tensor") |
|
h5py = util.get_module("h5py") |
|
np = util.get_module("numpy") |
|
|
|
|
|
class SummarySubDict: |
|
"""Nested dict-like object that proxies read and write operations through a root object. |
|
|
|
This lets us do synchronous serialization and lazy loading of large values. |
|
""" |
|
|
|
def __init__(self, root=None, path=()): |
|
self._path = tuple(path) |
|
if root is None: |
|
self._root = self |
|
self._json_dict = {} |
|
else: |
|
self._root = root |
|
json_dict = root._json_dict |
|
for k in path: |
|
json_dict = json_dict.get(k, {}) |
|
|
|
self._json_dict = json_dict |
|
self._dict = {} |
|
|
|
|
|
|
|
|
|
self._locked_keys = set() |
|
|
|
def __setattr__(self, k, v): |
|
k = k.strip() |
|
if k.startswith("_"): |
|
object.__setattr__(self, k, v) |
|
else: |
|
self[k] = v |
|
|
|
def __getattr__(self, k): |
|
k = k.strip() |
|
if k.startswith("_"): |
|
return object.__getattribute__(self, k) |
|
else: |
|
return self[k] |
|
|
|
def _root_get(self, path, child_dict): |
|
"""Load a value at a particular path from the root. |
|
|
|
This should only be implemented by the "_root" child class. |
|
|
|
We pass the child_dict so the item can be set on it or not as |
|
appropriate. Returning None for a nonexistent path wouldn't be |
|
distinguishable from that path being set to the value None. |
|
""" |
|
raise NotImplementedError |
|
|
|
def _root_set(self, path, new_keys_values): |
|
"""Set a value at a particular path in the root. |
|
|
|
This should only be implemented by the "_root" child class. |
|
""" |
|
raise NotImplementedError |
|
|
|
def _root_del(self, path): |
|
"""Delete a value at a particular path in the root. |
|
|
|
This should only be implemented by the "_root" child class. |
|
""" |
|
raise NotImplementedError |
|
|
|
def _write(self, commit=False): |
|
|
|
raise NotImplementedError |
|
|
|
def keys(self): |
|
|
|
|
|
return self._json_dict.keys() |
|
|
|
def get(self, k, default=None): |
|
if isinstance(k, str): |
|
k = k.strip() |
|
if k not in self._dict: |
|
self._root._root_get(self._path + (k,), self._dict) |
|
return self._dict.get(k, default) |
|
|
|
def items(self): |
|
|
|
|
|
for k in self.keys(): |
|
yield k, self[k] |
|
|
|
def __getitem__(self, k): |
|
if isinstance(k, str): |
|
k = k.strip() |
|
|
|
self.get(k) |
|
res = self._dict[k] |
|
|
|
return res |
|
|
|
def __contains__(self, k): |
|
if isinstance(k, str): |
|
k = k.strip() |
|
|
|
return k in self._json_dict |
|
|
|
def __setitem__(self, k, v): |
|
if isinstance(k, str): |
|
k = k.strip() |
|
|
|
path = self._path |
|
|
|
if isinstance(v, dict): |
|
self._dict[k] = SummarySubDict(self._root, path + (k,)) |
|
self._root._root_set(path, [(k, {})]) |
|
self._dict[k].update(v) |
|
else: |
|
self._dict[k] = v |
|
self._root._root_set(path, [(k, v)]) |
|
|
|
self._locked_keys.add(k) |
|
|
|
self._root._write() |
|
|
|
return v |
|
|
|
def __delitem__(self, k): |
|
k = k.strip() |
|
del self._dict[k] |
|
self._root._root_del(self._path + (k,)) |
|
|
|
self._root._write() |
|
|
|
def __repr__(self): |
|
|
|
|
|
repr_dict = dict(self._dict) |
|
for k in self._json_dict: |
|
v = self._json_dict[k] |
|
if ( |
|
k not in repr_dict |
|
and isinstance(v, dict) |
|
and v.get("_type") in H5_TYPES |
|
): |
|
|
|
|
|
repr_dict[k] = "..." |
|
else: |
|
repr_dict[k] = self[k] |
|
|
|
return repr(repr_dict) |
|
|
|
def update(self, key_vals=None, overwrite=True): |
|
"""Locked keys will be overwritten unless overwrite=False. |
|
|
|
Otherwise, written keys will be added to the "locked" list. |
|
""" |
|
if key_vals: |
|
write_items = self._update(key_vals, overwrite) |
|
self._root._root_set(self._path, write_items) |
|
self._root._write(commit=True) |
|
|
|
def _update(self, key_vals, overwrite): |
|
if not key_vals: |
|
return |
|
key_vals = {k.strip(): v for k, v in key_vals.items()} |
|
if overwrite: |
|
write_items = list(key_vals.items()) |
|
self._locked_keys.update(key_vals.keys()) |
|
else: |
|
write_keys = set(key_vals.keys()) - self._locked_keys |
|
write_items = [(k, key_vals[k]) for k in write_keys] |
|
|
|
for key, value in write_items: |
|
if isinstance(value, dict): |
|
self._dict[key] = SummarySubDict(self._root, self._path + (key,)) |
|
self._dict[key]._update(value, overwrite) |
|
else: |
|
self._dict[key] = value |
|
|
|
return write_items |
|
|
|
|
|
class Summary(SummarySubDict): |
|
"""Store summary metrics (eg. accuracy) during and after a run. |
|
|
|
You can manipulate this as if it's a Python dictionary but the keys |
|
get mangled. .strip() is called on them, so spaces at the beginning |
|
and end are removed. |
|
""" |
|
|
|
def __init__(self, run, summary=None): |
|
super().__init__() |
|
self._run = run |
|
self._h5_path = os.path.join(self._run.dir, DEEP_SUMMARY_FNAME) |
|
|
|
self._h5 = None |
|
|
|
|
|
|
|
self._json_dict = {} |
|
|
|
if summary is not None: |
|
self._json_dict = summary |
|
|
|
def _json_get(self, path): |
|
pass |
|
|
|
def _root_get(self, path, child_dict): |
|
json_dict = self._json_dict |
|
for key in path[:-1]: |
|
json_dict = json_dict[key] |
|
|
|
key = path[-1] |
|
if key in json_dict: |
|
child_dict[key] = self._decode(path, json_dict[key]) |
|
|
|
def _root_del(self, path): |
|
json_dict = self._json_dict |
|
for key in path[:-1]: |
|
json_dict = json_dict[key] |
|
|
|
val = json_dict[path[-1]] |
|
del json_dict[path[-1]] |
|
if isinstance(val, dict) and val.get("_type") in H5_TYPES: |
|
if not h5py: |
|
wandb.termerror("Deleting tensors in summary requires h5py") |
|
else: |
|
self.open_h5() |
|
h5_key = "summary/" + ".".join(path) |
|
del self._h5[h5_key] |
|
self._h5.flush() |
|
|
|
def _root_set(self, path, new_keys_values): |
|
json_dict = self._json_dict |
|
for key in path: |
|
json_dict = json_dict[key] |
|
|
|
for new_key, new_value in new_keys_values: |
|
json_dict[new_key] = self._encode(new_value, path + (new_key,)) |
|
|
|
def write_h5(self, path, val): |
|
|
|
self.open_h5() |
|
|
|
if not self._h5: |
|
wandb.termerror("Storing tensors in summary requires h5py") |
|
else: |
|
try: |
|
del self._h5["summary/" + ".".join(path)] |
|
except KeyError: |
|
pass |
|
self._h5["summary/" + ".".join(path)] = val |
|
self._h5.flush() |
|
|
|
def read_h5(self, path, val=None): |
|
|
|
self.open_h5() |
|
|
|
if not self._h5: |
|
wandb.termerror("Reading tensors from summary requires h5py") |
|
else: |
|
return self._h5.get("summary/" + ".".join(path), val) |
|
|
|
def open_h5(self): |
|
if not self._h5 and h5py: |
|
self._h5 = h5py.File(self._h5_path, "a", libver="latest") |
|
|
|
def _decode(self, path, json_value): |
|
"""Decode a `dict` encoded by `Summary._encode()`, loading h5 objects. |
|
|
|
h5 objects may be very large, so we won't have loaded them automatically. |
|
""" |
|
if isinstance(json_value, dict): |
|
if json_value.get("_type") in H5_TYPES: |
|
return self.read_h5(path, json_value) |
|
elif json_value.get("_type") == "data-frame": |
|
wandb.termerror( |
|
"This data frame was saved via the wandb data API. Contact support@wandb.com for help." |
|
) |
|
return None |
|
|
|
else: |
|
return SummarySubDict(self, path) |
|
else: |
|
return json_value |
|
|
|
def _encode(self, value, path_from_root): |
|
"""Normalize, compress, and encode sub-objects for backend storage. |
|
|
|
value: Object to encode. |
|
path_from_root: `tuple` of key strings from the top-level summary to the |
|
current `value`. |
|
|
|
Returns: |
|
A new tree of dict's with large objects replaced with dictionaries |
|
with "_type" entries that say which type the original data was. |
|
""" |
|
|
|
|
|
|
|
|
|
if isinstance(value, dict): |
|
json_value = {} |
|
for key, value in value.items(): |
|
json_value[key] = self._encode(value, path_from_root + (key,)) |
|
return json_value |
|
else: |
|
path = ".".join(path_from_root) |
|
friendly_value, converted = util.json_friendly( |
|
val_to_json(self._run, path, value, namespace="summary") |
|
) |
|
json_value, compressed = util.maybe_compress_summary( |
|
friendly_value, util.get_h5_typename(value) |
|
) |
|
if compressed: |
|
self.write_h5(path_from_root, friendly_value) |
|
|
|
return json_value |
|
|
|
|
|
def download_h5(run_id, entity=None, project=None, out_dir=None): |
|
api = Api() |
|
meta = api.download_url( |
|
project or api.settings("project"), |
|
DEEP_SUMMARY_FNAME, |
|
entity=entity or api.settings("entity"), |
|
run=run_id, |
|
) |
|
if meta and "md5" in meta and meta["md5"] is not None: |
|
|
|
wandb.termlog("Downloading summary data...") |
|
path, res = api.download_write_file(meta, out_dir=out_dir) |
|
return path |
|
|
|
|
|
def upload_h5(file, run_id, entity=None, project=None): |
|
api = Api() |
|
wandb.termlog("Uploading summary data...") |
|
with open(file, "rb") as f: |
|
api.push( |
|
{os.path.basename(file): f}, run=run_id, project=project, entity=entity |
|
) |
|
|
|
|
|
class FileSummary(Summary): |
|
def __init__(self, run): |
|
super().__init__(run) |
|
self._fname = os.path.join(run.dir, wandb_lib.filenames.SUMMARY_FNAME) |
|
self.load() |
|
|
|
def load(self): |
|
try: |
|
with open(self._fname) as f: |
|
self._json_dict = json.load(f) |
|
except (OSError, ValueError): |
|
self._json_dict = {} |
|
|
|
def _write(self, commit=False): |
|
|
|
with open(self._fname, "w") as f: |
|
f.write(util.json_dumps_safer(self._json_dict)) |
|
f.write("\n") |
|
f.flush() |
|
os.fsync(f.fileno()) |
|
if self._h5: |
|
self._h5.close() |
|
self._h5 = None |
|
|
|
|
|
class HTTPSummary(Summary): |
|
def __init__(self, run, client, summary=None): |
|
super().__init__(run, summary=summary) |
|
self._run = run |
|
self._client = client |
|
self._started = time.time() |
|
|
|
def __delitem__(self, key): |
|
if key not in self._json_dict: |
|
raise KeyError(key) |
|
del self._json_dict[key] |
|
|
|
def load(self): |
|
pass |
|
|
|
def open_h5(self): |
|
if not self._h5 and h5py: |
|
download_h5( |
|
self._run.id, |
|
entity=self._run.entity, |
|
project=self._run.project, |
|
out_dir=self._run.dir, |
|
) |
|
super().open_h5() |
|
|
|
def _write(self, commit=False): |
|
mutation = gql( |
|
""" |
|
mutation UpsertBucket( $id: String, $summaryMetrics: JSONString) { |
|
upsertBucket(input: { id: $id, summaryMetrics: $summaryMetrics}) { |
|
bucket { id } |
|
} |
|
} |
|
""" |
|
) |
|
if commit: |
|
if self._h5: |
|
self._h5.close() |
|
self._h5 = None |
|
res = self._client.execute( |
|
mutation, |
|
variable_values={ |
|
"id": self._run.storage_id, |
|
"summaryMetrics": util.json_dumps_safer(self._json_dict), |
|
}, |
|
) |
|
assert res["upsertBucket"]["bucket"]["id"] |
|
entity, project, run = self._run.path |
|
if ( |
|
os.path.exists(self._h5_path) |
|
and os.path.getmtime(self._h5_path) >= self._started |
|
): |
|
upload_h5(self._h5_path, run, entity=entity, project=project) |
|
else: |
|
return False |
|
|