File size: 4,803 Bytes
9c6594c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 |
"""watch."""
from __future__ import annotations
import logging
from typing import TYPE_CHECKING, Sequence
try:
from typing import Literal
except ImportError:
from typing_extensions import Literal # type: ignore
import wandb
from .lib import telemetry
if TYPE_CHECKING:
import torch # type: ignore [import-not-found]
from wandb.sdk.wandb_run import Run
logger = logging.getLogger("wandb")
_global_watch_idx = 0
def _watch(
run: Run,
models: torch.nn.Module | Sequence[torch.nn.Module],
criterion: torch.F | None = None,
log: Literal["gradients", "parameters", "all"] | None = "gradients",
log_freq: int = 1000,
idx: int | None = None,
log_graph: bool = False,
):
"""Hooks into the given PyTorch model(s) to monitor gradients and the model's computational graph.
This function can track parameters, gradients, or both during training. It should be
extended to support arbitrary machine learning models in the future.
Args:
run (wandb.sdk.wandb_run.Run): The run object to log to.
models (Union[torch.nn.Module, Sequence[torch.nn.Module]]):
A single model or a sequence of models to be monitored.
criterion (Optional[torch.F]):
The loss function being optimized (optional).
log (Optional[Literal["gradients", "parameters", "all"]]):
Specifies whether to log "gradients", "parameters", or "all".
Set to None to disable logging. (default="gradients")
log_freq (int):
Frequency (in batches) to log gradients and parameters. (default=1000)
idx (Optional[int]):
Index used when tracking multiple models with `wandb.watch`. (default=None)
log_graph (bool):
Whether to log the model's computational graph. (default=False)
Returns:
wandb.Graph:
The graph object, which will be populated after the first backward pass.
Raises:
ValueError: If `wandb.init` has not been called.
TypeError: If any of the models are not instances of `torch.nn.Module`.
"""
global _global_watch_idx
with telemetry.context() as tel:
tel.feature.watch = True
logger.info("Watching")
if log not in {"gradients", "parameters", "all", None}:
raise ValueError("log must be one of 'gradients', 'parameters', 'all', or None")
log_parameters = log in {"parameters", "all"}
log_gradients = log in {"gradients", "all"}
if not isinstance(models, (tuple, list)):
models = (models,)
torch = wandb.util.get_module(
"torch", required="wandb.watch only works with pytorch, couldn't import torch."
)
for model in models:
if not isinstance(model, torch.nn.Module):
raise TypeError(
f"Expected a pytorch model (torch.nn.Module). Received {type(model)}"
)
graphs = []
prefix = ""
if idx is None:
idx = _global_watch_idx
for local_idx, model in enumerate(models):
global_idx = idx + local_idx
_global_watch_idx += 1
if global_idx > 0:
# TODO: this makes ugly chart names like gradients/graph_1conv1d.bias
prefix = f"graph_{global_idx}"
if log_parameters:
run._torch.add_log_parameters_hook(
model,
prefix=prefix,
log_freq=log_freq,
)
if log_gradients:
run._torch.add_log_gradients_hook(
model,
prefix=prefix,
log_freq=log_freq,
)
if log_graph:
graph = run._torch.hook_torch(model, criterion, graph_idx=global_idx)
graphs.append(graph)
# NOTE: the graph is set in run.summary by hook_torch on the backward pass
return graphs
def _unwatch(
run: Run, models: torch.nn.Module | Sequence[torch.nn.Module] | None = None
) -> None:
"""Remove pytorch model topology, gradient and parameter hooks.
Args:
run (wandb.sdk.wandb_run.Run):
The run object to log to.
models (torch.nn.Module | Sequence[torch.nn.Module]):
Optional list of pytorch models that have had watch called on them
"""
if models:
if not isinstance(models, (tuple, list)):
models = (models,)
for model in models:
if not hasattr(model, "_wandb_hook_names"):
wandb.termwarn(f"{model} model has not been watched")
else:
for name in model._wandb_hook_names:
run._torch.unhook(name)
delattr(model, "_wandb_hook_names")
# TODO: we should also remove recursively model._wandb_watch_called
else:
run._torch.unhook_all()
|