File size: 4,797 Bytes
9c6594c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 |
"""W&B callback for sb3.
Really simple callback to get logging for each tree
Example usage:
```python
import gym
from stable_baselines3 import PPO
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.vec_env import DummyVecEnv, VecVideoRecorder
import wandb
from wandb.integration.sb3 import WandbCallback
config = {
"policy_type": "MlpPolicy",
"total_timesteps": 25000,
"env_name": "CartPole-v1",
}
run = wandb.init(
project="sb3",
config=config,
sync_tensorboard=True, # auto-upload sb3's tensorboard metrics
monitor_gym=True, # auto-upload the videos of agents playing the game
save_code=True, # optional
)
def make_env():
env = gym.make(config["env_name"])
env = Monitor(env) # record stats such as returns
return env
env = DummyVecEnv([make_env])
env = VecVideoRecorder(
env, "videos", record_video_trigger=lambda x: x % 2000 == 0, video_length=200
)
model = PPO(config["policy_type"], env, verbose=1, tensorboard_log=f"runs")
model.learn(
total_timesteps=config["total_timesteps"],
callback=WandbCallback(
model_save_path=f"models/{run.id}",
gradient_save_freq=100,
log="all",
),
)
```
"""
import logging
import os
from typing import Literal, Optional
from stable_baselines3.common.callbacks import BaseCallback # type: ignore
import wandb
from wandb.sdk.lib import telemetry as wb_telemetry
logger = logging.getLogger(__name__)
class WandbCallback(BaseCallback):
"""Callback for logging experiments to Weights and Biases.
Log SB3 experiments to Weights and Biases
- Added model tracking and uploading
- Added complete hyperparameters recording
- Added gradient logging
- Note that `wandb.init(...)` must be called before the WandbCallback can be used.
Args:
verbose: The verbosity of sb3 output
model_save_path: Path to the folder where the model will be saved, The default value is `None` so the model is not logged
model_save_freq: Frequency to save the model
gradient_save_freq: Frequency to log gradient. The default value is 0 so the gradients are not logged
log: What to log. One of "gradients", "parameters", or "all".
"""
def __init__(
self,
verbose: int = 0,
model_save_path: Optional[str] = None,
model_save_freq: int = 0,
gradient_save_freq: int = 0,
log: Optional[Literal["gradients", "parameters", "all"]] = "all",
) -> None:
super().__init__(verbose)
if wandb.run is None:
raise wandb.Error("You must call wandb.init() before WandbCallback()")
with wb_telemetry.context() as tel:
tel.feature.sb3 = True
self.model_save_freq = model_save_freq
self.model_save_path = model_save_path
self.gradient_save_freq = gradient_save_freq
if log not in ["gradients", "parameters", "all", None]:
wandb.termwarn(
"`log` must be one of `None`, 'gradients', 'parameters', or 'all', "
"falling back to 'all'"
)
log = "all"
self.log = log
# Create folder if needed
if self.model_save_path is not None:
os.makedirs(self.model_save_path, exist_ok=True)
self.path = os.path.join(self.model_save_path, "model.zip")
else:
assert (
self.model_save_freq == 0
), "to use the `model_save_freq` you have to set the `model_save_path` parameter"
def _init_callback(self) -> None:
d = {}
if "algo" not in d:
d["algo"] = type(self.model).__name__
for key in self.model.__dict__:
if key in wandb.config:
continue
if type(self.model.__dict__[key]) in [float, int, str]:
d[key] = self.model.__dict__[key]
else:
d[key] = str(self.model.__dict__[key])
if self.gradient_save_freq > 0:
wandb.watch(
self.model.policy,
log_freq=self.gradient_save_freq,
log=self.log,
)
wandb.config.setdefaults(d)
def _on_step(self) -> bool:
if self.model_save_freq > 0:
if self.model_save_path is not None:
if self.n_calls % self.model_save_freq == 0:
self.save_model()
return True
def _on_training_end(self) -> None:
if self.model_save_path is not None:
self.save_model()
def save_model(self) -> None:
self.model.save(self.path)
wandb.save(self.path, base_path=self.model_save_path)
if self.verbose > 1:
logger.info(f"Saving model checkpoint to {self.path}")
|