jamtur01's picture
Upload folder using huggingface_hub
9c6594c verified
"""keras init."""
import logging
import operator
import os
import shutil
import sys
from itertools import chain
import numpy as np
import tensorflow as tf
import tensorflow.keras.backend as K # noqa: N812
import wandb
from wandb.proto.wandb_deprecated import Deprecated
from wandb.sdk.integration_utils.data_logging import ValidationDataLogger
from wandb.sdk.lib.deprecate import deprecate
from wandb.util import add_import_hook
def _check_keras_version():
from keras import __version__ as keras_version
from packaging.version import parse
if parse(keras_version) < parse("2.4.0"):
wandb.termwarn(
f"Keras version {keras_version} is not fully supported. Required keras >= 2.4.0"
)
def _can_compute_flops() -> bool:
"""FLOPS computation is restricted to TF 2.x as it requires tf.compat.v1."""
from packaging.version import parse
if parse(tf.__version__) >= parse("2.0.0"):
return True
return False
if "keras" in sys.modules:
_check_keras_version()
else:
add_import_hook("keras", _check_keras_version)
logger = logging.getLogger(__name__)
def is_dataset(data):
dataset_ops = wandb.util.get_module("tensorflow.python.data.ops.dataset_ops")
if dataset_ops and hasattr(dataset_ops, "DatasetV2"):
dataset_types = (dataset_ops.DatasetV2,)
if hasattr(dataset_ops, "DatasetV1"):
dataset_types = dataset_types + (dataset_ops.DatasetV1,)
return isinstance(data, dataset_types)
else:
return False
def is_generator_like(data):
# Checks if data is a generator, Sequence, or Iterator.
types = (tf.keras.utils.Sequence,)
iterator_ops = wandb.util.get_module("tensorflow.python.data.ops.iterator_ops")
if iterator_ops:
types = types + (iterator_ops.Iterator,)
# EagerIterator was in tensorflow < 2
if hasattr(iterator_ops, "EagerIterator"):
types = types + (iterator_ops.EagerIterator,)
elif hasattr(iterator_ops, "IteratorV2"):
types = types + (iterator_ops.IteratorV2,)
return hasattr(data, "next") or hasattr(data, "__next__") or isinstance(data, types)
def patch_tf_keras(): # noqa: C901
from packaging.version import parse
from tensorflow.python.eager import context
if parse("2.6.0") <= parse(tf.__version__) < parse("2.13.0"):
keras_engine = "keras.engine"
try:
from keras.engine import training
from keras.engine import training_arrays_v1 as training_arrays
from keras.engine import training_generator_v1 as training_generator
except (ImportError, AttributeError):
wandb.termerror("Unable to patch Tensorflow/Keras")
logger.exception("exception while trying to patch_tf_keras")
return
else:
keras_engine = "tensorflow.python.keras.engine"
from tensorflow.python.keras.engine import training
try:
from tensorflow.python.keras.engine import (
training_arrays_v1 as training_arrays,
)
from tensorflow.python.keras.engine import (
training_generator_v1 as training_generator,
)
except (ImportError, AttributeError):
try:
from tensorflow.python.keras.engine import (
training_arrays,
training_generator,
)
except (ImportError, AttributeError):
wandb.termerror("Unable to patch Tensorflow/Keras")
logger.exception("exception while trying to patch_tf_keras")
return
# Tensorflow 2.1
training_v2_1 = wandb.util.get_module("tensorflow.python.keras.engine.training_v2")
# Tensorflow 2.2
training_v2_2 = wandb.util.get_module(f"{keras_engine}.training_v1")
if training_v2_1:
old_v2 = training_v2_1.Loop.fit
elif training_v2_2:
old_v2 = training.Model.fit
old_arrays = training_arrays.fit_loop
old_generator = training_generator.fit_generator
def set_wandb_attrs(cbk, val_data):
if isinstance(cbk, WandbCallback):
if is_generator_like(val_data):
cbk.generator = val_data
elif is_dataset(val_data):
if context.executing_eagerly():
cbk.generator = iter(val_data)
else:
wandb.termwarn(
"Found a validation dataset in graph mode, can't patch Keras."
)
elif isinstance(val_data, tuple) and isinstance(val_data[0], tf.Tensor):
# Graph mode dataset generator
def gen():
while True:
yield K.get_session().run(val_data)
cbk.generator = gen()
else:
cbk.validation_data = val_data
def new_arrays(*args, **kwargs):
cbks = kwargs.get("callbacks", [])
val_inputs = kwargs.get("val_inputs")
val_targets = kwargs.get("val_targets")
# TODO: these could be generators, why index 0?
if val_inputs and val_targets:
for cbk in cbks:
set_wandb_attrs(cbk, (val_inputs[0], val_targets[0]))
return old_arrays(*args, **kwargs)
def new_generator(*args, **kwargs):
cbks = kwargs.get("callbacks", [])
val_data = kwargs.get("validation_data")
if val_data:
for cbk in cbks:
set_wandb_attrs(cbk, val_data)
return old_generator(*args, **kwargs)
def new_v2(*args, **kwargs):
cbks = kwargs.get("callbacks", [])
val_data = kwargs.get("validation_data")
if val_data:
for cbk in cbks:
set_wandb_attrs(cbk, val_data)
return old_v2(*args, **kwargs)
training_arrays.orig_fit_loop = old_arrays
training_arrays.fit_loop = new_arrays
training_generator.orig_fit_generator = old_generator
training_generator.fit_generator = new_generator
wandb.patched["keras"].append([f"{keras_engine}.training_arrays", "fit_loop"])
wandb.patched["keras"].append(
[f"{keras_engine}.training_generator", "fit_generator"]
)
if training_v2_1:
training_v2_1.Loop.fit = new_v2
wandb.patched["keras"].append(
["tensorflow.python.keras.engine.training_v2.Loop", "fit"]
)
elif training_v2_2:
training.Model.fit = new_v2
wandb.patched["keras"].append([f"{keras_engine}.training.Model", "fit"])
def _array_has_dtype(array):
return hasattr(array, "dtype")
def _update_if_numeric(metrics, key, values):
if not _array_has_dtype(values):
_warn_not_logging(key)
return
if not is_numeric_array(values):
_warn_not_logging_non_numeric(key)
return
metrics[key] = wandb.Histogram(values)
def is_numeric_array(array):
return np.issubdtype(array.dtype, np.number)
def _warn_not_logging_non_numeric(name):
wandb.termwarn(
f"Non-numeric values found in layer: {name}, not logging this layer",
repeat=False,
)
def _warn_not_logging(name):
wandb.termwarn(
f"Layer {name} has undetermined datatype not logging this layer",
repeat=False,
)
tf_logger = tf.get_logger()
patch_tf_keras()
### For gradient logging ###
def _get_custom_optimizer_parent_class():
from packaging.version import parse
if parse(tf.__version__) >= parse("2.9.0"):
custom_optimizer_parent_class = tf.keras.optimizers.legacy.Optimizer
else:
custom_optimizer_parent_class = tf.keras.optimizers.Optimizer
return custom_optimizer_parent_class
_custom_optimizer_parent_class = _get_custom_optimizer_parent_class()
class _CustomOptimizer(_custom_optimizer_parent_class):
def __init__(self):
super().__init__(name="CustomOptimizer")
self._resource_apply_dense = tf.function(self._resource_apply_dense)
self._resource_apply_sparse = tf.function(self._resource_apply_sparse)
def _resource_apply_dense(self, grad, var):
var.assign(grad)
# this needs to be implemented to prevent a NotImplementedError when
# using Lookup layers.
def _resource_apply_sparse(self, grad, var, indices):
pass
def get_config(self):
return super().get_config()
class _GradAccumulatorCallback(tf.keras.callbacks.Callback):
"""Accumulates gradients during a fit() call when used in conjunction with the CustomOptimizer above."""
def set_model(self, model):
super().set_model(model)
self.og_weights = model.get_weights()
self.grads = [np.zeros(tuple(w.shape)) for w in model.trainable_weights]
def on_batch_end(self, batch, logs=None):
for g, w in zip(self.grads, self.model.trainable_weights):
g += w.numpy()
self.model.set_weights(self.og_weights)
def get_grads(self):
return [g.copy() for g in self.grads]
###
class WandbCallback(tf.keras.callbacks.Callback):
"""`WandbCallback` automatically integrates keras with wandb.
Example:
```python
model.fit(
X_train,
y_train,
validation_data=(X_test, y_test),
callbacks=[WandbCallback()],
)
```
`WandbCallback` will automatically log history data from any
metrics collected by keras: loss and anything passed into `keras_model.compile()`.
`WandbCallback` will set summary metrics for the run associated with the "best" training
step, where "best" is defined by the `monitor` and `mode` attributes. This defaults
to the epoch with the minimum `val_loss`. `WandbCallback` will by default save the model
associated with the best `epoch`.
`WandbCallback` can optionally log gradient and parameter histograms.
`WandbCallback` can optionally save training and validation data for wandb to visualize.
Args:
monitor: (str) name of metric to monitor. Defaults to `val_loss`.
mode: (str) one of {`auto`, `min`, `max`}.
`min` - save model when monitor is minimized
`max` - save model when monitor is maximized
`auto` - try to guess when to save the model (default).
save_model:
True - save a model when monitor beats all previous epochs
False - don't save models
save_graph: (boolean) if True save model graph to wandb (default to True).
save_weights_only: (boolean) if True, then only the model's weights will be
saved (`model.save_weights(filepath)`), else the full model
is saved (`model.save(filepath)`).
log_weights: (boolean) if True save histograms of the model's layer's weights.
log_gradients: (boolean) if True log histograms of the training gradients
training_data: (tuple) Same format `(X,y)` as passed to `model.fit`. This is needed
for calculating gradients - this is mandatory if `log_gradients` is `True`.
validation_data: (tuple) Same format `(X,y)` as passed to `model.fit`. A set of data
for wandb to visualize. If this is set, every epoch, wandb will
make a small number of predictions and save the results for later visualization. In case
you are working with image data, please also set `input_type` and `output_type` in order
to log correctly.
generator: (generator) a generator that returns validation data for wandb to visualize. This
generator should return tuples `(X,y)`. Either `validate_data` or generator should
be set for wandb to visualize specific data examples. In case you are working with image data,
please also set `input_type` and `output_type` in order to log correctly.
validation_steps: (int) if `validation_data` is a generator, how many
steps to run the generator for the full validation set.
labels: (list) If you are visualizing your data with wandb this list of labels
will convert numeric output to understandable string if you are building a
multiclass classifier. If you are making a binary classifier you can pass in
a list of two labels ["label for false", "label for true"]. If `validate_data`
and generator are both false, this won't do anything.
predictions: (int) the number of predictions to make for visualization each epoch, max
is 100.
input_type: (string) type of the model input to help visualization. can be one of:
(`image`, `images`, `segmentation_mask`, `auto`).
output_type: (string) type of the model output to help visualization. can be one of:
(`image`, `images`, `segmentation_mask`, `label`).
log_evaluation: (boolean) if True, save a Table containing validation data and the
model's predictions at each epoch. See `validation_indexes`,
`validation_row_processor`, and `output_row_processor` for additional details.
class_colors: ([float, float, float]) if the input or output is a segmentation mask,
an array containing an rgb tuple (range 0-1) for each class.
log_batch_frequency: (integer) if None, callback will log every epoch.
If set to integer, callback will log training metrics every `log_batch_frequency`
batches.
log_best_prefix: (string) if None, no extra summary metrics will be saved.
If set to a string, the monitored metric and epoch will be prepended with this value
and stored as summary metrics.
validation_indexes: ([wandb.data_types._TableLinkMixin]) an ordered list of index keys to associate
with each validation example. If log_evaluation is True and `validation_indexes` is provided,
then a Table of validation data will not be created and instead each prediction will
be associated with the row represented by the `TableLinkMixin`. The most common way to obtain
such keys are is use `Table.get_index()` which will return a list of row keys.
validation_row_processor: (Callable) a function to apply to the validation data, commonly used to visualize the data.
The function will receive an `ndx` (int) and a `row` (dict). If your model has a single input,
then `row["input"]` will be the input data for the row. Else, it will be keyed based on the name of the
input slot. If your fit function takes a single target, then `row["target"]` will be the target data for the row. Else,
it will be keyed based on the name of the output slots. For example, if your input data is a single ndarray,
but you wish to visualize the data as an Image, then you can provide `lambda ndx, row: {"img": wandb.Image(row["input"])}`
as the processor. Ignored if log_evaluation is False or `validation_indexes` are present.
output_row_processor: (Callable) same as `validation_row_processor`, but applied to the model's output. `row["output"]` will contain
the results of the model output.
infer_missing_processors: (bool) Determines if `validation_row_processor` and `output_row_processor`
should be inferred if missing. Defaults to True. If `labels` are provided, we will attempt to infer classification-type
processors where appropriate.
log_evaluation_frequency: (int) Determines the frequency which evaluation results will be logged. Default 0 (only at the end of training).
Set to 1 to log every epoch, 2 to log every other epoch, and so on. Has no effect when log_evaluation is False.
compute_flops: (bool) Compute the FLOPs of your Keras Sequential or Functional model in GigaFLOPs unit.
"""
def __init__(
self,
monitor="val_loss",
verbose=0,
mode="auto",
save_weights_only=False,
log_weights=False,
log_gradients=False,
save_model=True,
training_data=None,
validation_data=None,
labels=None,
predictions=36,
generator=None,
input_type=None,
output_type=None,
log_evaluation=False,
validation_steps=None,
class_colors=None,
log_batch_frequency=None,
log_best_prefix="best_",
save_graph=True,
validation_indexes=None,
validation_row_processor=None,
prediction_row_processor=None,
infer_missing_processors=True,
log_evaluation_frequency=0,
compute_flops=False,
**kwargs,
):
if wandb.run is None:
raise wandb.Error("You must call wandb.init() before WandbCallback()")
deprecate(
field_name=Deprecated.keras_callback,
warning_message=(
"WandbCallback is deprecated and will be removed in a future release. "
"Please use the WandbMetricsLogger, WandbModelCheckpoint, and WandbEvalCallback "
"callbacks instead. "
"See https://docs.wandb.ai/guides/integrations/keras for more information."
),
)
with wandb.wandb_lib.telemetry.context(run=wandb.run) as tel:
tel.feature.keras = True
self.validation_data = None
# This is kept around for legacy reasons
if validation_data is not None:
if is_generator_like(validation_data):
generator = validation_data
else:
self.validation_data = validation_data
if labels is None:
labels = []
self.labels = labels
self.predictions = min(predictions, 100)
self.monitor = monitor
self.verbose = verbose
self.save_weights_only = save_weights_only
self.save_graph = save_graph
wandb.save("model-best.h5")
self.filepath = os.path.join(wandb.run.dir, "model-best.h5")
self.save_model = save_model
if save_model:
deprecate(
field_name=Deprecated.keras_callback__save_model,
warning_message=(
"The save_model argument by default saves the model in the HDF5 format that cannot save "
"custom objects like subclassed models and custom layers. This behavior will be deprecated "
"in a future release in favor of the SavedModel format. Meanwhile, the HDF5 model is saved "
"as W&B files and the SavedModel as W&B Artifacts."
),
)
self.save_model_as_artifact = True
self.log_weights = log_weights
self.log_gradients = log_gradients
self.training_data = training_data
self.generator = generator
self._graph_rendered = False
data_type = kwargs.get("data_type", None)
if data_type is not None:
deprecate(
field_name=Deprecated.keras_callback__data_type,
warning_message=(
"The data_type argument of wandb.keras.WandbCallback is deprecated "
"and will be removed in a future release. Please use input_type instead.\n"
"Setting input_type = data_type."
),
)
input_type = data_type
self.input_type = input_type
self.output_type = output_type
self.log_evaluation = log_evaluation
self.validation_steps = validation_steps
self.class_colors = np.array(class_colors) if class_colors is not None else None
self.log_batch_frequency = log_batch_frequency
self.log_best_prefix = log_best_prefix
self.compute_flops = compute_flops
self._prediction_batch_size = None
if self.log_gradients:
if int(tf.__version__.split(".")[0]) < 2:
raise Exception("Gradient logging requires tensorflow 2.0 or higher.")
if self.training_data is None:
raise ValueError(
"training_data argument is required for gradient logging."
)
if isinstance(self.training_data, (list, tuple)):
if len(self.training_data) != 2:
raise ValueError("training data must be a tuple of length two")
self._training_data_x, self._training_data_y = self.training_data
else:
self._training_data_x = (
self.training_data
) # generator, tf.data.Dataset etc
self._training_data_y = None
# From Keras
if mode not in ["auto", "min", "max"]:
wandb.termwarn(
f"WandbCallback mode {mode} is unknown, fallback to auto mode."
)
mode = "auto"
if mode == "min":
self.monitor_op = operator.lt
self.best = float("inf")
elif mode == "max":
self.monitor_op = operator.gt
self.best = float("-inf")
else:
if "acc" in self.monitor or self.monitor.startswith("fmeasure"):
self.monitor_op = operator.gt
self.best = float("-inf")
else:
self.monitor_op = operator.lt
self.best = float("inf")
# Get the previous best metric for resumed runs
previous_best = wandb.run.summary.get(f"{self.log_best_prefix}{self.monitor}")
if previous_best is not None:
self.best = previous_best
self._validation_data_logger = None
self._validation_indexes = validation_indexes
self._validation_row_processor = validation_row_processor
self._prediction_row_processor = prediction_row_processor
self._infer_missing_processors = infer_missing_processors
self._log_evaluation_frequency = log_evaluation_frequency
self._model_trained_since_last_eval = False
def _build_grad_accumulator_model(self):
inputs = self.model.inputs
outputs = self.model(inputs)
grad_acc_model = tf.keras.models.Model(inputs, outputs)
grad_acc_model.compile(loss=self.model.loss, optimizer=_CustomOptimizer())
# make sure magic doesn't think this is a user model
grad_acc_model._wandb_internal_model = True
self._grad_accumulator_model = grad_acc_model
self._grad_accumulator_callback = _GradAccumulatorCallback()
def _implements_train_batch_hooks(self):
return self.log_batch_frequency is not None
def _implements_test_batch_hooks(self):
return self.log_batch_frequency is not None
def _implements_predict_batch_hooks(self):
return self.log_batch_frequency is not None
def set_params(self, params):
self.params = params
def set_model(self, model):
super().set_model(model)
if self.input_type == "auto" and len(model.inputs) == 1:
self.input_type = wandb.util.guess_data_type(
model.inputs[0].shape, risky=True
)
if self.input_type and self.output_type is None and len(model.outputs) == 1:
self.output_type = wandb.util.guess_data_type(model.outputs[0].shape)
if self.log_gradients:
self._build_grad_accumulator_model()
def _attempt_evaluation_log(self, commit=True):
if self.log_evaluation and self._validation_data_logger:
try:
if not self.model:
wandb.termwarn("WandbCallback unable to read model from trainer")
else:
self._validation_data_logger.log_predictions(
predictions=self._validation_data_logger.make_predictions(
self.model.predict
),
commit=commit,
)
self._model_trained_since_last_eval = False
except Exception as e:
wandb.termwarn("Error during prediction logging for epoch: " + str(e))
def on_epoch_end(self, epoch, logs=None):
if logs is None:
logs = {}
if self.log_weights:
wandb.log(self._log_weights(), commit=False)
if self.log_gradients:
wandb.log(self._log_gradients(), commit=False)
if self.input_type in (
"image",
"images",
"segmentation_mask",
) or self.output_type in ("image", "images", "segmentation_mask"):
if self.generator:
self.validation_data = next(self.generator)
if self.validation_data is None:
wandb.termwarn(
"No validation_data set, pass a generator to the callback."
)
elif self.validation_data and len(self.validation_data) > 0:
wandb.log(
{"examples": self._log_images(num_images=self.predictions)},
commit=False,
)
if (
self._log_evaluation_frequency > 0
and epoch % self._log_evaluation_frequency == 0
):
self._attempt_evaluation_log(commit=False)
wandb.log({"epoch": epoch}, commit=False)
wandb.log(logs, commit=True)
self.current = logs.get(self.monitor)
if self.current and self.monitor_op(self.current, self.best):
if self.log_best_prefix:
wandb.run.summary[f"{self.log_best_prefix}{self.monitor}"] = (
self.current
)
wandb.run.summary["{}{}".format(self.log_best_prefix, "epoch")] = epoch
if self.verbose and not self.save_model:
wandb.termlog(
f"Epoch {epoch:05d}: {self.monitor} improved from {self.best:.5f} to {self.current:.5f}"
)
if self.save_model:
self._save_model(epoch)
if self.save_model and self.save_model_as_artifact:
self._save_model_as_artifact(epoch)
self.best = self.current
# This is what keras used pre tensorflow.keras
def on_batch_begin(self, batch, logs=None):
pass
# This is what keras used pre tensorflow.keras
def on_batch_end(self, batch, logs=None):
if self.save_graph and not self._graph_rendered:
# Couldn't do this in train_begin because keras may still not be built
wandb.run.summary["graph"] = wandb.Graph.from_keras(self.model)
self._graph_rendered = True
if self.log_batch_frequency and batch % self.log_batch_frequency == 0:
wandb.log(logs, commit=True)
def on_train_batch_begin(self, batch, logs=None):
self._model_trained_since_last_eval = True
def on_train_batch_end(self, batch, logs=None):
if self.save_graph and not self._graph_rendered:
# Couldn't do this in train_begin because keras may still not be built
wandb.run.summary["graph"] = wandb.Graph.from_keras(self.model)
self._graph_rendered = True
if self.log_batch_frequency and batch % self.log_batch_frequency == 0:
wandb.log(logs, commit=True)
def on_test_begin(self, logs=None):
pass
def on_test_end(self, logs=None):
pass
def on_test_batch_begin(self, batch, logs=None):
pass
def on_test_batch_end(self, batch, logs=None):
pass
def on_train_begin(self, logs=None):
if self.log_evaluation:
try:
validation_data = None
if self.validation_data:
validation_data = self.validation_data
elif self.generator:
if not self.validation_steps:
wandb.termwarn(
"WandbCallback is unable to log validation data. "
"When using a generator for validation_data, you must pass validation_steps"
)
else:
x = None
y_true = None
for _ in range(self.validation_steps):
bx, by_true = next(self.generator)
if x is None:
x, y_true = bx, by_true
else:
x, y_true = (
np.append(x, bx, axis=0),
np.append(y_true, by_true, axis=0),
)
validation_data = (x, y_true)
else:
wandb.termwarn(
"WandbCallback is unable to read validation_data from trainer "
"and therefore cannot log validation data. Ensure Keras is properly "
"patched by calling `from wandb.keras import WandbCallback` at the top of your script."
)
if validation_data:
self._validation_data_logger = ValidationDataLogger(
inputs=validation_data[0],
targets=validation_data[1],
indexes=self._validation_indexes,
validation_row_processor=self._validation_row_processor,
prediction_row_processor=self._prediction_row_processor,
class_labels=self.labels,
infer_missing_processors=self._infer_missing_processors,
)
except Exception as e:
wandb.termwarn(
"Error initializing ValidationDataLogger in WandbCallback. "
f"Skipping logging validation data. Error: {str(e)}"
)
if self.compute_flops and _can_compute_flops():
try:
wandb.summary["GFLOPs"] = self.get_flops()
except Exception:
logger.exception("Error computing FLOPs")
wandb.termwarn("Unable to compute FLOPs for this model.")
def on_train_end(self, logs=None):
if self._model_trained_since_last_eval:
self._attempt_evaluation_log()
def on_predict_begin(self, logs=None):
pass
def on_predict_end(self, logs=None):
pass
def on_predict_batch_begin(self, batch, logs=None):
pass
def on_predict_batch_end(self, batch, logs=None):
pass
def _logits_to_captions(self, logits):
if logits[0].shape[-1] == 1:
# Scalar output from the model
# TODO: handle validation_y
if len(self.labels) == 2:
# User has named true and false
captions = [
self.labels[1] if logits[0] > 0.5 else self.labels[0]
for logit in logits
]
else:
if len(self.labels) != 0:
wandb.termwarn(
"keras model is producing a single output, "
'so labels should be a length two array: ["False label", "True label"].'
)
captions = [logit[0] for logit in logits]
else:
# Vector output from the model
# TODO: handle validation_y
labels = np.argmax(np.stack(logits), axis=1)
if len(self.labels) > 0:
# User has named the categories in self.labels
captions = []
for label in labels:
try:
captions.append(self.labels[label])
except IndexError:
captions.append(label)
else:
captions = labels
return captions
def _masks_to_pixels(self, masks):
# if its a binary mask, just return it as grayscale instead of picking the argmax
if len(masks[0].shape) == 2 or masks[0].shape[-1] == 1:
return masks
class_colors = (
self.class_colors
if self.class_colors is not None
else np.array(wandb.util.class_colors(masks[0].shape[2]))
)
imgs = class_colors[np.argmax(masks, axis=-1)]
return imgs
def _log_images(self, num_images=36):
validation_X = self.validation_data[0] # noqa: N806
validation_y = self.validation_data[1]
validation_length = len(validation_X)
if validation_length > num_images:
# pick some data at random
indices = np.random.choice(validation_length, num_images, replace=False)
else:
indices = range(validation_length)
test_data = []
test_output = []
for i in indices:
test_example = validation_X[i]
test_data.append(test_example)
test_output.append(validation_y[i])
if self.model.stateful:
predictions = self.model.predict(np.stack(test_data), batch_size=1)
self.model.reset_states()
else:
predictions = self.model.predict(
np.stack(test_data), batch_size=self._prediction_batch_size
)
if len(predictions) != len(test_data):
self._prediction_batch_size = 1
predictions = self.model.predict(
np.stack(test_data), batch_size=self._prediction_batch_size
)
if self.input_type == "label":
if self.output_type in ("image", "images", "segmentation_mask"):
captions = self._logits_to_captions(test_data)
output_image_data = (
self._masks_to_pixels(predictions)
if self.output_type == "segmentation_mask"
else predictions
)
reference_image_data = (
self._masks_to_pixels(test_output)
if self.output_type == "segmentation_mask"
else test_output
)
output_images = [
wandb.Image(data, caption=captions[i], grouping=2)
for i, data in enumerate(output_image_data)
]
reference_images = [
wandb.Image(data, caption=captions[i])
for i, data in enumerate(reference_image_data)
]
return list(chain.from_iterable(zip(output_images, reference_images)))
elif self.input_type in ("image", "images", "segmentation_mask"):
input_image_data = (
self._masks_to_pixels(test_data)
if self.input_type == "segmentation_mask"
else test_data
)
if self.output_type == "label":
# we just use the predicted label as the caption for now
captions = self._logits_to_captions(predictions)
return [
wandb.Image(data, caption=captions[i])
for i, data in enumerate(test_data)
]
elif self.output_type in ("image", "images", "segmentation_mask"):
output_image_data = (
self._masks_to_pixels(predictions)
if self.output_type == "segmentation_mask"
else predictions
)
reference_image_data = (
self._masks_to_pixels(test_output)
if self.output_type == "segmentation_mask"
else test_output
)
input_images = [
wandb.Image(data, grouping=3)
for i, data in enumerate(input_image_data)
]
output_images = [
wandb.Image(data) for i, data in enumerate(output_image_data)
]
reference_images = [
wandb.Image(data) for i, data in enumerate(reference_image_data)
]
return list(
chain.from_iterable(
zip(input_images, output_images, reference_images)
)
)
else:
# unknown output, just log the input images
return [wandb.Image(img) for img in test_data]
elif self.output_type in ("image", "images", "segmentation_mask"):
# unknown input, just log the predicted and reference outputs without captions
output_image_data = (
self._masks_to_pixels(predictions)
if self.output_type == "segmentation_mask"
else predictions
)
reference_image_data = (
self._masks_to_pixels(test_output)
if self.output_type == "segmentation_mask"
else test_output
)
output_images = [
wandb.Image(data, grouping=2)
for i, data in enumerate(output_image_data)
]
reference_images = [
wandb.Image(data) for i, data in enumerate(reference_image_data)
]
return list(chain.from_iterable(zip(output_images, reference_images)))
def _log_weights(self):
metrics = {}
for layer in self.model.layers:
weights = layer.get_weights()
if len(weights) == 1:
_update_if_numeric(
metrics, "parameters/" + layer.name + ".weights", weights[0]
)
elif len(weights) == 2:
_update_if_numeric(
metrics, "parameters/" + layer.name + ".weights", weights[0]
)
_update_if_numeric(
metrics, "parameters/" + layer.name + ".bias", weights[1]
)
return metrics
def _log_gradients(self):
# Suppress callback warnings grad accumulator
og_level = tf_logger.level
tf_logger.setLevel("ERROR")
self._grad_accumulator_model.fit(
self._training_data_x,
self._training_data_y,
verbose=0,
callbacks=[self._grad_accumulator_callback],
)
tf_logger.setLevel(og_level)
weights = self.model.trainable_weights
grads = self._grad_accumulator_callback.grads
metrics = {}
for weight, grad in zip(weights, grads):
metrics["gradients/" + weight.name.split(":")[0] + ".gradient"] = (
wandb.Histogram(grad)
)
return metrics
def _log_dataframe(self):
x, y_true, y_pred = None, None, None
if self.validation_data:
x, y_true = self.validation_data[0], self.validation_data[1]
y_pred = self.model.predict(x)
elif self.generator:
if not self.validation_steps:
wandb.termwarn(
"when using a generator for validation data with dataframes, "
"you must pass validation_steps. skipping"
)
return None
for _ in range(self.validation_steps):
bx, by_true = next(self.generator)
by_pred = self.model.predict(bx)
if x is None:
x, y_true, y_pred = bx, by_true, by_pred
else:
x, y_true, y_pred = (
np.append(x, bx, axis=0),
np.append(y_true, by_true, axis=0),
np.append(y_pred, by_pred, axis=0),
)
if self.input_type in ("image", "images") and self.output_type == "label":
return wandb.image_categorizer_dataframe(
x=x, y_true=y_true, y_pred=y_pred, labels=self.labels
)
elif (
self.input_type in ("image", "images")
and self.output_type == "segmentation_mask"
):
return wandb.image_segmentation_dataframe(
x=x,
y_true=y_true,
y_pred=y_pred,
labels=self.labels,
class_colors=self.class_colors,
)
else:
wandb.termwarn(
f"unknown dataframe type for input_type={self.input_type} and output_type={self.output_type}"
)
return None
def _save_model(self, epoch):
if wandb.run.disabled:
return
if self.verbose > 0:
wandb.termlog(
f"Epoch {epoch:05d}: {self.monitor} improved from {self.best:.5f} to {self.current:.5f}, "
f"saving model to {self.filepath}"
)
try:
if self.save_weights_only:
self.model.save_weights(self.filepath, overwrite=True)
else:
self.model.save(self.filepath, overwrite=True)
# Was getting `RuntimeError: Unable to create link` in TF 1.13.1
# also saw `TypeError: can't pickle _thread.RLock objects`
except (ImportError, RuntimeError, TypeError, AttributeError):
logger.exception("Error saving model in the h5py format")
wandb.termerror(
"Can't save model in the h5py format. The model will be saved as "
"as an W&B Artifact in the 'tf' format."
)
def _save_model_as_artifact(self, epoch):
if wandb.run.disabled:
return
# Save the model in the SavedModel format.
# TODO: Replace this manual artifact creation with the `log_model` method
# after `log_model` is released from beta.
self.model.save(self.filepath[:-3], overwrite=True, save_format="tf")
# Log the model as artifact.
name = wandb.util.make_artifact_name_safe(f"model-{wandb.run.name}")
model_artifact = wandb.Artifact(name, type="model")
model_artifact.add_dir(self.filepath[:-3])
wandb.run.log_artifact(model_artifact, aliases=["latest", f"epoch_{epoch}"])
# Remove the SavedModel from wandb dir as we don't want to log it to save memory.
shutil.rmtree(self.filepath[:-3])
def get_flops(self) -> float:
"""Calculate FLOPS [GFLOPs] for a tf.keras.Model or tf.keras.Sequential model in inference mode.
It uses tf.compat.v1.profiler under the hood.
"""
if not hasattr(self, "model"):
raise wandb.Error("self.model must be set before using this method.")
if not isinstance(
self.model, (tf.keras.models.Sequential, tf.keras.models.Model)
):
raise TypeError(
"Calculating FLOPS is only supported for "
"`tf.keras.Model` and `tf.keras.Sequential` instances."
)
from tensorflow.python.framework.convert_to_constants import (
convert_variables_to_constants_v2_as_graph,
)
# Compute FLOPs for one sample
batch_size = 1
inputs = [
tf.TensorSpec([batch_size] + inp.shape[1:], inp.dtype)
for inp in self.model.inputs
]
# convert tf.keras model into frozen graph to count FLOPs about operations used at inference
real_model = tf.function(self.model).get_concrete_function(inputs)
frozen_func, _ = convert_variables_to_constants_v2_as_graph(real_model)
# Calculate FLOPs with tf.profiler
run_meta = tf.compat.v1.RunMetadata()
opts = (
tf.compat.v1.profiler.ProfileOptionBuilder(
tf.compat.v1.profiler.ProfileOptionBuilder().float_operation()
)
.with_empty_output()
.build()
)
flops = tf.compat.v1.profiler.profile(
graph=frozen_func.graph, run_meta=run_meta, cmd="scope", options=opts
)
# convert to GFLOPs
return (flops.total_float_ops / 1e9) / 2