File size: 14,076 Bytes
9c6594c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 |
import io
import re
import time
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
import wandb
import wandb.util
from wandb.sdk.lib import telemetry
if TYPE_CHECKING:
import numpy as np
from wandb.sdk.internal.tb_watcher import TBHistory
# We have at least the default namestep and a global step to track
# TODO: reset this structure on wandb.finish
STEPS: Dict[str, Dict[str, Any]] = {
"": {"step": 0},
"global": {"step": 0, "last_log": None},
}
# TODO(cling): Set these when tensorboard behavior is configured.
# We support rate limited logging by setting this to number of seconds,
# can be a floating point.
RATE_LIMIT_SECONDS: Optional[Union[float, int]] = None
IGNORE_KINDS = ["graphs"]
tensor_util = wandb.util.get_module("tensorboard.util.tensor_util")
# prefer tensorboard, fallback to protobuf in tensorflow when tboard isn't available
pb = wandb.util.get_module(
"tensorboard.compat.proto.summary_pb2"
) or wandb.util.get_module("tensorflow.core.framework.summary_pb2")
Summary = pb.Summary if pb else None
def make_ndarray(tensor: Any) -> Optional["np.ndarray"]:
if tensor_util:
res = tensor_util.make_ndarray(tensor)
# Tensorboard can log generic objects, and we don't want to save them
if res.dtype == "object":
return None
else:
return res # type: ignore
else:
wandb.termwarn(
"Can't convert tensor summary, upgrade tensorboard with `pip"
" install tensorboard --upgrade`"
)
return None
def namespaced_tag(tag: str, namespace: str = "") -> str:
if not namespace:
return tag
else:
return namespace + "/" + tag
def history_image_key(key: str, namespace: str = "") -> str:
"""Convert invalid filesystem characters to _ for use in History keys.
Unfortunately this means currently certain image keys will collide silently. We
implement this mapping up here in the TensorFlow stuff rather than in the History
stuff so that we don't have to store a mapping anywhere from the original keys to
the safe ones.
"""
return namespaced_tag(re.sub(r"[/\\]", "_", key), namespace)
def tf_summary_to_dict( # noqa: C901
tf_summary_str_or_pb: Any, namespace: str = ""
) -> Optional[Dict[str, Any]]:
"""Convert a Tensorboard Summary to a dictionary.
Accepts a tensorflow.summary.Summary, one encoded as a string,
or a list of such encoded as strings.
"""
values = {}
if hasattr(tf_summary_str_or_pb, "summary"):
summary_pb = tf_summary_str_or_pb.summary
values[namespaced_tag("global_step", namespace)] = tf_summary_str_or_pb.step
values["_timestamp"] = tf_summary_str_or_pb.wall_time
elif isinstance(tf_summary_str_or_pb, (str, bytes, bytearray)):
summary_pb = Summary()
summary_pb.ParseFromString(tf_summary_str_or_pb)
elif hasattr(tf_summary_str_or_pb, "__iter__"):
summary_pb = [Summary() for _ in range(len(tf_summary_str_or_pb))]
for i, summary in enumerate(tf_summary_str_or_pb):
summary_pb[i].ParseFromString(summary)
if i > 0:
summary_pb[0].MergeFrom(summary_pb[i])
summary_pb = summary_pb[0]
else:
summary_pb = tf_summary_str_or_pb
if not hasattr(summary_pb, "value") or len(summary_pb.value) == 0:
# Ignore these, caller is responsible for handling None
return None
def encode_images(_img_strs: List[bytes], _value: Any) -> None:
try:
from PIL import Image
except ImportError:
wandb.termwarn(
"Install pillow if you are logging images with Tensorboard. "
"To install, run `pip install pillow`.",
repeat=False,
)
return None
if len(_img_strs) == 0:
return None
images: List[Union[wandb.Video, wandb.Image]] = []
for _img_str in _img_strs:
# Supports gifs from TensorboardX
if _img_str.startswith(b"GIF"):
images.append(wandb.Video(io.BytesIO(_img_str), format="gif"))
else:
images.append(wandb.Image(Image.open(io.BytesIO(_img_str))))
tag_idx = _value.tag.rsplit("/", 1)
if len(tag_idx) > 1 and tag_idx[1].isdigit():
tag, idx = tag_idx
values.setdefault(history_image_key(tag, namespace), []).extend(images)
else:
values[history_image_key(_value.tag, namespace)] = images
return None
for value in summary_pb.value:
kind = value.WhichOneof("value")
if kind in IGNORE_KINDS:
continue
if kind == "simple_value":
values[namespaced_tag(value.tag, namespace)] = value.simple_value
elif kind == "tensor":
plugin_name = value.metadata.plugin_data.plugin_name
if plugin_name == "scalars" or plugin_name == "":
values[namespaced_tag(value.tag, namespace)] = make_ndarray(
value.tensor
)
elif plugin_name == "images":
img_strs = value.tensor.string_val[2:] # First two items are dims.
encode_images(img_strs, value)
elif plugin_name == "histograms":
# https://github.com/tensorflow/tensorboard/blob/master/tensorboard/plugins/histogram/summary_v2.py#L15-L26
ndarray = make_ndarray(value.tensor)
if ndarray is None:
continue
shape = ndarray.shape
counts = []
bins = []
if shape[0] > 1:
bins.append(ndarray[0][0]) # Add the left most edge
for v in ndarray:
counts.append(v[2])
bins.append(v[1]) # Add the right most edges
elif shape[0] == 1:
counts = [ndarray[0][2]]
bins = ndarray[0][:2]
if len(counts) > 0:
try:
# TODO: we should just re-bin if there are too many buckets
values[namespaced_tag(value.tag, namespace)] = wandb.Histogram(
np_histogram=(counts, bins) # type: ignore
)
except ValueError:
wandb.termwarn(
f'Not logging key "{namespaced_tag(value.tag, namespace)}". '
f"Histograms must have fewer than {wandb.Histogram.MAX_LENGTH} bins",
repeat=False,
)
elif plugin_name == "pr_curves":
pr_curve_data = make_ndarray(value.tensor)
if pr_curve_data is None:
continue
precision = pr_curve_data[-2, :].tolist()
recall = pr_curve_data[-1, :].tolist()
# TODO: (kdg) implement spec for showing additional info in tool tips
# true_pos = pr_curve_data[1,:]
# false_pos = pr_curve_data[2,:]
# true_neg = pr_curve_data[1,:]
# false_neg = pr_curve_data[1,:]
# threshold = [1.0 / n for n in range(len(true_pos), 0, -1)]
# min of each in case tensorboard ever changes their pr_curve
# to allow for different length outputs
data = []
for i in range(min(len(precision), len(recall))):
# drop additional threshold values if they exist
if precision[i] != 0 or recall[i] != 0:
data.append((recall[i], precision[i]))
# sort data so custom chart looks the same as tb generated pr curve
# ascending recall, descending precision for the same recall values
data = sorted(data, key=lambda x: (x[0], -x[1]))
data_table = wandb.Table(data=data, columns=["recall", "precision"])
name = namespaced_tag(value.tag, namespace)
values[name] = wandb.plot_table(
"wandb/line/v0",
data_table,
{"x": "recall", "y": "precision"},
{"title": f"{name} Precision v. Recall"},
)
elif kind == "image":
img_str = value.image.encoded_image_string
encode_images([img_str], value)
# Coming soon...
# elif kind == "audio":
# audio = wandb.Audio(
# six.BytesIO(value.audio.encoded_audio_string),
# sample_rate=value.audio.sample_rate,
# content_type=value.audio.content_type,
# )
elif kind == "histo":
tag = namespaced_tag(value.tag, namespace)
if len(value.histo.bucket_limit) >= 3:
first = (
value.histo.bucket_limit[0]
+ value.histo.bucket_limit[0]
- value.histo.bucket_limit[1]
)
last = (
value.histo.bucket_limit[-2]
+ value.histo.bucket_limit[-2]
- value.histo.bucket_limit[-3]
)
np_histogram = (
list(value.histo.bucket),
[first] + value.histo.bucket_limit[:-1] + [last],
)
try:
# TODO: we should just re-bin if there are too many buckets
values[tag] = wandb.Histogram(np_histogram=np_histogram) # type: ignore
except ValueError:
wandb.termwarn(
f"Not logging key {tag!r}. "
f"Histograms must have fewer than {wandb.Histogram.MAX_LENGTH} bins",
repeat=False,
)
else:
# TODO: is there a case where we can render this?
wandb.termwarn(
f"Not logging key {tag!r}. Found a histogram with only 2 bins.",
repeat=False,
)
# TODO(jhr): figure out how to share this between userspace and internal process or dont
# elif value.tag == "_hparams_/session_start_info":
# if wandb.util.get_module("tensorboard.plugins.hparams"):
# from tensorboard.plugins.hparams import plugin_data_pb2
#
# plugin_data = plugin_data_pb2.HParamsPluginData() #
# plugin_data.ParseFromString(value.metadata.plugin_data.content)
# for key, param in six.iteritems(plugin_data.session_start_info.hparams):
# if not wandb.run.config.get(key):
# wandb.run.config[key] = (
# param.number_value or param.string_value or param.bool_value
# )
# else:
# wandb.termerror(
# "Received hparams tf.summary, but could not import "
# "the hparams plugin from tensorboard"
# )
return values
def reset_state() -> None:
"""Internal method for resetting state, called by wandb.finish()."""
global STEPS
STEPS = {"": {"step": 0}, "global": {"step": 0, "last_log": None}}
def _log(
tf_summary_str_or_pb: Any,
history: Optional["TBHistory"] = None,
step: int = 0,
namespace: str = "",
**kwargs: Any,
) -> None:
"""Logs a tfsummary to wandb.
Can accept a tf summary string or parsed event. Will use wandb.run.history unless a
history object is passed. Can optionally namespace events. Results are committed
when step increases for this namespace.
NOTE: This assumes that events being passed in are in chronological order
"""
global STEPS
global RATE_LIMIT_SECONDS
# To handle multiple global_steps, we keep track of them here instead
# of the global log
last_step = STEPS.get(namespace, {"step": 0})
# Commit our existing data if this namespace increased its step
commit = False
if last_step["step"] < step:
commit = True
log_dict = tf_summary_to_dict(tf_summary_str_or_pb, namespace)
if log_dict is None:
# not an event, just return
return
# Pass timestamp to history for loading historic data
timestamp = log_dict.get("_timestamp", time.time())
# Store our initial timestamp
if STEPS["global"]["last_log"] is None:
STEPS["global"]["last_log"] = timestamp
# Rollup events that share the same step across namespaces
if commit and step == STEPS["global"]["step"]:
commit = False
# Always add the biggest global_step key for non-default namespaces
if step > STEPS["global"]["step"]:
STEPS["global"]["step"] = step
if namespace != "":
log_dict["global_step"] = STEPS["global"]["step"]
# Keep internal step counter
STEPS[namespace] = {"step": step}
if commit:
# Only commit our data if we're below the rate limit or don't have one
if (
RATE_LIMIT_SECONDS is None
or timestamp - STEPS["global"]["last_log"] >= RATE_LIMIT_SECONDS
):
if history is None:
if wandb.run is not None:
wandb.run._log({})
else:
history.add({})
STEPS["global"]["last_log"] = timestamp
if history is None:
if wandb.run is not None:
wandb.run._log(log_dict, commit=False)
else:
history._row_update(log_dict)
def log(tf_summary_str_or_pb: Any, step: int = 0, namespace: str = "") -> None:
if wandb.run is None:
raise wandb.Error(
"You must call `wandb.init()` before calling `wandb.tensorflow.log`"
)
with telemetry.context() as tel:
tel.feature.tensorboard_log = True
_log(tf_summary_str_or_pb, namespace=namespace, step=step)
|