|
"""W&B callback for sb3. |
|
|
|
Really simple callback to get logging for each tree |
|
|
|
Example usage: |
|
|
|
```python |
|
import gym |
|
from stable_baselines3 import PPO |
|
from stable_baselines3.common.monitor import Monitor |
|
from stable_baselines3.common.vec_env import DummyVecEnv, VecVideoRecorder |
|
import wandb |
|
from wandb.integration.sb3 import WandbCallback |
|
|
|
|
|
config = { |
|
"policy_type": "MlpPolicy", |
|
"total_timesteps": 25000, |
|
"env_name": "CartPole-v1", |
|
} |
|
run = wandb.init( |
|
project="sb3", |
|
config=config, |
|
sync_tensorboard=True, # auto-upload sb3's tensorboard metrics |
|
monitor_gym=True, # auto-upload the videos of agents playing the game |
|
save_code=True, # optional |
|
) |
|
|
|
|
|
def make_env(): |
|
env = gym.make(config["env_name"]) |
|
env = Monitor(env) # record stats such as returns |
|
return env |
|
|
|
|
|
env = DummyVecEnv([make_env]) |
|
env = VecVideoRecorder( |
|
env, "videos", record_video_trigger=lambda x: x % 2000 == 0, video_length=200 |
|
) |
|
model = PPO(config["policy_type"], env, verbose=1, tensorboard_log=f"runs") |
|
model.learn( |
|
total_timesteps=config["total_timesteps"], |
|
callback=WandbCallback( |
|
model_save_path=f"models/{run.id}", |
|
gradient_save_freq=100, |
|
log="all", |
|
), |
|
) |
|
``` |
|
""" |
|
|
|
import logging |
|
import os |
|
from typing import Literal, Optional |
|
|
|
from stable_baselines3.common.callbacks import BaseCallback |
|
|
|
import wandb |
|
from wandb.sdk.lib import telemetry as wb_telemetry |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class WandbCallback(BaseCallback): |
|
"""Callback for logging experiments to Weights and Biases. |
|
|
|
Log SB3 experiments to Weights and Biases |
|
- Added model tracking and uploading |
|
- Added complete hyperparameters recording |
|
- Added gradient logging |
|
- Note that `wandb.init(...)` must be called before the WandbCallback can be used. |
|
|
|
Args: |
|
verbose: The verbosity of sb3 output |
|
model_save_path: Path to the folder where the model will be saved, The default value is `None` so the model is not logged |
|
model_save_freq: Frequency to save the model |
|
gradient_save_freq: Frequency to log gradient. The default value is 0 so the gradients are not logged |
|
log: What to log. One of "gradients", "parameters", or "all". |
|
""" |
|
|
|
def __init__( |
|
self, |
|
verbose: int = 0, |
|
model_save_path: Optional[str] = None, |
|
model_save_freq: int = 0, |
|
gradient_save_freq: int = 0, |
|
log: Optional[Literal["gradients", "parameters", "all"]] = "all", |
|
) -> None: |
|
super().__init__(verbose) |
|
if wandb.run is None: |
|
raise wandb.Error("You must call wandb.init() before WandbCallback()") |
|
with wb_telemetry.context() as tel: |
|
tel.feature.sb3 = True |
|
self.model_save_freq = model_save_freq |
|
self.model_save_path = model_save_path |
|
self.gradient_save_freq = gradient_save_freq |
|
if log not in ["gradients", "parameters", "all", None]: |
|
wandb.termwarn( |
|
"`log` must be one of `None`, 'gradients', 'parameters', or 'all', " |
|
"falling back to 'all'" |
|
) |
|
log = "all" |
|
self.log = log |
|
|
|
if self.model_save_path is not None: |
|
os.makedirs(self.model_save_path, exist_ok=True) |
|
self.path = os.path.join(self.model_save_path, "model.zip") |
|
else: |
|
assert ( |
|
self.model_save_freq == 0 |
|
), "to use the `model_save_freq` you have to set the `model_save_path` parameter" |
|
|
|
def _init_callback(self) -> None: |
|
d = {} |
|
if "algo" not in d: |
|
d["algo"] = type(self.model).__name__ |
|
for key in self.model.__dict__: |
|
if key in wandb.config: |
|
continue |
|
if type(self.model.__dict__[key]) in [float, int, str]: |
|
d[key] = self.model.__dict__[key] |
|
else: |
|
d[key] = str(self.model.__dict__[key]) |
|
if self.gradient_save_freq > 0: |
|
wandb.watch( |
|
self.model.policy, |
|
log_freq=self.gradient_save_freq, |
|
log=self.log, |
|
) |
|
wandb.config.setdefaults(d) |
|
|
|
def _on_step(self) -> bool: |
|
if self.model_save_freq > 0: |
|
if self.model_save_path is not None: |
|
if self.n_calls % self.model_save_freq == 0: |
|
self.save_model() |
|
return True |
|
|
|
def _on_training_end(self) -> None: |
|
if self.model_save_path is not None: |
|
self.save_model() |
|
|
|
def save_model(self) -> None: |
|
self.model.save(self.path) |
|
wandb.save(self.path, base_path=self.model_save_path) |
|
if self.verbose > 1: |
|
logger.info(f"Saving model checkpoint to {self.path}") |
|
|