jamtur01's picture
Upload folder using huggingface_hub
9c6594c verified
import json
import os
import time
from wandb_gql import gql
import wandb
from wandb import util
from wandb.apis.internal import Api
from wandb.sdk import lib as wandb_lib
from wandb.sdk.data_types.utils import val_to_json
DEEP_SUMMARY_FNAME = "wandb.h5"
H5_TYPES = ("numpy.ndarray", "tensorflow.Tensor", "torch.Tensor")
h5py = util.get_module("h5py")
np = util.get_module("numpy")
class SummarySubDict:
"""Nested dict-like object that proxies read and write operations through a root object.
This lets us do synchronous serialization and lazy loading of large values.
"""
def __init__(self, root=None, path=()):
self._path = tuple(path)
if root is None:
self._root = self
self._json_dict = {}
else:
self._root = root
json_dict = root._json_dict
for k in path:
json_dict = json_dict.get(k, {})
self._json_dict = json_dict
self._dict = {}
# We use this to track which keys the user has set explicitly
# so that we don't automatically overwrite them when we update
# the summary from the history.
self._locked_keys = set()
def __setattr__(self, k, v):
k = k.strip()
if k.startswith("_"):
object.__setattr__(self, k, v)
else:
self[k] = v
def __getattr__(self, k):
k = k.strip()
if k.startswith("_"):
return object.__getattribute__(self, k)
else:
return self[k]
def _root_get(self, path, child_dict):
"""Load a value at a particular path from the root.
This should only be implemented by the "_root" child class.
We pass the child_dict so the item can be set on it or not as
appropriate. Returning None for a nonexistent path wouldn't be
distinguishable from that path being set to the value None.
"""
raise NotImplementedError
def _root_set(self, path, new_keys_values):
"""Set a value at a particular path in the root.
This should only be implemented by the "_root" child class.
"""
raise NotImplementedError
def _root_del(self, path):
"""Delete a value at a particular path in the root.
This should only be implemented by the "_root" child class.
"""
raise NotImplementedError
def _write(self, commit=False):
# should only be implemented on the root summary
raise NotImplementedError
def keys(self):
# _json_dict has the full set of keys, including those for h5 objects
# that may not have been loaded yet
return self._json_dict.keys()
def get(self, k, default=None):
if isinstance(k, str):
k = k.strip()
if k not in self._dict:
self._root._root_get(self._path + (k,), self._dict)
return self._dict.get(k, default)
def items(self):
# not all items may be loaded into self._dict, so we
# have to build the sequence of items from scratch
for k in self.keys():
yield k, self[k]
def __getitem__(self, k):
if isinstance(k, str):
k = k.strip()
self.get(k) # load the value into _dict if it should be there
res = self._dict[k]
return res
def __contains__(self, k):
if isinstance(k, str):
k = k.strip()
return k in self._json_dict
def __setitem__(self, k, v):
if isinstance(k, str):
k = k.strip()
path = self._path
if isinstance(v, dict):
self._dict[k] = SummarySubDict(self._root, path + (k,))
self._root._root_set(path, [(k, {})])
self._dict[k].update(v)
else:
self._dict[k] = v
self._root._root_set(path, [(k, v)])
self._locked_keys.add(k)
self._root._write()
return v
def __delitem__(self, k):
k = k.strip()
del self._dict[k]
self._root._root_del(self._path + (k,))
self._root._write()
def __repr__(self):
# use a copy of _dict, except add placeholders for h5 objects, etc.
# that haven't been loaded yet
repr_dict = dict(self._dict)
for k in self._json_dict:
v = self._json_dict[k]
if (
k not in repr_dict
and isinstance(v, dict)
and v.get("_type") in H5_TYPES
):
# unloaded h5 objects may be very large. use a placeholder for them
# if we haven't already loaded them
repr_dict[k] = "..."
else:
repr_dict[k] = self[k]
return repr(repr_dict)
def update(self, key_vals=None, overwrite=True):
"""Locked keys will be overwritten unless overwrite=False.
Otherwise, written keys will be added to the "locked" list.
"""
if key_vals:
write_items = self._update(key_vals, overwrite)
self._root._root_set(self._path, write_items)
self._root._write(commit=True)
def _update(self, key_vals, overwrite):
if not key_vals:
return
key_vals = {k.strip(): v for k, v in key_vals.items()}
if overwrite:
write_items = list(key_vals.items())
self._locked_keys.update(key_vals.keys())
else:
write_keys = set(key_vals.keys()) - self._locked_keys
write_items = [(k, key_vals[k]) for k in write_keys]
for key, value in write_items:
if isinstance(value, dict):
self._dict[key] = SummarySubDict(self._root, self._path + (key,))
self._dict[key]._update(value, overwrite)
else:
self._dict[key] = value
return write_items
class Summary(SummarySubDict):
"""Store summary metrics (eg. accuracy) during and after a run.
You can manipulate this as if it's a Python dictionary but the keys
get mangled. .strip() is called on them, so spaces at the beginning
and end are removed.
"""
def __init__(self, run, summary=None):
super().__init__()
self._run = run
self._h5_path = os.path.join(self._run.dir, DEEP_SUMMARY_FNAME)
# Lazy load the h5 file
self._h5 = None
# Mirrored version of self._dict with versions of values that get written
# to JSON kept up to date by self._root_set() and self._root_del().
self._json_dict = {}
if summary is not None:
self._json_dict = summary
def _json_get(self, path):
pass
def _root_get(self, path, child_dict):
json_dict = self._json_dict
for key in path[:-1]:
json_dict = json_dict[key]
key = path[-1]
if key in json_dict:
child_dict[key] = self._decode(path, json_dict[key])
def _root_del(self, path):
json_dict = self._json_dict
for key in path[:-1]:
json_dict = json_dict[key]
val = json_dict[path[-1]]
del json_dict[path[-1]]
if isinstance(val, dict) and val.get("_type") in H5_TYPES:
if not h5py:
wandb.termerror("Deleting tensors in summary requires h5py")
else:
self.open_h5()
h5_key = "summary/" + ".".join(path)
del self._h5[h5_key]
self._h5.flush()
def _root_set(self, path, new_keys_values):
json_dict = self._json_dict
for key in path:
json_dict = json_dict[key]
for new_key, new_value in new_keys_values:
json_dict[new_key] = self._encode(new_value, path + (new_key,))
def write_h5(self, path, val):
# ensure the file is open
self.open_h5()
if not self._h5:
wandb.termerror("Storing tensors in summary requires h5py")
else:
try:
del self._h5["summary/" + ".".join(path)]
except KeyError:
pass
self._h5["summary/" + ".".join(path)] = val
self._h5.flush()
def read_h5(self, path, val=None):
# ensure the file is open
self.open_h5()
if not self._h5:
wandb.termerror("Reading tensors from summary requires h5py")
else:
return self._h5.get("summary/" + ".".join(path), val)
def open_h5(self):
if not self._h5 and h5py:
self._h5 = h5py.File(self._h5_path, "a", libver="latest")
def _decode(self, path, json_value):
"""Decode a `dict` encoded by `Summary._encode()`, loading h5 objects.
h5 objects may be very large, so we won't have loaded them automatically.
"""
if isinstance(json_value, dict):
if json_value.get("_type") in H5_TYPES:
return self.read_h5(path, json_value)
elif json_value.get("_type") == "data-frame":
wandb.termerror(
"This data frame was saved via the wandb data API. Contact support@wandb.com for help."
)
return None
# TODO: transform wandb objects and plots
else:
return SummarySubDict(self, path)
else:
return json_value
def _encode(self, value, path_from_root):
"""Normalize, compress, and encode sub-objects for backend storage.
value: Object to encode.
path_from_root: `tuple` of key strings from the top-level summary to the
current `value`.
Returns:
A new tree of dict's with large objects replaced with dictionaries
with "_type" entries that say which type the original data was.
"""
# Constructs a new `dict` tree in `json_value` that discards and/or
# encodes objects that aren't JSON serializable.
if isinstance(value, dict):
json_value = {}
for key, value in value.items():
json_value[key] = self._encode(value, path_from_root + (key,))
return json_value
else:
path = ".".join(path_from_root)
friendly_value, converted = util.json_friendly(
val_to_json(self._run, path, value, namespace="summary")
)
json_value, compressed = util.maybe_compress_summary(
friendly_value, util.get_h5_typename(value)
)
if compressed:
self.write_h5(path_from_root, friendly_value)
return json_value
def download_h5(run_id, entity=None, project=None, out_dir=None):
api = Api()
meta = api.download_url(
project or api.settings("project"),
DEEP_SUMMARY_FNAME,
entity=entity or api.settings("entity"),
run=run_id,
)
if meta and "md5" in meta and meta["md5"] is not None:
# TODO: make this non-blocking
wandb.termlog("Downloading summary data...")
path, res = api.download_write_file(meta, out_dir=out_dir)
return path
def upload_h5(file, run_id, entity=None, project=None):
api = Api()
wandb.termlog("Uploading summary data...")
with open(file, "rb") as f:
api.push(
{os.path.basename(file): f}, run=run_id, project=project, entity=entity
)
class FileSummary(Summary):
def __init__(self, run):
super().__init__(run)
self._fname = os.path.join(run.dir, wandb_lib.filenames.SUMMARY_FNAME)
self.load()
def load(self):
try:
with open(self._fname) as f:
self._json_dict = json.load(f)
except (OSError, ValueError):
self._json_dict = {}
def _write(self, commit=False):
# TODO: we just ignore commit to ensure backward capability
with open(self._fname, "w") as f:
f.write(util.json_dumps_safer(self._json_dict))
f.write("\n")
f.flush()
os.fsync(f.fileno())
if self._h5:
self._h5.close()
self._h5 = None
class HTTPSummary(Summary):
def __init__(self, run, client, summary=None):
super().__init__(run, summary=summary)
self._run = run
self._client = client
self._started = time.time()
def __delitem__(self, key):
if key not in self._json_dict:
raise KeyError(key)
del self._json_dict[key]
def load(self):
pass
def open_h5(self):
if not self._h5 and h5py:
download_h5(
self._run.id,
entity=self._run.entity,
project=self._run.project,
out_dir=self._run.dir,
)
super().open_h5()
def _write(self, commit=False):
mutation = gql(
"""
mutation UpsertBucket( $id: String, $summaryMetrics: JSONString) {
upsertBucket(input: { id: $id, summaryMetrics: $summaryMetrics}) {
bucket { id }
}
}
"""
)
if commit:
if self._h5:
self._h5.close()
self._h5 = None
res = self._client.execute(
mutation,
variable_values={
"id": self._run.storage_id,
"summaryMetrics": util.json_dumps_safer(self._json_dict),
},
)
assert res["upsertBucket"]["bucket"]["id"]
entity, project, run = self._run.path
if (
os.path.exists(self._h5_path)
and os.path.getmtime(self._h5_path) >= self._started
):
upload_h5(self._h5_path, run, entity=entity, project=project)
else:
return False