"""monkeypatch: patch code to add tensorboard hooks.""" import os import re import socket from typing import Any, Optional import wandb import wandb.util TENSORBOARD_C_MODULE = "tensorflow.python.ops.gen_summary_ops" TENSORBOARD_X_MODULE = "tensorboardX.writer" TENSORFLOW_PY_MODULE = "tensorflow.python.summary.writer.writer" TENSORBOARD_WRITER_MODULE = "tensorboard.summary.writer.event_file_writer" TENSORBOARD_PYTORCH_MODULE = "torch.utils.tensorboard.writer" def unpatch() -> None: for module, method in wandb.patched["tensorboard"]: writer = wandb.util.get_module(module, lazy=False) setattr(writer, method, getattr(writer, f"orig_{method}")) wandb.patched["tensorboard"] = [] def patch( save: bool = True, tensorboard_x: Optional[bool] = None, pytorch: Optional[bool] = None, root_logdir: str = "", ) -> None: if len(wandb.patched["tensorboard"]) > 0: raise ValueError( "Tensorboard already patched. Call `wandb.tensorboard.unpatch()` first; " "remove `sync_tensorboard=True` from `wandb.init`; " "or only call `wandb.tensorboard.patch` once." ) # TODO: Some older versions of tensorflow don't require tensorboard to be present. # we may want to lift this requirement, but it's safer to have it for now wandb.util.get_module( "tensorboard", required="Please install tensorboard package", lazy=False ) c_writer = wandb.util.get_module(TENSORBOARD_C_MODULE, lazy=False) py_writer = wandb.util.get_module(TENSORFLOW_PY_MODULE, lazy=False) tb_writer = wandb.util.get_module(TENSORBOARD_WRITER_MODULE, lazy=False) pt_writer = wandb.util.get_module(TENSORBOARD_PYTORCH_MODULE, lazy=False) tbx_writer = wandb.util.get_module(TENSORBOARD_X_MODULE, lazy=False) if not pytorch and not tensorboard_x and c_writer: _patch_tensorflow2( writer=c_writer, module=TENSORBOARD_C_MODULE, save=save, root_logdir=root_logdir, ) # This is for tensorflow <= 1.15 (tf.compat.v1.summary.FileWriter) if py_writer: _patch_file_writer( writer=py_writer, module=TENSORFLOW_PY_MODULE, save=save, root_logdir=root_logdir, ) if tb_writer: _patch_file_writer( writer=tb_writer, module=TENSORBOARD_WRITER_MODULE, save=save, root_logdir=root_logdir, ) if pt_writer: _patch_file_writer( writer=pt_writer, module=TENSORBOARD_PYTORCH_MODULE, save=save, root_logdir=root_logdir, ) if tbx_writer: _patch_file_writer( writer=tbx_writer, module=TENSORBOARD_X_MODULE, save=save, root_logdir=root_logdir, ) if not c_writer and not tb_writer and not tb_writer: wandb.termerror("Unsupported tensorboard configuration") def _patch_tensorflow2( writer: Any, module: Any, save: bool = True, root_logdir: str = "", ) -> None: # This configures TensorFlow 2 style Tensorboard logging old_csfw_func = writer.create_summary_file_writer logdir_hist = [] def new_csfw_func(*args: Any, **kwargs: Any) -> Any: logdir = ( kwargs["logdir"].numpy().decode("utf8") if hasattr(kwargs["logdir"], "numpy") else kwargs["logdir"] ) logdir_hist.append(logdir) root_logdir_arg = root_logdir if len(set(logdir_hist)) > 1 and root_logdir == "": wandb.termwarn( "When using several event log directories, " 'please call `wandb.tensorboard.patch(root_logdir="...")` before `wandb.init`' ) # if the logdir contains the hostname, the writer was not given a logdir. # In this case, the generated logdir # is generated and ends with the hostname, update the root_logdir to match. hostname = socket.gethostname() search = re.search(rf"-\d+_{hostname}", logdir) if search: root_logdir_arg = logdir[: search.span()[1]] elif root_logdir is not None and not os.path.abspath(logdir).startswith( os.path.abspath(root_logdir) ): wandb.termwarn( "Found log directory outside of given root_logdir, " f"dropping given root_logdir for event file in {logdir}" ) root_logdir_arg = "" _notify_tensorboard_logdir(logdir, save=save, root_logdir=root_logdir_arg) return old_csfw_func(*args, **kwargs) writer.orig_create_summary_file_writer = old_csfw_func writer.create_summary_file_writer = new_csfw_func wandb.patched["tensorboard"].append([module, "create_summary_file_writer"]) def _patch_file_writer( writer: Any, module: Any, save: bool = True, root_logdir: str = "", ) -> None: # This configures non-TensorFlow Tensorboard logging, or tensorflow <= 1.15 logdir_hist = [] class TBXEventFileWriter(writer.EventFileWriter): def __init__(self, logdir: str, *args: Any, **kwargs: Any) -> None: logdir_hist.append(logdir) root_logdir_arg = root_logdir if len(set(logdir_hist)) > 1 and root_logdir == "": wandb.termwarn( "When using several event log directories, " 'please call `wandb.tensorboard.patch(root_logdir="...")` before `wandb.init`' ) # if the logdir contains the hostname, the writer was not given a logdir. # In this case, the logdir is generated and ends with the hostname, # update the root_logdir to match. hostname = socket.gethostname() search = re.search(rf"-\d+_{hostname}", logdir) if search: root_logdir_arg = logdir[: search.span()[1]] elif root_logdir is not None and not os.path.abspath(logdir).startswith( os.path.abspath(root_logdir) ): wandb.termwarn( "Found log directory outside of given root_logdir, " f"dropping given root_logdir for event file in {logdir}" ) root_logdir_arg = "" _notify_tensorboard_logdir(logdir, save=save, root_logdir=root_logdir_arg) super().__init__(logdir, *args, **kwargs) writer.orig_EventFileWriter = writer.EventFileWriter writer.EventFileWriter = TBXEventFileWriter wandb.patched["tensorboard"].append([module, "EventFileWriter"]) def _notify_tensorboard_logdir( logdir: str, save: bool = True, root_logdir: str = "" ) -> None: if wandb.run is not None: wandb.run._tensorboard_callback(logdir, save=save, root_logdir=root_logdir)