|
"""tensorboard watcher.""" |
|
|
|
import glob |
|
import logging |
|
import os |
|
import queue |
|
import socket |
|
import sys |
|
import threading |
|
import time |
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional |
|
|
|
import wandb |
|
from wandb import util |
|
from wandb.plot import CustomChart |
|
from wandb.sdk.interface.interface import GlobStr |
|
from wandb.sdk.lib import filesystem |
|
|
|
from . import run as internal_run |
|
|
|
if TYPE_CHECKING: |
|
from queue import PriorityQueue |
|
|
|
from tensorboard.backend.event_processing.event_file_loader import EventFileLoader |
|
from tensorboard.compat.proto.event_pb2 import ProtoEvent |
|
|
|
from wandb.proto.wandb_internal_pb2 import RunRecord |
|
from wandb.sdk.interface.interface import FilesDict |
|
|
|
from ..interface.interface_queue import InterfaceQueue |
|
from .settings_static import SettingsStatic |
|
|
|
HistoryDict = Dict[str, Any] |
|
|
|
|
|
SHUTDOWN_DELAY = 5 |
|
ERROR_DELAY = 5 |
|
REMOTE_FILE_TOKEN = "://" |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
def _link_and_save_file( |
|
path: str, base_path: str, interface: "InterfaceQueue", settings: "SettingsStatic" |
|
) -> None: |
|
|
|
files_dir = settings.files_dir |
|
file_name = os.path.relpath(path, base_path) |
|
abs_path = os.path.abspath(path) |
|
wandb_path = os.path.join(files_dir, file_name) |
|
filesystem.mkdir_exists_ok(os.path.dirname(wandb_path)) |
|
|
|
if os.path.islink(wandb_path) and abs_path != os.readlink(wandb_path): |
|
os.remove(wandb_path) |
|
os.symlink(abs_path, wandb_path) |
|
elif not os.path.exists(wandb_path): |
|
os.symlink(abs_path, wandb_path) |
|
|
|
interface.publish_files(dict(files=[(GlobStr(glob.escape(file_name)), "live")])) |
|
|
|
|
|
def is_tfevents_file_created_by( |
|
path: str, hostname: Optional[str], start_time: Optional[float] |
|
) -> bool: |
|
"""Check if a path is a tfevents file. |
|
|
|
Optionally checks that it was created by [hostname] after [start_time]. |
|
|
|
tensorboard tfevents filename format: |
|
https://github.com/tensorflow/tensorboard/blob/f3f26b46981da5bd46a5bb93fcf02d9eb7608bc1/tensorboard/summary/writer/event_file_writer.py#L81 |
|
tensorflow tfevents filename format: |
|
https://github.com/tensorflow/tensorflow/blob/8f597046dc30c14b5413813d02c0e0aed399c177/tensorflow/core/util/events_writer.cc#L68 |
|
""" |
|
if not path: |
|
raise ValueError("Path must be a nonempty string") |
|
basename = os.path.basename(path) |
|
if basename.endswith((".profile-empty", ".sagemaker-uploaded")): |
|
return False |
|
fname_components = basename.split(".") |
|
try: |
|
tfevents_idx = fname_components.index("tfevents") |
|
except ValueError: |
|
return False |
|
|
|
if hostname is not None: |
|
for i, part in enumerate(hostname.split(".")): |
|
try: |
|
fname_component_part = fname_components[tfevents_idx + 2 + i] |
|
except IndexError: |
|
return False |
|
if part != fname_component_part: |
|
return False |
|
if start_time is not None: |
|
try: |
|
created_time = int(fname_components[tfevents_idx + 1]) |
|
except (ValueError, IndexError): |
|
return False |
|
|
|
|
|
|
|
|
|
|
|
if created_time < int(start_time): |
|
return False |
|
return True |
|
|
|
|
|
class TBWatcher: |
|
_logdirs: "Dict[str, TBDirWatcher]" |
|
_watcher_queue: "PriorityQueue" |
|
|
|
def __init__( |
|
self, |
|
settings: "SettingsStatic", |
|
run_proto: "RunRecord", |
|
interface: "InterfaceQueue", |
|
force: bool = False, |
|
) -> None: |
|
self._logdirs = {} |
|
self._consumer: Optional[TBEventConsumer] = None |
|
self._settings = settings |
|
self._interface = interface |
|
self._run_proto = run_proto |
|
self._force = force |
|
|
|
self._watcher_queue = queue.PriorityQueue() |
|
wandb.tensorboard.reset_state() |
|
|
|
def _calculate_namespace(self, logdir: str, rootdir: str) -> Optional[str]: |
|
namespace: Optional[str] |
|
dirs = list(self._logdirs) + [logdir] |
|
|
|
if os.path.isfile(logdir): |
|
filename = os.path.basename(logdir) |
|
else: |
|
filename = "" |
|
|
|
if rootdir == "": |
|
rootdir = util.to_forward_slash_path( |
|
os.path.dirname(os.path.commonprefix(dirs)) |
|
) |
|
|
|
|
|
|
|
|
|
|
|
namespace = logdir.replace(filename, "").replace(rootdir, "").strip("/") |
|
|
|
|
|
|
|
if len(dirs) == 1 and namespace not in ["train", "validation"]: |
|
namespace = None |
|
else: |
|
namespace = logdir.replace(filename, "").replace(rootdir, "").strip("/") |
|
|
|
return namespace |
|
|
|
def add(self, logdir: str, save: bool, root_dir: str) -> None: |
|
logdir = util.to_forward_slash_path(logdir) |
|
root_dir = util.to_forward_slash_path(root_dir) |
|
if logdir in self._logdirs: |
|
return |
|
namespace = self._calculate_namespace(logdir, root_dir) |
|
|
|
|
|
if not self._consumer: |
|
self._consumer = TBEventConsumer( |
|
self, self._watcher_queue, self._run_proto, self._settings |
|
) |
|
self._consumer.start() |
|
|
|
tbdir_watcher = TBDirWatcher( |
|
self, logdir, save, namespace, self._watcher_queue, self._force |
|
) |
|
self._logdirs[logdir] = tbdir_watcher |
|
tbdir_watcher.start() |
|
|
|
def finish(self) -> None: |
|
for tbdirwatcher in self._logdirs.values(): |
|
tbdirwatcher.shutdown() |
|
for tbdirwatcher in self._logdirs.values(): |
|
tbdirwatcher.finish() |
|
if self._consumer: |
|
self._consumer.finish() |
|
|
|
|
|
class TBDirWatcher: |
|
def __init__( |
|
self, |
|
tbwatcher: "TBWatcher", |
|
logdir: str, |
|
save: bool, |
|
namespace: Optional[str], |
|
queue: "PriorityQueue", |
|
force: bool = False, |
|
) -> None: |
|
self.directory_watcher = util.get_module( |
|
"tensorboard.backend.event_processing.directory_watcher", |
|
required="Please install tensorboard package", |
|
) |
|
|
|
|
|
|
|
|
|
self.tf_compat = util.get_module( |
|
"tensorboard.compat", required="Please install tensorboard package" |
|
) |
|
self._tbwatcher = tbwatcher |
|
self._generator = self.directory_watcher.DirectoryWatcher( |
|
logdir, self._loader(save, namespace), self._is_our_tfevents_file |
|
) |
|
self._thread = threading.Thread(target=self._thread_except_body) |
|
self._first_event_timestamp = None |
|
self._shutdown = threading.Event() |
|
self._queue = queue |
|
self._file_version = None |
|
self._namespace = namespace |
|
self._logdir = logdir |
|
self._hostname = socket.gethostname() |
|
self._force = force |
|
self._process_events_lock = threading.Lock() |
|
|
|
def start(self) -> None: |
|
self._thread.start() |
|
|
|
def _is_our_tfevents_file(self, path: str) -> bool: |
|
"""Check if a path has been modified since launch and contains tfevents.""" |
|
if not path: |
|
raise ValueError("Path must be a nonempty string") |
|
path = self.tf_compat.tf.compat.as_str_any(path) |
|
if self._force: |
|
return is_tfevents_file_created_by(path, None, None) |
|
else: |
|
return is_tfevents_file_created_by( |
|
path, self._hostname, self._tbwatcher._settings.x_start_time |
|
) |
|
|
|
def _loader( |
|
self, save: bool = True, namespace: Optional[str] = None |
|
) -> "EventFileLoader": |
|
"""Incredibly hacky class generator to optionally save / prefix tfevent files.""" |
|
_loader_interface = self._tbwatcher._interface |
|
_loader_settings = self._tbwatcher._settings |
|
try: |
|
from tensorboard.backend.event_processing import event_file_loader |
|
except ImportError: |
|
raise Exception("Please install tensorboard package") |
|
|
|
class EventFileLoader(event_file_loader.EventFileLoader): |
|
def __init__(self, file_path: str) -> None: |
|
super().__init__(file_path) |
|
if save: |
|
if REMOTE_FILE_TOKEN in file_path: |
|
logger.warning( |
|
"Not persisting remote tfevent file: %s", file_path |
|
) |
|
else: |
|
|
|
logdir = os.path.dirname(file_path) |
|
parts = list(os.path.split(logdir)) |
|
if namespace and parts[-1] == namespace: |
|
parts.pop() |
|
logdir = os.path.join(*parts) |
|
_link_and_save_file( |
|
path=file_path, |
|
base_path=logdir, |
|
interface=_loader_interface, |
|
settings=_loader_settings, |
|
) |
|
|
|
return EventFileLoader |
|
|
|
def _process_events(self, shutdown_call: bool = False) -> None: |
|
try: |
|
with self._process_events_lock: |
|
for event in self._generator.Load(): |
|
self.process_event(event) |
|
except ( |
|
self.directory_watcher.DirectoryDeletedError, |
|
StopIteration, |
|
RuntimeError, |
|
OSError, |
|
) as e: |
|
|
|
logger.debug("Encountered tensorboard directory watcher error: %s", e) |
|
if not self._shutdown.is_set() and not shutdown_call: |
|
time.sleep(ERROR_DELAY) |
|
|
|
def _thread_except_body(self) -> None: |
|
try: |
|
self._thread_body() |
|
except Exception: |
|
logger.exception("generic exception in TBDirWatcher thread") |
|
raise |
|
|
|
def _thread_body(self) -> None: |
|
"""Check for new events every second.""" |
|
shutdown_time: Optional[float] = None |
|
while True: |
|
self._process_events() |
|
if self._shutdown.is_set(): |
|
now = time.time() |
|
if not shutdown_time: |
|
shutdown_time = now + SHUTDOWN_DELAY |
|
elif now > shutdown_time: |
|
break |
|
time.sleep(1) |
|
|
|
def process_event(self, event: "ProtoEvent") -> None: |
|
|
|
if self._first_event_timestamp is None: |
|
self._first_event_timestamp = event.wall_time |
|
|
|
if event.HasField("file_version"): |
|
self._file_version = event.file_version |
|
|
|
if event.HasField("summary"): |
|
self._queue.put(Event(event, self._namespace)) |
|
|
|
def shutdown(self) -> None: |
|
self._process_events(shutdown_call=True) |
|
self._shutdown.set() |
|
|
|
def finish(self) -> None: |
|
self.shutdown() |
|
self._thread.join() |
|
|
|
|
|
class Event: |
|
"""An event wrapper to enable priority queueing.""" |
|
|
|
def __init__(self, event: "ProtoEvent", namespace: Optional[str]): |
|
self.event = event |
|
self.namespace = namespace |
|
self.created_at = time.time() |
|
|
|
def __lt__(self, other: "Event") -> bool: |
|
if self.event.wall_time < other.event.wall_time: |
|
return True |
|
return False |
|
|
|
|
|
class TBEventConsumer: |
|
"""Consume tfevents from a priority queue. |
|
|
|
There should always only be one of these per run_manager. We wait for 10 seconds of |
|
queued events to reduce the chance of multiple tfevent files triggering out of order |
|
steps. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
tbwatcher: TBWatcher, |
|
queue: "PriorityQueue", |
|
run_proto: "RunRecord", |
|
settings: "SettingsStatic", |
|
delay: int = 10, |
|
) -> None: |
|
self._tbwatcher = tbwatcher |
|
self._queue = queue |
|
self._thread = threading.Thread(target=self._thread_except_body) |
|
self._shutdown = threading.Event() |
|
self.tb_history = TBHistory() |
|
self._delay = delay |
|
|
|
|
|
|
|
|
|
def datatypes_cb(fname: GlobStr) -> None: |
|
files: FilesDict = dict(files=[(fname, "now")]) |
|
self._tbwatcher._interface.publish_files(files) |
|
|
|
|
|
self._internal_run = internal_run.InternalRun(run_proto, settings, datatypes_cb) |
|
self._internal_run._set_internal_run_interface(self._tbwatcher._interface) |
|
|
|
def start(self) -> None: |
|
self._start_time = time.time() |
|
self._thread.start() |
|
|
|
def finish(self) -> None: |
|
self._delay = 0 |
|
self._shutdown.set() |
|
self._thread.join() |
|
while not self._queue.empty(): |
|
event = self._queue.get(True, 1) |
|
if event: |
|
self._handle_event(event, history=self.tb_history) |
|
items = self.tb_history._get_and_reset() |
|
for item in items: |
|
self._save_row( |
|
item, |
|
) |
|
|
|
def _thread_except_body(self) -> None: |
|
try: |
|
self._thread_body() |
|
except Exception: |
|
logger.exception("generic exception in TBEventConsumer thread") |
|
raise |
|
|
|
def _thread_body(self) -> None: |
|
while True: |
|
try: |
|
event = self._queue.get(True, 1) |
|
|
|
if ( |
|
time.time() < self._start_time + self._delay |
|
and not self._shutdown.is_set() |
|
): |
|
self._queue.put(event) |
|
time.sleep(0.1) |
|
continue |
|
except queue.Empty: |
|
event = None |
|
if self._shutdown.is_set(): |
|
break |
|
if event: |
|
self._handle_event(event, history=self.tb_history) |
|
items = self.tb_history._get_and_reset() |
|
for item in items: |
|
self._save_row( |
|
item, |
|
) |
|
|
|
self.tb_history._flush() |
|
items = self.tb_history._get_and_reset() |
|
for item in items: |
|
self._save_row(item) |
|
|
|
def _handle_event( |
|
self, event: "ProtoEvent", history: Optional["TBHistory"] = None |
|
) -> None: |
|
wandb.tensorboard._log( |
|
event.event, |
|
step=event.event.step, |
|
namespace=event.namespace, |
|
history=history, |
|
) |
|
|
|
def _save_row(self, row: "HistoryDict") -> None: |
|
chart_keys = set() |
|
for k, v in row.items(): |
|
if isinstance(v, CustomChart): |
|
chart_keys.add(k) |
|
v.set_key(k) |
|
self._tbwatcher._interface.publish_config( |
|
key=v.spec.config_key, |
|
val=v.spec.config_value, |
|
) |
|
|
|
for k in chart_keys: |
|
chart = row.pop(k) |
|
if isinstance(chart, CustomChart): |
|
row[chart.spec.table_key] = chart.table |
|
|
|
self._tbwatcher._interface.publish_history( |
|
self._internal_run, |
|
row, |
|
publish_step=False, |
|
) |
|
|
|
|
|
class TBHistory: |
|
_data: "HistoryDict" |
|
_added: "List[HistoryDict]" |
|
|
|
def __init__(self) -> None: |
|
self._step = 0 |
|
self._step_size = 0 |
|
self._data = dict() |
|
self._added = [] |
|
|
|
def _flush(self) -> None: |
|
if not self._data: |
|
return |
|
|
|
|
|
|
|
if self._step_size > util.MAX_LINE_BYTES: |
|
metrics = [(k, sys.getsizeof(v)) for k, v in self._data.items()] |
|
metrics.sort(key=lambda t: t[1], reverse=True) |
|
bad = 0 |
|
dropped_keys = [] |
|
for k, v in metrics: |
|
|
|
if self._step_size - bad < util.MAX_LINE_BYTES - 100000: |
|
break |
|
else: |
|
bad += v |
|
dropped_keys.append(k) |
|
del self._data[k] |
|
wandb.termwarn( |
|
f"Step {self._step} exceeds max data limit, dropping {len(dropped_keys)} of the largest keys:" |
|
) |
|
print("\t" + ("\n\t".join(dropped_keys))) |
|
self._data["_step"] = self._step |
|
self._added.append(self._data) |
|
self._step += 1 |
|
self._step_size = 0 |
|
|
|
def add(self, d: "HistoryDict") -> None: |
|
self._flush() |
|
self._data = dict() |
|
self._data.update(self._track_history_dict(d)) |
|
|
|
def _track_history_dict(self, d: "HistoryDict") -> "HistoryDict": |
|
e = {} |
|
for k in d.keys(): |
|
e[k] = d[k] |
|
self._step_size += sys.getsizeof(e[k]) |
|
return e |
|
|
|
def _row_update(self, d: "HistoryDict") -> None: |
|
self._data.update(self._track_history_dict(d)) |
|
|
|
def _get_and_reset(self) -> "List[HistoryDict]": |
|
added = self._added[:] |
|
self._added = [] |
|
return added |
|
|