File size: 14,076 Bytes
9c6594c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
import io
import re
import time
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union

import wandb
import wandb.util
from wandb.sdk.lib import telemetry

if TYPE_CHECKING:
    import numpy as np

    from wandb.sdk.internal.tb_watcher import TBHistory

# We have at least the default namestep and a global step to track
# TODO: reset this structure on wandb.finish
STEPS: Dict[str, Dict[str, Any]] = {
    "": {"step": 0},
    "global": {"step": 0, "last_log": None},
}
# TODO(cling): Set these when tensorboard behavior is configured.
# We support rate limited logging by setting this to number of seconds,
# can be a floating point.
RATE_LIMIT_SECONDS: Optional[Union[float, int]] = None
IGNORE_KINDS = ["graphs"]
tensor_util = wandb.util.get_module("tensorboard.util.tensor_util")


# prefer tensorboard, fallback to protobuf in tensorflow when tboard isn't available
pb = wandb.util.get_module(
    "tensorboard.compat.proto.summary_pb2"
) or wandb.util.get_module("tensorflow.core.framework.summary_pb2")

Summary = pb.Summary if pb else None


def make_ndarray(tensor: Any) -> Optional["np.ndarray"]:
    if tensor_util:
        res = tensor_util.make_ndarray(tensor)
        # Tensorboard can log generic objects, and we don't want to save them
        if res.dtype == "object":
            return None
        else:
            return res  # type: ignore
    else:
        wandb.termwarn(
            "Can't convert tensor summary, upgrade tensorboard with `pip"
            " install tensorboard --upgrade`"
        )
        return None


def namespaced_tag(tag: str, namespace: str = "") -> str:
    if not namespace:
        return tag
    else:
        return namespace + "/" + tag


def history_image_key(key: str, namespace: str = "") -> str:
    """Convert invalid filesystem characters to _ for use in History keys.

    Unfortunately this means currently certain image keys will collide silently. We
    implement this mapping up here in the TensorFlow stuff rather than in the History
    stuff so that we don't have to store a mapping anywhere from the original keys to
    the safe ones.
    """
    return namespaced_tag(re.sub(r"[/\\]", "_", key), namespace)


def tf_summary_to_dict(  # noqa: C901
    tf_summary_str_or_pb: Any, namespace: str = ""
) -> Optional[Dict[str, Any]]:
    """Convert a Tensorboard Summary to a dictionary.

    Accepts a tensorflow.summary.Summary, one encoded as a string,
    or a list of such encoded as strings.
    """
    values = {}
    if hasattr(tf_summary_str_or_pb, "summary"):
        summary_pb = tf_summary_str_or_pb.summary
        values[namespaced_tag("global_step", namespace)] = tf_summary_str_or_pb.step
        values["_timestamp"] = tf_summary_str_or_pb.wall_time
    elif isinstance(tf_summary_str_or_pb, (str, bytes, bytearray)):
        summary_pb = Summary()
        summary_pb.ParseFromString(tf_summary_str_or_pb)
    elif hasattr(tf_summary_str_or_pb, "__iter__"):
        summary_pb = [Summary() for _ in range(len(tf_summary_str_or_pb))]
        for i, summary in enumerate(tf_summary_str_or_pb):
            summary_pb[i].ParseFromString(summary)
            if i > 0:
                summary_pb[0].MergeFrom(summary_pb[i])
        summary_pb = summary_pb[0]
    else:
        summary_pb = tf_summary_str_or_pb

    if not hasattr(summary_pb, "value") or len(summary_pb.value) == 0:
        # Ignore these, caller is responsible for handling None
        return None

    def encode_images(_img_strs: List[bytes], _value: Any) -> None:
        try:
            from PIL import Image
        except ImportError:
            wandb.termwarn(
                "Install pillow if you are logging images with Tensorboard. "
                "To install, run `pip install pillow`.",
                repeat=False,
            )
            return None

        if len(_img_strs) == 0:
            return None

        images: List[Union[wandb.Video, wandb.Image]] = []
        for _img_str in _img_strs:
            # Supports gifs from TensorboardX
            if _img_str.startswith(b"GIF"):
                images.append(wandb.Video(io.BytesIO(_img_str), format="gif"))
            else:
                images.append(wandb.Image(Image.open(io.BytesIO(_img_str))))
        tag_idx = _value.tag.rsplit("/", 1)
        if len(tag_idx) > 1 and tag_idx[1].isdigit():
            tag, idx = tag_idx
            values.setdefault(history_image_key(tag, namespace), []).extend(images)
        else:
            values[history_image_key(_value.tag, namespace)] = images

        return None

    for value in summary_pb.value:
        kind = value.WhichOneof("value")
        if kind in IGNORE_KINDS:
            continue
        if kind == "simple_value":
            values[namespaced_tag(value.tag, namespace)] = value.simple_value
        elif kind == "tensor":
            plugin_name = value.metadata.plugin_data.plugin_name
            if plugin_name == "scalars" or plugin_name == "":
                values[namespaced_tag(value.tag, namespace)] = make_ndarray(
                    value.tensor
                )
            elif plugin_name == "images":
                img_strs = value.tensor.string_val[2:]  # First two items are dims.
                encode_images(img_strs, value)
            elif plugin_name == "histograms":
                # https://github.com/tensorflow/tensorboard/blob/master/tensorboard/plugins/histogram/summary_v2.py#L15-L26
                ndarray = make_ndarray(value.tensor)
                if ndarray is None:
                    continue
                shape = ndarray.shape
                counts = []
                bins = []
                if shape[0] > 1:
                    bins.append(ndarray[0][0])  # Add the left most edge
                    for v in ndarray:
                        counts.append(v[2])
                        bins.append(v[1])  # Add the right most edges
                elif shape[0] == 1:
                    counts = [ndarray[0][2]]
                    bins = ndarray[0][:2]
                if len(counts) > 0:
                    try:
                        # TODO: we should just re-bin if there are too many buckets
                        values[namespaced_tag(value.tag, namespace)] = wandb.Histogram(
                            np_histogram=(counts, bins)  # type: ignore
                        )
                    except ValueError:
                        wandb.termwarn(
                            f'Not logging key "{namespaced_tag(value.tag, namespace)}". '
                            f"Histograms must have fewer than {wandb.Histogram.MAX_LENGTH} bins",
                            repeat=False,
                        )
            elif plugin_name == "pr_curves":
                pr_curve_data = make_ndarray(value.tensor)
                if pr_curve_data is None:
                    continue
                precision = pr_curve_data[-2, :].tolist()
                recall = pr_curve_data[-1, :].tolist()
                # TODO: (kdg) implement spec for showing additional info in tool tips
                # true_pos = pr_curve_data[1,:]
                # false_pos = pr_curve_data[2,:]
                # true_neg = pr_curve_data[1,:]
                # false_neg = pr_curve_data[1,:]
                # threshold = [1.0 / n for n in range(len(true_pos), 0, -1)]
                # min of each in case tensorboard ever changes their pr_curve
                # to allow for different length outputs
                data = []
                for i in range(min(len(precision), len(recall))):
                    # drop additional threshold values if they exist
                    if precision[i] != 0 or recall[i] != 0:
                        data.append((recall[i], precision[i]))
                # sort data so custom chart looks the same as tb generated pr curve
                # ascending recall, descending precision for the same recall values
                data = sorted(data, key=lambda x: (x[0], -x[1]))
                data_table = wandb.Table(data=data, columns=["recall", "precision"])
                name = namespaced_tag(value.tag, namespace)

                values[name] = wandb.plot_table(
                    "wandb/line/v0",
                    data_table,
                    {"x": "recall", "y": "precision"},
                    {"title": f"{name} Precision v. Recall"},
                )
        elif kind == "image":
            img_str = value.image.encoded_image_string
            encode_images([img_str], value)
        # Coming soon...
        # elif kind == "audio":
        #     audio = wandb.Audio(
        #         six.BytesIO(value.audio.encoded_audio_string),
        #         sample_rate=value.audio.sample_rate,
        #         content_type=value.audio.content_type,
        #     )
        elif kind == "histo":
            tag = namespaced_tag(value.tag, namespace)
            if len(value.histo.bucket_limit) >= 3:
                first = (
                    value.histo.bucket_limit[0]
                    + value.histo.bucket_limit[0]
                    - value.histo.bucket_limit[1]
                )
                last = (
                    value.histo.bucket_limit[-2]
                    + value.histo.bucket_limit[-2]
                    - value.histo.bucket_limit[-3]
                )
                np_histogram = (
                    list(value.histo.bucket),
                    [first] + value.histo.bucket_limit[:-1] + [last],
                )
                try:
                    # TODO: we should just re-bin if there are too many buckets
                    values[tag] = wandb.Histogram(np_histogram=np_histogram)  # type: ignore
                except ValueError:
                    wandb.termwarn(
                        f"Not logging key {tag!r}. "
                        f"Histograms must have fewer than {wandb.Histogram.MAX_LENGTH} bins",
                        repeat=False,
                    )
            else:
                # TODO: is there a case where we can render this?
                wandb.termwarn(
                    f"Not logging key {tag!r}. Found a histogram with only 2 bins.",
                    repeat=False,
                )
        # TODO(jhr): figure out how to share this between userspace and internal process or dont
        # elif value.tag == "_hparams_/session_start_info":
        #     if wandb.util.get_module("tensorboard.plugins.hparams"):
        #         from tensorboard.plugins.hparams import plugin_data_pb2
        #
        #         plugin_data = plugin_data_pb2.HParamsPluginData()        #
        #         plugin_data.ParseFromString(value.metadata.plugin_data.content)
        #         for key, param in six.iteritems(plugin_data.session_start_info.hparams):
        #             if not wandb.run.config.get(key):
        #                 wandb.run.config[key] = (
        #                     param.number_value or param.string_value or param.bool_value
        #                 )
        #     else:
        #         wandb.termerror(
        #             "Received hparams tf.summary, but could not import "
        #             "the hparams plugin from tensorboard"
        #         )
    return values


def reset_state() -> None:
    """Internal method for resetting state, called by wandb.finish()."""
    global STEPS
    STEPS = {"": {"step": 0}, "global": {"step": 0, "last_log": None}}


def _log(
    tf_summary_str_or_pb: Any,
    history: Optional["TBHistory"] = None,
    step: int = 0,
    namespace: str = "",
    **kwargs: Any,
) -> None:
    """Logs a tfsummary to wandb.

    Can accept a tf summary string or parsed event.  Will use wandb.run.history unless a
    history object is passed.  Can optionally namespace events.  Results are committed
    when step increases for this namespace.

    NOTE: This assumes that events being passed in are in chronological order
    """
    global STEPS
    global RATE_LIMIT_SECONDS
    # To handle multiple global_steps, we keep track of them here instead
    # of the global log
    last_step = STEPS.get(namespace, {"step": 0})

    # Commit our existing data if this namespace increased its step
    commit = False
    if last_step["step"] < step:
        commit = True

    log_dict = tf_summary_to_dict(tf_summary_str_or_pb, namespace)
    if log_dict is None:
        # not an event, just return
        return

    # Pass timestamp to history for loading historic data
    timestamp = log_dict.get("_timestamp", time.time())
    # Store our initial timestamp
    if STEPS["global"]["last_log"] is None:
        STEPS["global"]["last_log"] = timestamp
    # Rollup events that share the same step across namespaces
    if commit and step == STEPS["global"]["step"]:
        commit = False
    # Always add the biggest global_step key for non-default namespaces
    if step > STEPS["global"]["step"]:
        STEPS["global"]["step"] = step
    if namespace != "":
        log_dict["global_step"] = STEPS["global"]["step"]

    # Keep internal step counter
    STEPS[namespace] = {"step": step}

    if commit:
        # Only commit our data if we're below the rate limit or don't have one
        if (
            RATE_LIMIT_SECONDS is None
            or timestamp - STEPS["global"]["last_log"] >= RATE_LIMIT_SECONDS
        ):
            if history is None:
                if wandb.run is not None:
                    wandb.run._log({})
            else:
                history.add({})

        STEPS["global"]["last_log"] = timestamp

    if history is None:
        if wandb.run is not None:
            wandb.run._log(log_dict, commit=False)
    else:
        history._row_update(log_dict)


def log(tf_summary_str_or_pb: Any, step: int = 0, namespace: str = "") -> None:
    if wandb.run is None:
        raise wandb.Error(
            "You must call `wandb.init()` before calling `wandb.tensorflow.log`"
        )

    with telemetry.context() as tel:
        tel.feature.tensorboard_log = True

    _log(tf_summary_str_or_pb, namespace=namespace, step=step)