File size: 6,886 Bytes
9c6594c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
"""monkeypatch: patch code to add tensorboard hooks."""

import os
import re
import socket
from typing import Any, Optional

import wandb
import wandb.util

TENSORBOARD_C_MODULE = "tensorflow.python.ops.gen_summary_ops"
TENSORBOARD_X_MODULE = "tensorboardX.writer"
TENSORFLOW_PY_MODULE = "tensorflow.python.summary.writer.writer"
TENSORBOARD_WRITER_MODULE = "tensorboard.summary.writer.event_file_writer"
TENSORBOARD_PYTORCH_MODULE = "torch.utils.tensorboard.writer"


def unpatch() -> None:
    for module, method in wandb.patched["tensorboard"]:
        writer = wandb.util.get_module(module, lazy=False)
        setattr(writer, method, getattr(writer, f"orig_{method}"))
    wandb.patched["tensorboard"] = []


def patch(
    save: bool = True,
    tensorboard_x: Optional[bool] = None,
    pytorch: Optional[bool] = None,
    root_logdir: str = "",
) -> None:
    if len(wandb.patched["tensorboard"]) > 0:
        raise ValueError(
            "Tensorboard already patched. Call `wandb.tensorboard.unpatch()` first; "
            "remove `sync_tensorboard=True` from `wandb.init`; "
            "or only call `wandb.tensorboard.patch` once."
        )

    # TODO: Some older versions of tensorflow don't require tensorboard to be present.
    # we may want to lift this requirement, but it's safer to have it for now
    wandb.util.get_module(
        "tensorboard", required="Please install tensorboard package", lazy=False
    )
    c_writer = wandb.util.get_module(TENSORBOARD_C_MODULE, lazy=False)
    py_writer = wandb.util.get_module(TENSORFLOW_PY_MODULE, lazy=False)
    tb_writer = wandb.util.get_module(TENSORBOARD_WRITER_MODULE, lazy=False)
    pt_writer = wandb.util.get_module(TENSORBOARD_PYTORCH_MODULE, lazy=False)
    tbx_writer = wandb.util.get_module(TENSORBOARD_X_MODULE, lazy=False)

    if not pytorch and not tensorboard_x and c_writer:
        _patch_tensorflow2(
            writer=c_writer,
            module=TENSORBOARD_C_MODULE,
            save=save,
            root_logdir=root_logdir,
        )
    # This is for tensorflow <= 1.15 (tf.compat.v1.summary.FileWriter)
    if py_writer:
        _patch_file_writer(
            writer=py_writer,
            module=TENSORFLOW_PY_MODULE,
            save=save,
            root_logdir=root_logdir,
        )
    if tb_writer:
        _patch_file_writer(
            writer=tb_writer,
            module=TENSORBOARD_WRITER_MODULE,
            save=save,
            root_logdir=root_logdir,
        )
    if pt_writer:
        _patch_file_writer(
            writer=pt_writer,
            module=TENSORBOARD_PYTORCH_MODULE,
            save=save,
            root_logdir=root_logdir,
        )
    if tbx_writer:
        _patch_file_writer(
            writer=tbx_writer,
            module=TENSORBOARD_X_MODULE,
            save=save,
            root_logdir=root_logdir,
        )
    if not c_writer and not tb_writer and not tb_writer:
        wandb.termerror("Unsupported tensorboard configuration")


def _patch_tensorflow2(
    writer: Any,
    module: Any,
    save: bool = True,
    root_logdir: str = "",
) -> None:
    # This configures TensorFlow 2 style Tensorboard logging
    old_csfw_func = writer.create_summary_file_writer
    logdir_hist = []

    def new_csfw_func(*args: Any, **kwargs: Any) -> Any:
        logdir = (
            kwargs["logdir"].numpy().decode("utf8")
            if hasattr(kwargs["logdir"], "numpy")
            else kwargs["logdir"]
        )
        logdir_hist.append(logdir)
        root_logdir_arg = root_logdir

        if len(set(logdir_hist)) > 1 and root_logdir == "":
            wandb.termwarn(
                "When using several event log directories, "
                'please call `wandb.tensorboard.patch(root_logdir="...")` before `wandb.init`'
            )
        # if the logdir contains the hostname, the writer was not given a logdir.
        # In this case, the generated logdir
        # is generated and ends with the hostname, update the root_logdir to match.
        hostname = socket.gethostname()
        search = re.search(rf"-\d+_{hostname}", logdir)
        if search:
            root_logdir_arg = logdir[: search.span()[1]]
        elif root_logdir is not None and not os.path.abspath(logdir).startswith(
            os.path.abspath(root_logdir)
        ):
            wandb.termwarn(
                "Found log directory outside of given root_logdir, "
                f"dropping given root_logdir for event file in {logdir}"
            )
            root_logdir_arg = ""

        _notify_tensorboard_logdir(logdir, save=save, root_logdir=root_logdir_arg)
        return old_csfw_func(*args, **kwargs)

    writer.orig_create_summary_file_writer = old_csfw_func
    writer.create_summary_file_writer = new_csfw_func
    wandb.patched["tensorboard"].append([module, "create_summary_file_writer"])


def _patch_file_writer(
    writer: Any,
    module: Any,
    save: bool = True,
    root_logdir: str = "",
) -> None:
    # This configures non-TensorFlow Tensorboard logging, or tensorflow <= 1.15
    logdir_hist = []

    class TBXEventFileWriter(writer.EventFileWriter):
        def __init__(self, logdir: str, *args: Any, **kwargs: Any) -> None:
            logdir_hist.append(logdir)
            root_logdir_arg = root_logdir
            if len(set(logdir_hist)) > 1 and root_logdir == "":
                wandb.termwarn(
                    "When using several event log directories, "
                    'please call `wandb.tensorboard.patch(root_logdir="...")` before `wandb.init`'
                )

            # if the logdir contains the hostname, the writer was not given a logdir.
            # In this case, the logdir is generated and ends with the hostname,
            # update the root_logdir to match.
            hostname = socket.gethostname()
            search = re.search(rf"-\d+_{hostname}", logdir)
            if search:
                root_logdir_arg = logdir[: search.span()[1]]

            elif root_logdir is not None and not os.path.abspath(logdir).startswith(
                os.path.abspath(root_logdir)
            ):
                wandb.termwarn(
                    "Found log directory outside of given root_logdir, "
                    f"dropping given root_logdir for event file in {logdir}"
                )
                root_logdir_arg = ""

            _notify_tensorboard_logdir(logdir, save=save, root_logdir=root_logdir_arg)

            super().__init__(logdir, *args, **kwargs)

    writer.orig_EventFileWriter = writer.EventFileWriter
    writer.EventFileWriter = TBXEventFileWriter
    wandb.patched["tensorboard"].append([module, "EventFileWriter"])


def _notify_tensorboard_logdir(
    logdir: str, save: bool = True, root_logdir: str = ""
) -> None:
    if wandb.run is not None:
        wandb.run._tensorboard_callback(logdir, save=save, root_logdir=root_logdir)