File size: 6,886 Bytes
9c6594c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 |
"""monkeypatch: patch code to add tensorboard hooks."""
import os
import re
import socket
from typing import Any, Optional
import wandb
import wandb.util
TENSORBOARD_C_MODULE = "tensorflow.python.ops.gen_summary_ops"
TENSORBOARD_X_MODULE = "tensorboardX.writer"
TENSORFLOW_PY_MODULE = "tensorflow.python.summary.writer.writer"
TENSORBOARD_WRITER_MODULE = "tensorboard.summary.writer.event_file_writer"
TENSORBOARD_PYTORCH_MODULE = "torch.utils.tensorboard.writer"
def unpatch() -> None:
for module, method in wandb.patched["tensorboard"]:
writer = wandb.util.get_module(module, lazy=False)
setattr(writer, method, getattr(writer, f"orig_{method}"))
wandb.patched["tensorboard"] = []
def patch(
save: bool = True,
tensorboard_x: Optional[bool] = None,
pytorch: Optional[bool] = None,
root_logdir: str = "",
) -> None:
if len(wandb.patched["tensorboard"]) > 0:
raise ValueError(
"Tensorboard already patched. Call `wandb.tensorboard.unpatch()` first; "
"remove `sync_tensorboard=True` from `wandb.init`; "
"or only call `wandb.tensorboard.patch` once."
)
# TODO: Some older versions of tensorflow don't require tensorboard to be present.
# we may want to lift this requirement, but it's safer to have it for now
wandb.util.get_module(
"tensorboard", required="Please install tensorboard package", lazy=False
)
c_writer = wandb.util.get_module(TENSORBOARD_C_MODULE, lazy=False)
py_writer = wandb.util.get_module(TENSORFLOW_PY_MODULE, lazy=False)
tb_writer = wandb.util.get_module(TENSORBOARD_WRITER_MODULE, lazy=False)
pt_writer = wandb.util.get_module(TENSORBOARD_PYTORCH_MODULE, lazy=False)
tbx_writer = wandb.util.get_module(TENSORBOARD_X_MODULE, lazy=False)
if not pytorch and not tensorboard_x and c_writer:
_patch_tensorflow2(
writer=c_writer,
module=TENSORBOARD_C_MODULE,
save=save,
root_logdir=root_logdir,
)
# This is for tensorflow <= 1.15 (tf.compat.v1.summary.FileWriter)
if py_writer:
_patch_file_writer(
writer=py_writer,
module=TENSORFLOW_PY_MODULE,
save=save,
root_logdir=root_logdir,
)
if tb_writer:
_patch_file_writer(
writer=tb_writer,
module=TENSORBOARD_WRITER_MODULE,
save=save,
root_logdir=root_logdir,
)
if pt_writer:
_patch_file_writer(
writer=pt_writer,
module=TENSORBOARD_PYTORCH_MODULE,
save=save,
root_logdir=root_logdir,
)
if tbx_writer:
_patch_file_writer(
writer=tbx_writer,
module=TENSORBOARD_X_MODULE,
save=save,
root_logdir=root_logdir,
)
if not c_writer and not tb_writer and not tb_writer:
wandb.termerror("Unsupported tensorboard configuration")
def _patch_tensorflow2(
writer: Any,
module: Any,
save: bool = True,
root_logdir: str = "",
) -> None:
# This configures TensorFlow 2 style Tensorboard logging
old_csfw_func = writer.create_summary_file_writer
logdir_hist = []
def new_csfw_func(*args: Any, **kwargs: Any) -> Any:
logdir = (
kwargs["logdir"].numpy().decode("utf8")
if hasattr(kwargs["logdir"], "numpy")
else kwargs["logdir"]
)
logdir_hist.append(logdir)
root_logdir_arg = root_logdir
if len(set(logdir_hist)) > 1 and root_logdir == "":
wandb.termwarn(
"When using several event log directories, "
'please call `wandb.tensorboard.patch(root_logdir="...")` before `wandb.init`'
)
# if the logdir contains the hostname, the writer was not given a logdir.
# In this case, the generated logdir
# is generated and ends with the hostname, update the root_logdir to match.
hostname = socket.gethostname()
search = re.search(rf"-\d+_{hostname}", logdir)
if search:
root_logdir_arg = logdir[: search.span()[1]]
elif root_logdir is not None and not os.path.abspath(logdir).startswith(
os.path.abspath(root_logdir)
):
wandb.termwarn(
"Found log directory outside of given root_logdir, "
f"dropping given root_logdir for event file in {logdir}"
)
root_logdir_arg = ""
_notify_tensorboard_logdir(logdir, save=save, root_logdir=root_logdir_arg)
return old_csfw_func(*args, **kwargs)
writer.orig_create_summary_file_writer = old_csfw_func
writer.create_summary_file_writer = new_csfw_func
wandb.patched["tensorboard"].append([module, "create_summary_file_writer"])
def _patch_file_writer(
writer: Any,
module: Any,
save: bool = True,
root_logdir: str = "",
) -> None:
# This configures non-TensorFlow Tensorboard logging, or tensorflow <= 1.15
logdir_hist = []
class TBXEventFileWriter(writer.EventFileWriter):
def __init__(self, logdir: str, *args: Any, **kwargs: Any) -> None:
logdir_hist.append(logdir)
root_logdir_arg = root_logdir
if len(set(logdir_hist)) > 1 and root_logdir == "":
wandb.termwarn(
"When using several event log directories, "
'please call `wandb.tensorboard.patch(root_logdir="...")` before `wandb.init`'
)
# if the logdir contains the hostname, the writer was not given a logdir.
# In this case, the logdir is generated and ends with the hostname,
# update the root_logdir to match.
hostname = socket.gethostname()
search = re.search(rf"-\d+_{hostname}", logdir)
if search:
root_logdir_arg = logdir[: search.span()[1]]
elif root_logdir is not None and not os.path.abspath(logdir).startswith(
os.path.abspath(root_logdir)
):
wandb.termwarn(
"Found log directory outside of given root_logdir, "
f"dropping given root_logdir for event file in {logdir}"
)
root_logdir_arg = ""
_notify_tensorboard_logdir(logdir, save=save, root_logdir=root_logdir_arg)
super().__init__(logdir, *args, **kwargs)
writer.orig_EventFileWriter = writer.EventFileWriter
writer.EventFileWriter = TBXEventFileWriter
wandb.patched["tensorboard"].append([module, "EventFileWriter"])
def _notify_tensorboard_logdir(
logdir: str, save: bool = True, root_logdir: str = ""
) -> None:
if wandb.run is not None:
wandb.run._tensorboard_callback(logdir, save=save, root_logdir=root_logdir)
|