jamtur01's picture
Upload folder using huggingface_hub
9c6594c verified
"""monkeypatch: patch code to add tensorboard hooks."""
import os
import re
import socket
from typing import Any, Optional
import wandb
import wandb.util
TENSORBOARD_C_MODULE = "tensorflow.python.ops.gen_summary_ops"
TENSORBOARD_X_MODULE = "tensorboardX.writer"
TENSORFLOW_PY_MODULE = "tensorflow.python.summary.writer.writer"
TENSORBOARD_WRITER_MODULE = "tensorboard.summary.writer.event_file_writer"
TENSORBOARD_PYTORCH_MODULE = "torch.utils.tensorboard.writer"
def unpatch() -> None:
for module, method in wandb.patched["tensorboard"]:
writer = wandb.util.get_module(module, lazy=False)
setattr(writer, method, getattr(writer, f"orig_{method}"))
wandb.patched["tensorboard"] = []
def patch(
save: bool = True,
tensorboard_x: Optional[bool] = None,
pytorch: Optional[bool] = None,
root_logdir: str = "",
) -> None:
if len(wandb.patched["tensorboard"]) > 0:
raise ValueError(
"Tensorboard already patched. Call `wandb.tensorboard.unpatch()` first; "
"remove `sync_tensorboard=True` from `wandb.init`; "
"or only call `wandb.tensorboard.patch` once."
)
# TODO: Some older versions of tensorflow don't require tensorboard to be present.
# we may want to lift this requirement, but it's safer to have it for now
wandb.util.get_module(
"tensorboard", required="Please install tensorboard package", lazy=False
)
c_writer = wandb.util.get_module(TENSORBOARD_C_MODULE, lazy=False)
py_writer = wandb.util.get_module(TENSORFLOW_PY_MODULE, lazy=False)
tb_writer = wandb.util.get_module(TENSORBOARD_WRITER_MODULE, lazy=False)
pt_writer = wandb.util.get_module(TENSORBOARD_PYTORCH_MODULE, lazy=False)
tbx_writer = wandb.util.get_module(TENSORBOARD_X_MODULE, lazy=False)
if not pytorch and not tensorboard_x and c_writer:
_patch_tensorflow2(
writer=c_writer,
module=TENSORBOARD_C_MODULE,
save=save,
root_logdir=root_logdir,
)
# This is for tensorflow <= 1.15 (tf.compat.v1.summary.FileWriter)
if py_writer:
_patch_file_writer(
writer=py_writer,
module=TENSORFLOW_PY_MODULE,
save=save,
root_logdir=root_logdir,
)
if tb_writer:
_patch_file_writer(
writer=tb_writer,
module=TENSORBOARD_WRITER_MODULE,
save=save,
root_logdir=root_logdir,
)
if pt_writer:
_patch_file_writer(
writer=pt_writer,
module=TENSORBOARD_PYTORCH_MODULE,
save=save,
root_logdir=root_logdir,
)
if tbx_writer:
_patch_file_writer(
writer=tbx_writer,
module=TENSORBOARD_X_MODULE,
save=save,
root_logdir=root_logdir,
)
if not c_writer and not tb_writer and not tb_writer:
wandb.termerror("Unsupported tensorboard configuration")
def _patch_tensorflow2(
writer: Any,
module: Any,
save: bool = True,
root_logdir: str = "",
) -> None:
# This configures TensorFlow 2 style Tensorboard logging
old_csfw_func = writer.create_summary_file_writer
logdir_hist = []
def new_csfw_func(*args: Any, **kwargs: Any) -> Any:
logdir = (
kwargs["logdir"].numpy().decode("utf8")
if hasattr(kwargs["logdir"], "numpy")
else kwargs["logdir"]
)
logdir_hist.append(logdir)
root_logdir_arg = root_logdir
if len(set(logdir_hist)) > 1 and root_logdir == "":
wandb.termwarn(
"When using several event log directories, "
'please call `wandb.tensorboard.patch(root_logdir="...")` before `wandb.init`'
)
# if the logdir contains the hostname, the writer was not given a logdir.
# In this case, the generated logdir
# is generated and ends with the hostname, update the root_logdir to match.
hostname = socket.gethostname()
search = re.search(rf"-\d+_{hostname}", logdir)
if search:
root_logdir_arg = logdir[: search.span()[1]]
elif root_logdir is not None and not os.path.abspath(logdir).startswith(
os.path.abspath(root_logdir)
):
wandb.termwarn(
"Found log directory outside of given root_logdir, "
f"dropping given root_logdir for event file in {logdir}"
)
root_logdir_arg = ""
_notify_tensorboard_logdir(logdir, save=save, root_logdir=root_logdir_arg)
return old_csfw_func(*args, **kwargs)
writer.orig_create_summary_file_writer = old_csfw_func
writer.create_summary_file_writer = new_csfw_func
wandb.patched["tensorboard"].append([module, "create_summary_file_writer"])
def _patch_file_writer(
writer: Any,
module: Any,
save: bool = True,
root_logdir: str = "",
) -> None:
# This configures non-TensorFlow Tensorboard logging, or tensorflow <= 1.15
logdir_hist = []
class TBXEventFileWriter(writer.EventFileWriter):
def __init__(self, logdir: str, *args: Any, **kwargs: Any) -> None:
logdir_hist.append(logdir)
root_logdir_arg = root_logdir
if len(set(logdir_hist)) > 1 and root_logdir == "":
wandb.termwarn(
"When using several event log directories, "
'please call `wandb.tensorboard.patch(root_logdir="...")` before `wandb.init`'
)
# if the logdir contains the hostname, the writer was not given a logdir.
# In this case, the logdir is generated and ends with the hostname,
# update the root_logdir to match.
hostname = socket.gethostname()
search = re.search(rf"-\d+_{hostname}", logdir)
if search:
root_logdir_arg = logdir[: search.span()[1]]
elif root_logdir is not None and not os.path.abspath(logdir).startswith(
os.path.abspath(root_logdir)
):
wandb.termwarn(
"Found log directory outside of given root_logdir, "
f"dropping given root_logdir for event file in {logdir}"
)
root_logdir_arg = ""
_notify_tensorboard_logdir(logdir, save=save, root_logdir=root_logdir_arg)
super().__init__(logdir, *args, **kwargs)
writer.orig_EventFileWriter = writer.EventFileWriter
writer.EventFileWriter = TBXEventFileWriter
wandb.patched["tensorboard"].append([module, "EventFileWriter"])
def _notify_tensorboard_logdir(
logdir: str, save: bool = True, root_logdir: str = ""
) -> None:
if wandb.run is not None:
wandb.run._tensorboard_callback(logdir, save=save, root_logdir=root_logdir)