jamtur01's picture
Upload folder using huggingface_hub
9c6594c verified
"""tensorboard watcher."""
import glob
import logging
import os
import queue
import socket
import sys
import threading
import time
from typing import TYPE_CHECKING, Any, Dict, List, Optional
import wandb
from wandb import util
from wandb.plot import CustomChart
from wandb.sdk.interface.interface import GlobStr
from wandb.sdk.lib import filesystem
from . import run as internal_run
if TYPE_CHECKING:
from queue import PriorityQueue
from tensorboard.backend.event_processing.event_file_loader import EventFileLoader
from tensorboard.compat.proto.event_pb2 import ProtoEvent
from wandb.proto.wandb_internal_pb2 import RunRecord
from wandb.sdk.interface.interface import FilesDict
from ..interface.interface_queue import InterfaceQueue
from .settings_static import SettingsStatic
HistoryDict = Dict[str, Any]
# Give some time for tensorboard data to be flushed
SHUTDOWN_DELAY = 5
ERROR_DELAY = 5
REMOTE_FILE_TOKEN = "://"
logger = logging.getLogger(__name__)
def _link_and_save_file(
path: str, base_path: str, interface: "InterfaceQueue", settings: "SettingsStatic"
) -> None:
# TODO(jhr): should this logic be merged with Run.save()
files_dir = settings.files_dir
file_name = os.path.relpath(path, base_path)
abs_path = os.path.abspath(path)
wandb_path = os.path.join(files_dir, file_name)
filesystem.mkdir_exists_ok(os.path.dirname(wandb_path))
# We overwrite existing symlinks because namespaces can change in Tensorboard
if os.path.islink(wandb_path) and abs_path != os.readlink(wandb_path):
os.remove(wandb_path)
os.symlink(abs_path, wandb_path)
elif not os.path.exists(wandb_path):
os.symlink(abs_path, wandb_path)
# TODO(jhr): need to figure out policy, live/throttled?
interface.publish_files(dict(files=[(GlobStr(glob.escape(file_name)), "live")]))
def is_tfevents_file_created_by(
path: str, hostname: Optional[str], start_time: Optional[float]
) -> bool:
"""Check if a path is a tfevents file.
Optionally checks that it was created by [hostname] after [start_time].
tensorboard tfevents filename format:
https://github.com/tensorflow/tensorboard/blob/f3f26b46981da5bd46a5bb93fcf02d9eb7608bc1/tensorboard/summary/writer/event_file_writer.py#L81
tensorflow tfevents filename format:
https://github.com/tensorflow/tensorflow/blob/8f597046dc30c14b5413813d02c0e0aed399c177/tensorflow/core/util/events_writer.cc#L68
"""
if not path:
raise ValueError("Path must be a nonempty string")
basename = os.path.basename(path)
if basename.endswith((".profile-empty", ".sagemaker-uploaded")):
return False
fname_components = basename.split(".")
try:
tfevents_idx = fname_components.index("tfevents")
except ValueError:
return False
# check the hostname, which may have dots
if hostname is not None:
for i, part in enumerate(hostname.split(".")):
try:
fname_component_part = fname_components[tfevents_idx + 2 + i]
except IndexError:
return False
if part != fname_component_part:
return False
if start_time is not None:
try:
created_time = int(fname_components[tfevents_idx + 1])
except (ValueError, IndexError):
return False
# Ensure that the file is newer then our start time, and that it was
# created from the same hostname.
# TODO: we should also check the PID (also contained in the tfevents
# filename). Can we assume that our parent pid is the user process
# that wrote these files?
if created_time < int(start_time):
return False
return True
class TBWatcher:
_logdirs: "Dict[str, TBDirWatcher]"
_watcher_queue: "PriorityQueue"
def __init__(
self,
settings: "SettingsStatic",
run_proto: "RunRecord",
interface: "InterfaceQueue",
force: bool = False,
) -> None:
self._logdirs = {}
self._consumer: Optional[TBEventConsumer] = None
self._settings = settings
self._interface = interface
self._run_proto = run_proto
self._force = force
# TODO(jhr): do we need locking in this queue?
self._watcher_queue = queue.PriorityQueue()
wandb.tensorboard.reset_state() # type: ignore
def _calculate_namespace(self, logdir: str, rootdir: str) -> Optional[str]:
namespace: Optional[str]
dirs = list(self._logdirs) + [logdir]
if os.path.isfile(logdir):
filename = os.path.basename(logdir)
else:
filename = ""
if rootdir == "":
rootdir = util.to_forward_slash_path(
os.path.dirname(os.path.commonprefix(dirs))
)
# Tensorboard loads all tfevents files in a directory and prepends
# their values with the path. Passing namespace to log allows us
# to nest the values in wandb
# Note that we strip '/' instead of os.sep, because elsewhere we've
# converted paths to forward slash.
namespace = logdir.replace(filename, "").replace(rootdir, "").strip("/")
# TODO: revisit this heuristic, it exists because we don't know the
# root log directory until more than one tfevents file is written to
if len(dirs) == 1 and namespace not in ["train", "validation"]:
namespace = None
else:
namespace = logdir.replace(filename, "").replace(rootdir, "").strip("/")
return namespace
def add(self, logdir: str, save: bool, root_dir: str) -> None:
logdir = util.to_forward_slash_path(logdir)
root_dir = util.to_forward_slash_path(root_dir)
if logdir in self._logdirs:
return
namespace = self._calculate_namespace(logdir, root_dir)
# TODO(jhr): implement the deferred tbdirwatcher to find namespace
if not self._consumer:
self._consumer = TBEventConsumer(
self, self._watcher_queue, self._run_proto, self._settings
)
self._consumer.start()
tbdir_watcher = TBDirWatcher(
self, logdir, save, namespace, self._watcher_queue, self._force
)
self._logdirs[logdir] = tbdir_watcher
tbdir_watcher.start()
def finish(self) -> None:
for tbdirwatcher in self._logdirs.values():
tbdirwatcher.shutdown()
for tbdirwatcher in self._logdirs.values():
tbdirwatcher.finish()
if self._consumer:
self._consumer.finish()
class TBDirWatcher:
def __init__(
self,
tbwatcher: "TBWatcher",
logdir: str,
save: bool,
namespace: Optional[str],
queue: "PriorityQueue",
force: bool = False,
) -> None:
self.directory_watcher = util.get_module(
"tensorboard.backend.event_processing.directory_watcher",
required="Please install tensorboard package",
)
# self.event_file_loader = util.get_module(
# "tensorboard.backend.event_processing.event_file_loader",
# required="Please install tensorboard package",
# )
self.tf_compat = util.get_module(
"tensorboard.compat", required="Please install tensorboard package"
)
self._tbwatcher = tbwatcher
self._generator = self.directory_watcher.DirectoryWatcher(
logdir, self._loader(save, namespace), self._is_our_tfevents_file
)
self._thread = threading.Thread(target=self._thread_except_body)
self._first_event_timestamp = None
self._shutdown = threading.Event()
self._queue = queue
self._file_version = None
self._namespace = namespace
self._logdir = logdir
self._hostname = socket.gethostname()
self._force = force
self._process_events_lock = threading.Lock()
def start(self) -> None:
self._thread.start()
def _is_our_tfevents_file(self, path: str) -> bool:
"""Check if a path has been modified since launch and contains tfevents."""
if not path:
raise ValueError("Path must be a nonempty string")
path = self.tf_compat.tf.compat.as_str_any(path)
if self._force:
return is_tfevents_file_created_by(path, None, None)
else:
return is_tfevents_file_created_by(
path, self._hostname, self._tbwatcher._settings.x_start_time
)
def _loader(
self, save: bool = True, namespace: Optional[str] = None
) -> "EventFileLoader":
"""Incredibly hacky class generator to optionally save / prefix tfevent files."""
_loader_interface = self._tbwatcher._interface
_loader_settings = self._tbwatcher._settings
try:
from tensorboard.backend.event_processing import event_file_loader
except ImportError:
raise Exception("Please install tensorboard package")
class EventFileLoader(event_file_loader.EventFileLoader):
def __init__(self, file_path: str) -> None:
super().__init__(file_path)
if save:
if REMOTE_FILE_TOKEN in file_path:
logger.warning(
"Not persisting remote tfevent file: %s", file_path
)
else:
# TODO: save plugins?
logdir = os.path.dirname(file_path)
parts = list(os.path.split(logdir))
if namespace and parts[-1] == namespace:
parts.pop()
logdir = os.path.join(*parts)
_link_and_save_file(
path=file_path,
base_path=logdir,
interface=_loader_interface,
settings=_loader_settings,
)
return EventFileLoader
def _process_events(self, shutdown_call: bool = False) -> None:
try:
with self._process_events_lock:
for event in self._generator.Load():
self.process_event(event)
except (
self.directory_watcher.DirectoryDeletedError,
StopIteration,
RuntimeError,
OSError,
) as e:
# When listing s3 the directory may not yet exist, or could be empty
logger.debug("Encountered tensorboard directory watcher error: %s", e)
if not self._shutdown.is_set() and not shutdown_call:
time.sleep(ERROR_DELAY)
def _thread_except_body(self) -> None:
try:
self._thread_body()
except Exception:
logger.exception("generic exception in TBDirWatcher thread")
raise
def _thread_body(self) -> None:
"""Check for new events every second."""
shutdown_time: Optional[float] = None
while True:
self._process_events()
if self._shutdown.is_set():
now = time.time()
if not shutdown_time:
shutdown_time = now + SHUTDOWN_DELAY
elif now > shutdown_time:
break
time.sleep(1)
def process_event(self, event: "ProtoEvent") -> None:
# print("\nEVENT:::", self._logdir, self._namespace, event, "\n")
if self._first_event_timestamp is None:
self._first_event_timestamp = event.wall_time
if event.HasField("file_version"):
self._file_version = event.file_version
if event.HasField("summary"):
self._queue.put(Event(event, self._namespace))
def shutdown(self) -> None:
self._process_events(shutdown_call=True)
self._shutdown.set()
def finish(self) -> None:
self.shutdown()
self._thread.join()
class Event:
"""An event wrapper to enable priority queueing."""
def __init__(self, event: "ProtoEvent", namespace: Optional[str]):
self.event = event
self.namespace = namespace
self.created_at = time.time()
def __lt__(self, other: "Event") -> bool:
if self.event.wall_time < other.event.wall_time:
return True
return False
class TBEventConsumer:
"""Consume tfevents from a priority queue.
There should always only be one of these per run_manager. We wait for 10 seconds of
queued events to reduce the chance of multiple tfevent files triggering out of order
steps.
"""
def __init__(
self,
tbwatcher: TBWatcher,
queue: "PriorityQueue",
run_proto: "RunRecord",
settings: "SettingsStatic",
delay: int = 10,
) -> None:
self._tbwatcher = tbwatcher
self._queue = queue
self._thread = threading.Thread(target=self._thread_except_body)
self._shutdown = threading.Event()
self.tb_history = TBHistory()
self._delay = delay
# This is a bit of a hack to get file saving to work as it does in the user
# process. Since we don't have a real run object, we have to define the
# datatypes callback ourselves.
def datatypes_cb(fname: GlobStr) -> None:
files: FilesDict = dict(files=[(fname, "now")])
self._tbwatcher._interface.publish_files(files)
# this is only used for logging artifacts
self._internal_run = internal_run.InternalRun(run_proto, settings, datatypes_cb)
self._internal_run._set_internal_run_interface(self._tbwatcher._interface)
def start(self) -> None:
self._start_time = time.time()
self._thread.start()
def finish(self) -> None:
self._delay = 0
self._shutdown.set()
self._thread.join()
while not self._queue.empty():
event = self._queue.get(True, 1)
if event:
self._handle_event(event, history=self.tb_history)
items = self.tb_history._get_and_reset()
for item in items:
self._save_row(
item,
)
def _thread_except_body(self) -> None:
try:
self._thread_body()
except Exception:
logger.exception("generic exception in TBEventConsumer thread")
raise
def _thread_body(self) -> None:
while True:
try:
event = self._queue.get(True, 1)
# Wait self._delay seconds from consumer start before logging events
if (
time.time() < self._start_time + self._delay
and not self._shutdown.is_set()
):
self._queue.put(event)
time.sleep(0.1)
continue
except queue.Empty:
event = None
if self._shutdown.is_set():
break
if event:
self._handle_event(event, history=self.tb_history)
items = self.tb_history._get_and_reset()
for item in items:
self._save_row(
item,
)
# flush uncommitted data
self.tb_history._flush()
items = self.tb_history._get_and_reset()
for item in items:
self._save_row(item)
def _handle_event(
self, event: "ProtoEvent", history: Optional["TBHistory"] = None
) -> None:
wandb.tensorboard._log( # type: ignore
event.event,
step=event.event.step,
namespace=event.namespace,
history=history,
)
def _save_row(self, row: "HistoryDict") -> None:
chart_keys = set()
for k, v in row.items():
if isinstance(v, CustomChart):
chart_keys.add(k)
v.set_key(k)
self._tbwatcher._interface.publish_config(
key=v.spec.config_key,
val=v.spec.config_value,
)
for k in chart_keys:
chart = row.pop(k)
if isinstance(chart, CustomChart):
row[chart.spec.table_key] = chart.table
self._tbwatcher._interface.publish_history(
self._internal_run,
row,
publish_step=False,
)
class TBHistory:
_data: "HistoryDict"
_added: "List[HistoryDict]"
def __init__(self) -> None:
self._step = 0
self._step_size = 0
self._data = dict()
self._added = []
def _flush(self) -> None:
if not self._data:
return
# A single tensorboard step may have too much data
# we just drop the largest keys in the step if it does.
# TODO: we could flush the data across multiple steps
if self._step_size > util.MAX_LINE_BYTES:
metrics = [(k, sys.getsizeof(v)) for k, v in self._data.items()]
metrics.sort(key=lambda t: t[1], reverse=True)
bad = 0
dropped_keys = []
for k, v in metrics:
# TODO: (cvp) Added a buffer of 100KiB, this feels rather brittle.
if self._step_size - bad < util.MAX_LINE_BYTES - 100000:
break
else:
bad += v
dropped_keys.append(k)
del self._data[k]
wandb.termwarn(
f"Step {self._step} exceeds max data limit, dropping {len(dropped_keys)} of the largest keys:"
)
print("\t" + ("\n\t".join(dropped_keys))) # noqa: T201
self._data["_step"] = self._step
self._added.append(self._data)
self._step += 1
self._step_size = 0
def add(self, d: "HistoryDict") -> None:
self._flush()
self._data = dict()
self._data.update(self._track_history_dict(d))
def _track_history_dict(self, d: "HistoryDict") -> "HistoryDict":
e = {}
for k in d.keys():
e[k] = d[k]
self._step_size += sys.getsizeof(e[k])
return e
def _row_update(self, d: "HistoryDict") -> None:
self._data.update(self._track_history_dict(d))
def _get_and_reset(self) -> "List[HistoryDict]":
added = self._added[:]
self._added = []
return added