|
"""keras init.""" |
|
|
|
import logging |
|
import operator |
|
import os |
|
import shutil |
|
import sys |
|
from itertools import chain |
|
|
|
import numpy as np |
|
import tensorflow as tf |
|
import tensorflow.keras.backend as K |
|
|
|
import wandb |
|
from wandb.proto.wandb_deprecated import Deprecated |
|
from wandb.sdk.integration_utils.data_logging import ValidationDataLogger |
|
from wandb.sdk.lib.deprecate import deprecate |
|
from wandb.util import add_import_hook |
|
|
|
|
|
def _check_keras_version(): |
|
from keras import __version__ as keras_version |
|
from packaging.version import parse |
|
|
|
if parse(keras_version) < parse("2.4.0"): |
|
wandb.termwarn( |
|
f"Keras version {keras_version} is not fully supported. Required keras >= 2.4.0" |
|
) |
|
|
|
|
|
def _can_compute_flops() -> bool: |
|
"""FLOPS computation is restricted to TF 2.x as it requires tf.compat.v1.""" |
|
from packaging.version import parse |
|
|
|
if parse(tf.__version__) >= parse("2.0.0"): |
|
return True |
|
|
|
return False |
|
|
|
|
|
if "keras" in sys.modules: |
|
_check_keras_version() |
|
else: |
|
add_import_hook("keras", _check_keras_version) |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
def is_dataset(data): |
|
dataset_ops = wandb.util.get_module("tensorflow.python.data.ops.dataset_ops") |
|
if dataset_ops and hasattr(dataset_ops, "DatasetV2"): |
|
dataset_types = (dataset_ops.DatasetV2,) |
|
if hasattr(dataset_ops, "DatasetV1"): |
|
dataset_types = dataset_types + (dataset_ops.DatasetV1,) |
|
return isinstance(data, dataset_types) |
|
else: |
|
return False |
|
|
|
|
|
def is_generator_like(data): |
|
|
|
|
|
types = (tf.keras.utils.Sequence,) |
|
iterator_ops = wandb.util.get_module("tensorflow.python.data.ops.iterator_ops") |
|
if iterator_ops: |
|
types = types + (iterator_ops.Iterator,) |
|
|
|
if hasattr(iterator_ops, "EagerIterator"): |
|
types = types + (iterator_ops.EagerIterator,) |
|
elif hasattr(iterator_ops, "IteratorV2"): |
|
types = types + (iterator_ops.IteratorV2,) |
|
return hasattr(data, "next") or hasattr(data, "__next__") or isinstance(data, types) |
|
|
|
|
|
def patch_tf_keras(): |
|
from packaging.version import parse |
|
from tensorflow.python.eager import context |
|
|
|
if parse("2.6.0") <= parse(tf.__version__) < parse("2.13.0"): |
|
keras_engine = "keras.engine" |
|
try: |
|
from keras.engine import training |
|
from keras.engine import training_arrays_v1 as training_arrays |
|
from keras.engine import training_generator_v1 as training_generator |
|
except (ImportError, AttributeError): |
|
wandb.termerror("Unable to patch Tensorflow/Keras") |
|
logger.exception("exception while trying to patch_tf_keras") |
|
return |
|
else: |
|
keras_engine = "tensorflow.python.keras.engine" |
|
|
|
from tensorflow.python.keras.engine import training |
|
|
|
try: |
|
from tensorflow.python.keras.engine import ( |
|
training_arrays_v1 as training_arrays, |
|
) |
|
from tensorflow.python.keras.engine import ( |
|
training_generator_v1 as training_generator, |
|
) |
|
except (ImportError, AttributeError): |
|
try: |
|
from tensorflow.python.keras.engine import ( |
|
training_arrays, |
|
training_generator, |
|
) |
|
except (ImportError, AttributeError): |
|
wandb.termerror("Unable to patch Tensorflow/Keras") |
|
logger.exception("exception while trying to patch_tf_keras") |
|
return |
|
|
|
|
|
training_v2_1 = wandb.util.get_module("tensorflow.python.keras.engine.training_v2") |
|
|
|
training_v2_2 = wandb.util.get_module(f"{keras_engine}.training_v1") |
|
|
|
if training_v2_1: |
|
old_v2 = training_v2_1.Loop.fit |
|
elif training_v2_2: |
|
old_v2 = training.Model.fit |
|
|
|
old_arrays = training_arrays.fit_loop |
|
old_generator = training_generator.fit_generator |
|
|
|
def set_wandb_attrs(cbk, val_data): |
|
if isinstance(cbk, WandbCallback): |
|
if is_generator_like(val_data): |
|
cbk.generator = val_data |
|
elif is_dataset(val_data): |
|
if context.executing_eagerly(): |
|
cbk.generator = iter(val_data) |
|
else: |
|
wandb.termwarn( |
|
"Found a validation dataset in graph mode, can't patch Keras." |
|
) |
|
elif isinstance(val_data, tuple) and isinstance(val_data[0], tf.Tensor): |
|
|
|
def gen(): |
|
while True: |
|
yield K.get_session().run(val_data) |
|
|
|
cbk.generator = gen() |
|
else: |
|
cbk.validation_data = val_data |
|
|
|
def new_arrays(*args, **kwargs): |
|
cbks = kwargs.get("callbacks", []) |
|
val_inputs = kwargs.get("val_inputs") |
|
val_targets = kwargs.get("val_targets") |
|
|
|
if val_inputs and val_targets: |
|
for cbk in cbks: |
|
set_wandb_attrs(cbk, (val_inputs[0], val_targets[0])) |
|
return old_arrays(*args, **kwargs) |
|
|
|
def new_generator(*args, **kwargs): |
|
cbks = kwargs.get("callbacks", []) |
|
val_data = kwargs.get("validation_data") |
|
if val_data: |
|
for cbk in cbks: |
|
set_wandb_attrs(cbk, val_data) |
|
return old_generator(*args, **kwargs) |
|
|
|
def new_v2(*args, **kwargs): |
|
cbks = kwargs.get("callbacks", []) |
|
val_data = kwargs.get("validation_data") |
|
if val_data: |
|
for cbk in cbks: |
|
set_wandb_attrs(cbk, val_data) |
|
return old_v2(*args, **kwargs) |
|
|
|
training_arrays.orig_fit_loop = old_arrays |
|
training_arrays.fit_loop = new_arrays |
|
training_generator.orig_fit_generator = old_generator |
|
training_generator.fit_generator = new_generator |
|
wandb.patched["keras"].append([f"{keras_engine}.training_arrays", "fit_loop"]) |
|
wandb.patched["keras"].append( |
|
[f"{keras_engine}.training_generator", "fit_generator"] |
|
) |
|
|
|
if training_v2_1: |
|
training_v2_1.Loop.fit = new_v2 |
|
wandb.patched["keras"].append( |
|
["tensorflow.python.keras.engine.training_v2.Loop", "fit"] |
|
) |
|
elif training_v2_2: |
|
training.Model.fit = new_v2 |
|
wandb.patched["keras"].append([f"{keras_engine}.training.Model", "fit"]) |
|
|
|
|
|
def _array_has_dtype(array): |
|
return hasattr(array, "dtype") |
|
|
|
|
|
def _update_if_numeric(metrics, key, values): |
|
if not _array_has_dtype(values): |
|
_warn_not_logging(key) |
|
return |
|
|
|
if not is_numeric_array(values): |
|
_warn_not_logging_non_numeric(key) |
|
return |
|
|
|
metrics[key] = wandb.Histogram(values) |
|
|
|
|
|
def is_numeric_array(array): |
|
return np.issubdtype(array.dtype, np.number) |
|
|
|
|
|
def _warn_not_logging_non_numeric(name): |
|
wandb.termwarn( |
|
f"Non-numeric values found in layer: {name}, not logging this layer", |
|
repeat=False, |
|
) |
|
|
|
|
|
def _warn_not_logging(name): |
|
wandb.termwarn( |
|
f"Layer {name} has undetermined datatype not logging this layer", |
|
repeat=False, |
|
) |
|
|
|
|
|
tf_logger = tf.get_logger() |
|
|
|
patch_tf_keras() |
|
|
|
|
|
|
|
|
|
|
|
def _get_custom_optimizer_parent_class(): |
|
from packaging.version import parse |
|
|
|
if parse(tf.__version__) >= parse("2.9.0"): |
|
custom_optimizer_parent_class = tf.keras.optimizers.legacy.Optimizer |
|
else: |
|
custom_optimizer_parent_class = tf.keras.optimizers.Optimizer |
|
|
|
return custom_optimizer_parent_class |
|
|
|
|
|
_custom_optimizer_parent_class = _get_custom_optimizer_parent_class() |
|
|
|
|
|
class _CustomOptimizer(_custom_optimizer_parent_class): |
|
def __init__(self): |
|
super().__init__(name="CustomOptimizer") |
|
self._resource_apply_dense = tf.function(self._resource_apply_dense) |
|
self._resource_apply_sparse = tf.function(self._resource_apply_sparse) |
|
|
|
def _resource_apply_dense(self, grad, var): |
|
var.assign(grad) |
|
|
|
|
|
|
|
def _resource_apply_sparse(self, grad, var, indices): |
|
pass |
|
|
|
def get_config(self): |
|
return super().get_config() |
|
|
|
|
|
class _GradAccumulatorCallback(tf.keras.callbacks.Callback): |
|
"""Accumulates gradients during a fit() call when used in conjunction with the CustomOptimizer above.""" |
|
|
|
def set_model(self, model): |
|
super().set_model(model) |
|
self.og_weights = model.get_weights() |
|
self.grads = [np.zeros(tuple(w.shape)) for w in model.trainable_weights] |
|
|
|
def on_batch_end(self, batch, logs=None): |
|
for g, w in zip(self.grads, self.model.trainable_weights): |
|
g += w.numpy() |
|
self.model.set_weights(self.og_weights) |
|
|
|
def get_grads(self): |
|
return [g.copy() for g in self.grads] |
|
|
|
|
|
|
|
|
|
|
|
class WandbCallback(tf.keras.callbacks.Callback): |
|
"""`WandbCallback` automatically integrates keras with wandb. |
|
|
|
Example: |
|
```python |
|
model.fit( |
|
X_train, |
|
y_train, |
|
validation_data=(X_test, y_test), |
|
callbacks=[WandbCallback()], |
|
) |
|
``` |
|
|
|
`WandbCallback` will automatically log history data from any |
|
metrics collected by keras: loss and anything passed into `keras_model.compile()`. |
|
|
|
`WandbCallback` will set summary metrics for the run associated with the "best" training |
|
step, where "best" is defined by the `monitor` and `mode` attributes. This defaults |
|
to the epoch with the minimum `val_loss`. `WandbCallback` will by default save the model |
|
associated with the best `epoch`. |
|
|
|
`WandbCallback` can optionally log gradient and parameter histograms. |
|
|
|
`WandbCallback` can optionally save training and validation data for wandb to visualize. |
|
|
|
Args: |
|
monitor: (str) name of metric to monitor. Defaults to `val_loss`. |
|
mode: (str) one of {`auto`, `min`, `max`}. |
|
`min` - save model when monitor is minimized |
|
`max` - save model when monitor is maximized |
|
`auto` - try to guess when to save the model (default). |
|
save_model: |
|
True - save a model when monitor beats all previous epochs |
|
False - don't save models |
|
save_graph: (boolean) if True save model graph to wandb (default to True). |
|
save_weights_only: (boolean) if True, then only the model's weights will be |
|
saved (`model.save_weights(filepath)`), else the full model |
|
is saved (`model.save(filepath)`). |
|
log_weights: (boolean) if True save histograms of the model's layer's weights. |
|
log_gradients: (boolean) if True log histograms of the training gradients |
|
training_data: (tuple) Same format `(X,y)` as passed to `model.fit`. This is needed |
|
for calculating gradients - this is mandatory if `log_gradients` is `True`. |
|
validation_data: (tuple) Same format `(X,y)` as passed to `model.fit`. A set of data |
|
for wandb to visualize. If this is set, every epoch, wandb will |
|
make a small number of predictions and save the results for later visualization. In case |
|
you are working with image data, please also set `input_type` and `output_type` in order |
|
to log correctly. |
|
generator: (generator) a generator that returns validation data for wandb to visualize. This |
|
generator should return tuples `(X,y)`. Either `validate_data` or generator should |
|
be set for wandb to visualize specific data examples. In case you are working with image data, |
|
please also set `input_type` and `output_type` in order to log correctly. |
|
validation_steps: (int) if `validation_data` is a generator, how many |
|
steps to run the generator for the full validation set. |
|
labels: (list) If you are visualizing your data with wandb this list of labels |
|
will convert numeric output to understandable string if you are building a |
|
multiclass classifier. If you are making a binary classifier you can pass in |
|
a list of two labels ["label for false", "label for true"]. If `validate_data` |
|
and generator are both false, this won't do anything. |
|
predictions: (int) the number of predictions to make for visualization each epoch, max |
|
is 100. |
|
input_type: (string) type of the model input to help visualization. can be one of: |
|
(`image`, `images`, `segmentation_mask`, `auto`). |
|
output_type: (string) type of the model output to help visualization. can be one of: |
|
(`image`, `images`, `segmentation_mask`, `label`). |
|
log_evaluation: (boolean) if True, save a Table containing validation data and the |
|
model's predictions at each epoch. See `validation_indexes`, |
|
`validation_row_processor`, and `output_row_processor` for additional details. |
|
class_colors: ([float, float, float]) if the input or output is a segmentation mask, |
|
an array containing an rgb tuple (range 0-1) for each class. |
|
log_batch_frequency: (integer) if None, callback will log every epoch. |
|
If set to integer, callback will log training metrics every `log_batch_frequency` |
|
batches. |
|
log_best_prefix: (string) if None, no extra summary metrics will be saved. |
|
If set to a string, the monitored metric and epoch will be prepended with this value |
|
and stored as summary metrics. |
|
validation_indexes: ([wandb.data_types._TableLinkMixin]) an ordered list of index keys to associate |
|
with each validation example. If log_evaluation is True and `validation_indexes` is provided, |
|
then a Table of validation data will not be created and instead each prediction will |
|
be associated with the row represented by the `TableLinkMixin`. The most common way to obtain |
|
such keys are is use `Table.get_index()` which will return a list of row keys. |
|
validation_row_processor: (Callable) a function to apply to the validation data, commonly used to visualize the data. |
|
The function will receive an `ndx` (int) and a `row` (dict). If your model has a single input, |
|
then `row["input"]` will be the input data for the row. Else, it will be keyed based on the name of the |
|
input slot. If your fit function takes a single target, then `row["target"]` will be the target data for the row. Else, |
|
it will be keyed based on the name of the output slots. For example, if your input data is a single ndarray, |
|
but you wish to visualize the data as an Image, then you can provide `lambda ndx, row: {"img": wandb.Image(row["input"])}` |
|
as the processor. Ignored if log_evaluation is False or `validation_indexes` are present. |
|
output_row_processor: (Callable) same as `validation_row_processor`, but applied to the model's output. `row["output"]` will contain |
|
the results of the model output. |
|
infer_missing_processors: (bool) Determines if `validation_row_processor` and `output_row_processor` |
|
should be inferred if missing. Defaults to True. If `labels` are provided, we will attempt to infer classification-type |
|
processors where appropriate. |
|
log_evaluation_frequency: (int) Determines the frequency which evaluation results will be logged. Default 0 (only at the end of training). |
|
Set to 1 to log every epoch, 2 to log every other epoch, and so on. Has no effect when log_evaluation is False. |
|
compute_flops: (bool) Compute the FLOPs of your Keras Sequential or Functional model in GigaFLOPs unit. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
monitor="val_loss", |
|
verbose=0, |
|
mode="auto", |
|
save_weights_only=False, |
|
log_weights=False, |
|
log_gradients=False, |
|
save_model=True, |
|
training_data=None, |
|
validation_data=None, |
|
labels=None, |
|
predictions=36, |
|
generator=None, |
|
input_type=None, |
|
output_type=None, |
|
log_evaluation=False, |
|
validation_steps=None, |
|
class_colors=None, |
|
log_batch_frequency=None, |
|
log_best_prefix="best_", |
|
save_graph=True, |
|
validation_indexes=None, |
|
validation_row_processor=None, |
|
prediction_row_processor=None, |
|
infer_missing_processors=True, |
|
log_evaluation_frequency=0, |
|
compute_flops=False, |
|
**kwargs, |
|
): |
|
if wandb.run is None: |
|
raise wandb.Error("You must call wandb.init() before WandbCallback()") |
|
|
|
deprecate( |
|
field_name=Deprecated.keras_callback, |
|
warning_message=( |
|
"WandbCallback is deprecated and will be removed in a future release. " |
|
"Please use the WandbMetricsLogger, WandbModelCheckpoint, and WandbEvalCallback " |
|
"callbacks instead. " |
|
"See https://docs.wandb.ai/guides/integrations/keras for more information." |
|
), |
|
) |
|
|
|
with wandb.wandb_lib.telemetry.context(run=wandb.run) as tel: |
|
tel.feature.keras = True |
|
self.validation_data = None |
|
|
|
if validation_data is not None: |
|
if is_generator_like(validation_data): |
|
generator = validation_data |
|
else: |
|
self.validation_data = validation_data |
|
if labels is None: |
|
labels = [] |
|
self.labels = labels |
|
self.predictions = min(predictions, 100) |
|
|
|
self.monitor = monitor |
|
self.verbose = verbose |
|
self.save_weights_only = save_weights_only |
|
self.save_graph = save_graph |
|
|
|
wandb.save("model-best.h5") |
|
self.filepath = os.path.join(wandb.run.dir, "model-best.h5") |
|
self.save_model = save_model |
|
if save_model: |
|
deprecate( |
|
field_name=Deprecated.keras_callback__save_model, |
|
warning_message=( |
|
"The save_model argument by default saves the model in the HDF5 format that cannot save " |
|
"custom objects like subclassed models and custom layers. This behavior will be deprecated " |
|
"in a future release in favor of the SavedModel format. Meanwhile, the HDF5 model is saved " |
|
"as W&B files and the SavedModel as W&B Artifacts." |
|
), |
|
) |
|
|
|
self.save_model_as_artifact = True |
|
self.log_weights = log_weights |
|
self.log_gradients = log_gradients |
|
self.training_data = training_data |
|
self.generator = generator |
|
self._graph_rendered = False |
|
|
|
data_type = kwargs.get("data_type", None) |
|
if data_type is not None: |
|
deprecate( |
|
field_name=Deprecated.keras_callback__data_type, |
|
warning_message=( |
|
"The data_type argument of wandb.keras.WandbCallback is deprecated " |
|
"and will be removed in a future release. Please use input_type instead.\n" |
|
"Setting input_type = data_type." |
|
), |
|
) |
|
input_type = data_type |
|
self.input_type = input_type |
|
self.output_type = output_type |
|
self.log_evaluation = log_evaluation |
|
self.validation_steps = validation_steps |
|
self.class_colors = np.array(class_colors) if class_colors is not None else None |
|
self.log_batch_frequency = log_batch_frequency |
|
self.log_best_prefix = log_best_prefix |
|
self.compute_flops = compute_flops |
|
|
|
self._prediction_batch_size = None |
|
|
|
if self.log_gradients: |
|
if int(tf.__version__.split(".")[0]) < 2: |
|
raise Exception("Gradient logging requires tensorflow 2.0 or higher.") |
|
if self.training_data is None: |
|
raise ValueError( |
|
"training_data argument is required for gradient logging." |
|
) |
|
if isinstance(self.training_data, (list, tuple)): |
|
if len(self.training_data) != 2: |
|
raise ValueError("training data must be a tuple of length two") |
|
self._training_data_x, self._training_data_y = self.training_data |
|
else: |
|
self._training_data_x = ( |
|
self.training_data |
|
) |
|
self._training_data_y = None |
|
|
|
|
|
if mode not in ["auto", "min", "max"]: |
|
wandb.termwarn( |
|
f"WandbCallback mode {mode} is unknown, fallback to auto mode." |
|
) |
|
mode = "auto" |
|
|
|
if mode == "min": |
|
self.monitor_op = operator.lt |
|
self.best = float("inf") |
|
elif mode == "max": |
|
self.monitor_op = operator.gt |
|
self.best = float("-inf") |
|
else: |
|
if "acc" in self.monitor or self.monitor.startswith("fmeasure"): |
|
self.monitor_op = operator.gt |
|
self.best = float("-inf") |
|
else: |
|
self.monitor_op = operator.lt |
|
self.best = float("inf") |
|
|
|
previous_best = wandb.run.summary.get(f"{self.log_best_prefix}{self.monitor}") |
|
if previous_best is not None: |
|
self.best = previous_best |
|
|
|
self._validation_data_logger = None |
|
self._validation_indexes = validation_indexes |
|
self._validation_row_processor = validation_row_processor |
|
self._prediction_row_processor = prediction_row_processor |
|
self._infer_missing_processors = infer_missing_processors |
|
self._log_evaluation_frequency = log_evaluation_frequency |
|
self._model_trained_since_last_eval = False |
|
|
|
def _build_grad_accumulator_model(self): |
|
inputs = self.model.inputs |
|
outputs = self.model(inputs) |
|
grad_acc_model = tf.keras.models.Model(inputs, outputs) |
|
grad_acc_model.compile(loss=self.model.loss, optimizer=_CustomOptimizer()) |
|
|
|
|
|
grad_acc_model._wandb_internal_model = True |
|
|
|
self._grad_accumulator_model = grad_acc_model |
|
self._grad_accumulator_callback = _GradAccumulatorCallback() |
|
|
|
def _implements_train_batch_hooks(self): |
|
return self.log_batch_frequency is not None |
|
|
|
def _implements_test_batch_hooks(self): |
|
return self.log_batch_frequency is not None |
|
|
|
def _implements_predict_batch_hooks(self): |
|
return self.log_batch_frequency is not None |
|
|
|
def set_params(self, params): |
|
self.params = params |
|
|
|
def set_model(self, model): |
|
super().set_model(model) |
|
if self.input_type == "auto" and len(model.inputs) == 1: |
|
self.input_type = wandb.util.guess_data_type( |
|
model.inputs[0].shape, risky=True |
|
) |
|
if self.input_type and self.output_type is None and len(model.outputs) == 1: |
|
self.output_type = wandb.util.guess_data_type(model.outputs[0].shape) |
|
if self.log_gradients: |
|
self._build_grad_accumulator_model() |
|
|
|
def _attempt_evaluation_log(self, commit=True): |
|
if self.log_evaluation and self._validation_data_logger: |
|
try: |
|
if not self.model: |
|
wandb.termwarn("WandbCallback unable to read model from trainer") |
|
else: |
|
self._validation_data_logger.log_predictions( |
|
predictions=self._validation_data_logger.make_predictions( |
|
self.model.predict |
|
), |
|
commit=commit, |
|
) |
|
self._model_trained_since_last_eval = False |
|
except Exception as e: |
|
wandb.termwarn("Error during prediction logging for epoch: " + str(e)) |
|
|
|
def on_epoch_end(self, epoch, logs=None): |
|
if logs is None: |
|
logs = {} |
|
if self.log_weights: |
|
wandb.log(self._log_weights(), commit=False) |
|
|
|
if self.log_gradients: |
|
wandb.log(self._log_gradients(), commit=False) |
|
|
|
if self.input_type in ( |
|
"image", |
|
"images", |
|
"segmentation_mask", |
|
) or self.output_type in ("image", "images", "segmentation_mask"): |
|
if self.generator: |
|
self.validation_data = next(self.generator) |
|
if self.validation_data is None: |
|
wandb.termwarn( |
|
"No validation_data set, pass a generator to the callback." |
|
) |
|
elif self.validation_data and len(self.validation_data) > 0: |
|
wandb.log( |
|
{"examples": self._log_images(num_images=self.predictions)}, |
|
commit=False, |
|
) |
|
|
|
if ( |
|
self._log_evaluation_frequency > 0 |
|
and epoch % self._log_evaluation_frequency == 0 |
|
): |
|
self._attempt_evaluation_log(commit=False) |
|
|
|
wandb.log({"epoch": epoch}, commit=False) |
|
wandb.log(logs, commit=True) |
|
|
|
self.current = logs.get(self.monitor) |
|
if self.current and self.monitor_op(self.current, self.best): |
|
if self.log_best_prefix: |
|
wandb.run.summary[f"{self.log_best_prefix}{self.monitor}"] = ( |
|
self.current |
|
) |
|
wandb.run.summary["{}{}".format(self.log_best_prefix, "epoch")] = epoch |
|
if self.verbose and not self.save_model: |
|
wandb.termlog( |
|
f"Epoch {epoch:05d}: {self.monitor} improved from {self.best:.5f} to {self.current:.5f}" |
|
) |
|
if self.save_model: |
|
self._save_model(epoch) |
|
|
|
if self.save_model and self.save_model_as_artifact: |
|
self._save_model_as_artifact(epoch) |
|
|
|
self.best = self.current |
|
|
|
|
|
def on_batch_begin(self, batch, logs=None): |
|
pass |
|
|
|
|
|
def on_batch_end(self, batch, logs=None): |
|
if self.save_graph and not self._graph_rendered: |
|
|
|
wandb.run.summary["graph"] = wandb.Graph.from_keras(self.model) |
|
self._graph_rendered = True |
|
|
|
if self.log_batch_frequency and batch % self.log_batch_frequency == 0: |
|
wandb.log(logs, commit=True) |
|
|
|
def on_train_batch_begin(self, batch, logs=None): |
|
self._model_trained_since_last_eval = True |
|
|
|
def on_train_batch_end(self, batch, logs=None): |
|
if self.save_graph and not self._graph_rendered: |
|
|
|
wandb.run.summary["graph"] = wandb.Graph.from_keras(self.model) |
|
self._graph_rendered = True |
|
|
|
if self.log_batch_frequency and batch % self.log_batch_frequency == 0: |
|
wandb.log(logs, commit=True) |
|
|
|
def on_test_begin(self, logs=None): |
|
pass |
|
|
|
def on_test_end(self, logs=None): |
|
pass |
|
|
|
def on_test_batch_begin(self, batch, logs=None): |
|
pass |
|
|
|
def on_test_batch_end(self, batch, logs=None): |
|
pass |
|
|
|
def on_train_begin(self, logs=None): |
|
if self.log_evaluation: |
|
try: |
|
validation_data = None |
|
if self.validation_data: |
|
validation_data = self.validation_data |
|
elif self.generator: |
|
if not self.validation_steps: |
|
wandb.termwarn( |
|
"WandbCallback is unable to log validation data. " |
|
"When using a generator for validation_data, you must pass validation_steps" |
|
) |
|
else: |
|
x = None |
|
y_true = None |
|
for _ in range(self.validation_steps): |
|
bx, by_true = next(self.generator) |
|
if x is None: |
|
x, y_true = bx, by_true |
|
else: |
|
x, y_true = ( |
|
np.append(x, bx, axis=0), |
|
np.append(y_true, by_true, axis=0), |
|
) |
|
validation_data = (x, y_true) |
|
else: |
|
wandb.termwarn( |
|
"WandbCallback is unable to read validation_data from trainer " |
|
"and therefore cannot log validation data. Ensure Keras is properly " |
|
"patched by calling `from wandb.keras import WandbCallback` at the top of your script." |
|
) |
|
if validation_data: |
|
self._validation_data_logger = ValidationDataLogger( |
|
inputs=validation_data[0], |
|
targets=validation_data[1], |
|
indexes=self._validation_indexes, |
|
validation_row_processor=self._validation_row_processor, |
|
prediction_row_processor=self._prediction_row_processor, |
|
class_labels=self.labels, |
|
infer_missing_processors=self._infer_missing_processors, |
|
) |
|
except Exception as e: |
|
wandb.termwarn( |
|
"Error initializing ValidationDataLogger in WandbCallback. " |
|
f"Skipping logging validation data. Error: {str(e)}" |
|
) |
|
|
|
if self.compute_flops and _can_compute_flops(): |
|
try: |
|
wandb.summary["GFLOPs"] = self.get_flops() |
|
except Exception: |
|
logger.exception("Error computing FLOPs") |
|
wandb.termwarn("Unable to compute FLOPs for this model.") |
|
|
|
def on_train_end(self, logs=None): |
|
if self._model_trained_since_last_eval: |
|
self._attempt_evaluation_log() |
|
|
|
def on_predict_begin(self, logs=None): |
|
pass |
|
|
|
def on_predict_end(self, logs=None): |
|
pass |
|
|
|
def on_predict_batch_begin(self, batch, logs=None): |
|
pass |
|
|
|
def on_predict_batch_end(self, batch, logs=None): |
|
pass |
|
|
|
def _logits_to_captions(self, logits): |
|
if logits[0].shape[-1] == 1: |
|
|
|
|
|
if len(self.labels) == 2: |
|
|
|
captions = [ |
|
self.labels[1] if logits[0] > 0.5 else self.labels[0] |
|
for logit in logits |
|
] |
|
else: |
|
if len(self.labels) != 0: |
|
wandb.termwarn( |
|
"keras model is producing a single output, " |
|
'so labels should be a length two array: ["False label", "True label"].' |
|
) |
|
captions = [logit[0] for logit in logits] |
|
else: |
|
|
|
|
|
labels = np.argmax(np.stack(logits), axis=1) |
|
|
|
if len(self.labels) > 0: |
|
|
|
captions = [] |
|
for label in labels: |
|
try: |
|
captions.append(self.labels[label]) |
|
except IndexError: |
|
captions.append(label) |
|
else: |
|
captions = labels |
|
return captions |
|
|
|
def _masks_to_pixels(self, masks): |
|
|
|
if len(masks[0].shape) == 2 or masks[0].shape[-1] == 1: |
|
return masks |
|
class_colors = ( |
|
self.class_colors |
|
if self.class_colors is not None |
|
else np.array(wandb.util.class_colors(masks[0].shape[2])) |
|
) |
|
imgs = class_colors[np.argmax(masks, axis=-1)] |
|
return imgs |
|
|
|
def _log_images(self, num_images=36): |
|
validation_X = self.validation_data[0] |
|
validation_y = self.validation_data[1] |
|
|
|
validation_length = len(validation_X) |
|
|
|
if validation_length > num_images: |
|
|
|
indices = np.random.choice(validation_length, num_images, replace=False) |
|
else: |
|
indices = range(validation_length) |
|
|
|
test_data = [] |
|
test_output = [] |
|
for i in indices: |
|
test_example = validation_X[i] |
|
test_data.append(test_example) |
|
test_output.append(validation_y[i]) |
|
|
|
if self.model.stateful: |
|
predictions = self.model.predict(np.stack(test_data), batch_size=1) |
|
self.model.reset_states() |
|
else: |
|
predictions = self.model.predict( |
|
np.stack(test_data), batch_size=self._prediction_batch_size |
|
) |
|
if len(predictions) != len(test_data): |
|
self._prediction_batch_size = 1 |
|
predictions = self.model.predict( |
|
np.stack(test_data), batch_size=self._prediction_batch_size |
|
) |
|
|
|
if self.input_type == "label": |
|
if self.output_type in ("image", "images", "segmentation_mask"): |
|
captions = self._logits_to_captions(test_data) |
|
output_image_data = ( |
|
self._masks_to_pixels(predictions) |
|
if self.output_type == "segmentation_mask" |
|
else predictions |
|
) |
|
reference_image_data = ( |
|
self._masks_to_pixels(test_output) |
|
if self.output_type == "segmentation_mask" |
|
else test_output |
|
) |
|
output_images = [ |
|
wandb.Image(data, caption=captions[i], grouping=2) |
|
for i, data in enumerate(output_image_data) |
|
] |
|
reference_images = [ |
|
wandb.Image(data, caption=captions[i]) |
|
for i, data in enumerate(reference_image_data) |
|
] |
|
return list(chain.from_iterable(zip(output_images, reference_images))) |
|
elif self.input_type in ("image", "images", "segmentation_mask"): |
|
input_image_data = ( |
|
self._masks_to_pixels(test_data) |
|
if self.input_type == "segmentation_mask" |
|
else test_data |
|
) |
|
if self.output_type == "label": |
|
|
|
captions = self._logits_to_captions(predictions) |
|
return [ |
|
wandb.Image(data, caption=captions[i]) |
|
for i, data in enumerate(test_data) |
|
] |
|
elif self.output_type in ("image", "images", "segmentation_mask"): |
|
output_image_data = ( |
|
self._masks_to_pixels(predictions) |
|
if self.output_type == "segmentation_mask" |
|
else predictions |
|
) |
|
reference_image_data = ( |
|
self._masks_to_pixels(test_output) |
|
if self.output_type == "segmentation_mask" |
|
else test_output |
|
) |
|
input_images = [ |
|
wandb.Image(data, grouping=3) |
|
for i, data in enumerate(input_image_data) |
|
] |
|
output_images = [ |
|
wandb.Image(data) for i, data in enumerate(output_image_data) |
|
] |
|
reference_images = [ |
|
wandb.Image(data) for i, data in enumerate(reference_image_data) |
|
] |
|
return list( |
|
chain.from_iterable( |
|
zip(input_images, output_images, reference_images) |
|
) |
|
) |
|
else: |
|
|
|
return [wandb.Image(img) for img in test_data] |
|
elif self.output_type in ("image", "images", "segmentation_mask"): |
|
|
|
output_image_data = ( |
|
self._masks_to_pixels(predictions) |
|
if self.output_type == "segmentation_mask" |
|
else predictions |
|
) |
|
reference_image_data = ( |
|
self._masks_to_pixels(test_output) |
|
if self.output_type == "segmentation_mask" |
|
else test_output |
|
) |
|
output_images = [ |
|
wandb.Image(data, grouping=2) |
|
for i, data in enumerate(output_image_data) |
|
] |
|
reference_images = [ |
|
wandb.Image(data) for i, data in enumerate(reference_image_data) |
|
] |
|
return list(chain.from_iterable(zip(output_images, reference_images))) |
|
|
|
def _log_weights(self): |
|
metrics = {} |
|
for layer in self.model.layers: |
|
weights = layer.get_weights() |
|
if len(weights) == 1: |
|
_update_if_numeric( |
|
metrics, "parameters/" + layer.name + ".weights", weights[0] |
|
) |
|
elif len(weights) == 2: |
|
_update_if_numeric( |
|
metrics, "parameters/" + layer.name + ".weights", weights[0] |
|
) |
|
_update_if_numeric( |
|
metrics, "parameters/" + layer.name + ".bias", weights[1] |
|
) |
|
return metrics |
|
|
|
def _log_gradients(self): |
|
|
|
og_level = tf_logger.level |
|
tf_logger.setLevel("ERROR") |
|
|
|
self._grad_accumulator_model.fit( |
|
self._training_data_x, |
|
self._training_data_y, |
|
verbose=0, |
|
callbacks=[self._grad_accumulator_callback], |
|
) |
|
tf_logger.setLevel(og_level) |
|
weights = self.model.trainable_weights |
|
grads = self._grad_accumulator_callback.grads |
|
metrics = {} |
|
for weight, grad in zip(weights, grads): |
|
metrics["gradients/" + weight.name.split(":")[0] + ".gradient"] = ( |
|
wandb.Histogram(grad) |
|
) |
|
return metrics |
|
|
|
def _log_dataframe(self): |
|
x, y_true, y_pred = None, None, None |
|
|
|
if self.validation_data: |
|
x, y_true = self.validation_data[0], self.validation_data[1] |
|
y_pred = self.model.predict(x) |
|
elif self.generator: |
|
if not self.validation_steps: |
|
wandb.termwarn( |
|
"when using a generator for validation data with dataframes, " |
|
"you must pass validation_steps. skipping" |
|
) |
|
return None |
|
|
|
for _ in range(self.validation_steps): |
|
bx, by_true = next(self.generator) |
|
by_pred = self.model.predict(bx) |
|
if x is None: |
|
x, y_true, y_pred = bx, by_true, by_pred |
|
else: |
|
x, y_true, y_pred = ( |
|
np.append(x, bx, axis=0), |
|
np.append(y_true, by_true, axis=0), |
|
np.append(y_pred, by_pred, axis=0), |
|
) |
|
|
|
if self.input_type in ("image", "images") and self.output_type == "label": |
|
return wandb.image_categorizer_dataframe( |
|
x=x, y_true=y_true, y_pred=y_pred, labels=self.labels |
|
) |
|
elif ( |
|
self.input_type in ("image", "images") |
|
and self.output_type == "segmentation_mask" |
|
): |
|
return wandb.image_segmentation_dataframe( |
|
x=x, |
|
y_true=y_true, |
|
y_pred=y_pred, |
|
labels=self.labels, |
|
class_colors=self.class_colors, |
|
) |
|
else: |
|
wandb.termwarn( |
|
f"unknown dataframe type for input_type={self.input_type} and output_type={self.output_type}" |
|
) |
|
return None |
|
|
|
def _save_model(self, epoch): |
|
if wandb.run.disabled: |
|
return |
|
if self.verbose > 0: |
|
wandb.termlog( |
|
f"Epoch {epoch:05d}: {self.monitor} improved from {self.best:.5f} to {self.current:.5f}, " |
|
f"saving model to {self.filepath}" |
|
) |
|
|
|
try: |
|
if self.save_weights_only: |
|
self.model.save_weights(self.filepath, overwrite=True) |
|
else: |
|
self.model.save(self.filepath, overwrite=True) |
|
|
|
|
|
except (ImportError, RuntimeError, TypeError, AttributeError): |
|
logger.exception("Error saving model in the h5py format") |
|
wandb.termerror( |
|
"Can't save model in the h5py format. The model will be saved as " |
|
"as an W&B Artifact in the 'tf' format." |
|
) |
|
|
|
def _save_model_as_artifact(self, epoch): |
|
if wandb.run.disabled: |
|
return |
|
|
|
|
|
|
|
|
|
self.model.save(self.filepath[:-3], overwrite=True, save_format="tf") |
|
|
|
|
|
name = wandb.util.make_artifact_name_safe(f"model-{wandb.run.name}") |
|
model_artifact = wandb.Artifact(name, type="model") |
|
model_artifact.add_dir(self.filepath[:-3]) |
|
wandb.run.log_artifact(model_artifact, aliases=["latest", f"epoch_{epoch}"]) |
|
|
|
|
|
shutil.rmtree(self.filepath[:-3]) |
|
|
|
def get_flops(self) -> float: |
|
"""Calculate FLOPS [GFLOPs] for a tf.keras.Model or tf.keras.Sequential model in inference mode. |
|
|
|
It uses tf.compat.v1.profiler under the hood. |
|
""" |
|
if not hasattr(self, "model"): |
|
raise wandb.Error("self.model must be set before using this method.") |
|
|
|
if not isinstance( |
|
self.model, (tf.keras.models.Sequential, tf.keras.models.Model) |
|
): |
|
raise TypeError( |
|
"Calculating FLOPS is only supported for " |
|
"`tf.keras.Model` and `tf.keras.Sequential` instances." |
|
) |
|
|
|
from tensorflow.python.framework.convert_to_constants import ( |
|
convert_variables_to_constants_v2_as_graph, |
|
) |
|
|
|
|
|
batch_size = 1 |
|
inputs = [ |
|
tf.TensorSpec([batch_size] + inp.shape[1:], inp.dtype) |
|
for inp in self.model.inputs |
|
] |
|
|
|
|
|
real_model = tf.function(self.model).get_concrete_function(inputs) |
|
frozen_func, _ = convert_variables_to_constants_v2_as_graph(real_model) |
|
|
|
|
|
run_meta = tf.compat.v1.RunMetadata() |
|
opts = ( |
|
tf.compat.v1.profiler.ProfileOptionBuilder( |
|
tf.compat.v1.profiler.ProfileOptionBuilder().float_operation() |
|
) |
|
.with_empty_output() |
|
.build() |
|
) |
|
|
|
flops = tf.compat.v1.profiler.profile( |
|
graph=frozen_func.graph, run_meta=run_meta, cmd="scope", options=opts |
|
) |
|
|
|
|
|
return (flops.total_float_ops / 1e9) / 2 |
|
|