File size: 11,371 Bytes
9c6594c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 |
from typing import Any, Callable, Dict, List, Optional
from ultralytics.yolo.engine.model import YOLO
from ultralytics.yolo.engine.trainer import BaseTrainer
try:
from ultralytics.yolo.utils import RANK
from ultralytics.yolo.utils.torch_utils import get_flops, get_num_params
except ModuleNotFoundError:
from ultralytics.utils import RANK
from ultralytics.utils.torch_utils import get_flops, get_num_params
from ultralytics.yolo.v8.classify.train import ClassificationTrainer
import wandb
from wandb.sdk.lib import telemetry
class WandbCallback:
"""An internal YOLO model wrapper that tracks metrics, and logs models to Weights & Biases.
Usage:
```python
from wandb.integration.yolov8.yolov8 import WandbCallback
model = YOLO("yolov8n.pt")
wandb_logger = WandbCallback(
model,
)
for event, callback_fn in wandb_logger.callbacks.items():
model.add_callback(event, callback_fn)
```
"""
def __init__(
self,
yolo: YOLO,
run_name: Optional[str] = None,
project: Optional[str] = None,
tags: Optional[List[str]] = None,
resume: Optional[str] = None,
**kwargs: Optional[Any],
) -> None:
"""A utility class to manage wandb run and various callbacks for the ultralytics YOLOv8 framework.
Args:
yolo: A YOLOv8 model that's inherited from `:class:ultralytics.yolo.engine.model.YOLO`
run_name, str: The name of the Weights & Biases run, defaults to an auto generated run_name if `trainer.args.name` is not defined.
project, str: The name of the Weights & Biases project, defaults to `"YOLOv8"` if `trainer.args.project` is not defined.
tags, List[str]: A list of tags to be added to the Weights & Biases run, defaults to `["YOLOv8"]`.
resume, str: Whether to resume a previous run on Weights & Biases, defaults to `None`.
**kwargs: Additional arguments to be passed to `wandb.init()`.
"""
self.yolo = yolo
self.run_name = run_name
self.project = project
self.tags = tags
self.resume = resume
self.kwargs = kwargs
def on_pretrain_routine_start(self, trainer: BaseTrainer) -> None:
"""Starts a new wandb run to track the training process and log to Weights & Biases.
Args:
trainer: A task trainer that's inherited from `:class:ultralytics.yolo.engine.trainer.BaseTrainer`
that contains the model training and optimization routine.
"""
if wandb.run is None:
self.run = wandb.init(
name=self.run_name if self.run_name else trainer.args.name,
project=self.project
if self.project
else trainer.args.project or "YOLOv8",
tags=self.tags if self.tags else ["YOLOv8"],
config=vars(trainer.args),
resume=self.resume if self.resume else None,
**self.kwargs,
)
else:
self.run = wandb.run
assert self.run is not None
self.run.define_metric("epoch", hidden=True)
self.run.define_metric(
"train/*", step_metric="epoch", step_sync=True, summary="min"
)
self.run.define_metric(
"val/*", step_metric="epoch", step_sync=True, summary="min"
)
self.run.define_metric(
"metrics/*", step_metric="epoch", step_sync=True, summary="max"
)
self.run.define_metric(
"lr/*", step_metric="epoch", step_sync=True, summary="last"
)
with telemetry.context(run=wandb.run) as tel:
tel.feature.ultralytics_yolov8 = True
def on_pretrain_routine_end(self, trainer: BaseTrainer) -> None:
assert self.run is not None
self.run.summary.update(
{
"model/parameters": get_num_params(trainer.model),
"model/GFLOPs": round(get_flops(trainer.model), 3),
}
)
def on_train_epoch_start(self, trainer: BaseTrainer) -> None:
"""On train epoch start we only log epoch number to the Weights & Biases run."""
# We log the epoch number here to commit the previous step,
assert self.run is not None
self.run.log({"epoch": trainer.epoch + 1})
def on_train_epoch_end(self, trainer: BaseTrainer) -> None:
"""On train epoch end we log all the metrics to the Weights & Biases run."""
assert self.run is not None
self.run.log(
{
**trainer.metrics,
**trainer.label_loss_items(trainer.tloss, prefix="train"),
**trainer.lr,
},
)
# Currently only the detection and segmentation trainers save images to the save_dir
if not isinstance(trainer, ClassificationTrainer):
self.run.log(
{
"train_batch_images": [
wandb.Image(str(image_path), caption=image_path.stem)
for image_path in trainer.save_dir.glob("train_batch*.jpg")
]
}
)
def on_fit_epoch_end(self, trainer: BaseTrainer) -> None:
"""On fit epoch end we log all the best metrics and model detail to Weights & Biases run summary."""
assert self.run is not None
if trainer.epoch == 0:
speeds = [
trainer.validator.speed.get(
key,
)
for key in (1, "inference")
]
speed = speeds[0] if speeds[0] else speeds[1]
if speed:
self.run.summary.update(
{
"model/speed(ms/img)": round(speed, 3),
}
)
if trainer.best_fitness == trainer.fitness:
self.run.summary.update(
{
"best/epoch": trainer.epoch + 1,
**{f"best/{key}": val for key, val in trainer.metrics.items()},
}
)
def on_train_end(self, trainer: BaseTrainer) -> None:
"""On train end we log all the media, including plots, images and best model artifact to Weights & Biases."""
# Currently only the detection and segmentation trainers save images to the save_dir
assert self.run is not None
if not isinstance(trainer, ClassificationTrainer):
assert self.run is not None
self.run.log(
{
"plots": [
wandb.Image(str(image_path), caption=image_path.stem)
for image_path in trainer.save_dir.glob("*.png")
],
"val_images": [
wandb.Image(str(image_path), caption=image_path.stem)
for image_path in trainer.validator.save_dir.glob("val*.jpg")
],
},
)
if trainer.best.exists():
assert self.run is not None
self.run.log_artifact(
str(trainer.best),
type="model",
name=f"{self.run.name}_{trainer.args.task}.pt",
aliases=["best", f"epoch_{trainer.epoch + 1}"],
)
def on_model_save(self, trainer: BaseTrainer) -> None:
"""On model save we log the model as an artifact to Weights & Biases."""
assert self.run is not None
self.run.log_artifact(
str(trainer.last),
type="model",
name=f"{self.run.name}_{trainer.args.task}.pt",
aliases=["last", f"epoch_{trainer.epoch + 1}"],
)
def teardown(self, _trainer: BaseTrainer) -> None:
"""On teardown, we finish the Weights & Biases run and set it to None."""
assert self.run is not None
self.run.finish()
self.run = None
@property
def callbacks(
self,
) -> Dict[str, Callable]:
"""Property contains all the relevant callbacks to add to the YOLO model for the Weights & Biases logging."""
return {
"on_pretrain_routine_start": self.on_pretrain_routine_start,
"on_pretrain_routine_end": self.on_pretrain_routine_end,
"on_train_epoch_start": self.on_train_epoch_start,
"on_train_epoch_end": self.on_train_epoch_end,
"on_fit_epoch_end": self.on_fit_epoch_end,
"on_train_end": self.on_train_end,
"on_model_save": self.on_model_save,
"teardown": self.teardown,
}
def add_callbacks(
yolo: YOLO,
run_name: Optional[str] = None,
project: Optional[str] = None,
tags: Optional[List[str]] = None,
resume: Optional[str] = None,
**kwargs: Optional[Any],
) -> YOLO:
"""A YOLO model wrapper that tracks metrics, and logs models to Weights & Biases.
Args:
yolo: A YOLOv8 model that's inherited from `:class:ultralytics.yolo.engine.model.YOLO`
run_name, str: The name of the Weights & Biases run, defaults to an auto generated name if `trainer.args.name` is not defined.
project, str: The name of the Weights & Biases project, defaults to `"YOLOv8"` if `trainer.args.project` is not defined.
tags, List[str]: A list of tags to be added to the Weights & Biases run, defaults to `["YOLOv8"]`.
resume, str: Whether to resume a previous run on Weights & Biases, defaults to `None`.
**kwargs: Additional arguments to be passed to `wandb.init()`.
Usage:
```python
from wandb.integration.yolov8 import add_callbacks as add_wandb_callbacks
model = YOLO("yolov8n.pt")
add_wandb_callbacks(
model,
)
model.train(
data="coco128.yaml",
epochs=3,
imgsz=640,
)
```
"""
wandb.termwarn(
"""The wandb callback is currently in beta and is subject to change based on updates to `ultralytics yolov8`.
The callback is tested and supported for ultralytics v8.0.43 and above.
Please report any issues to https://github.com/wandb/wandb/issues with the tag `yolov8`.
""",
repeat=False,
)
wandb.termwarn(
"""This wandb callback is no longer functional and would be deprecated in the near future.
We recommend you to use the updated callback using `from wandb.integration.ultralytics import add_wandb_callback`.
The updated callback is tested and supported for ultralytics 8.0.167 and above.
You can refer to https://docs.wandb.ai/guides/integrations/ultralytics for the updated documentation.
Please report any issues to https://github.com/wandb/wandb/issues with the tag `yolov8`.
""",
repeat=False,
)
if RANK in [-1, 0]:
wandb_logger = WandbCallback(
yolo, run_name=run_name, project=project, tags=tags, resume=resume, **kwargs
)
for event, callback_fn in wandb_logger.callbacks.items():
yolo.add_callback(event, callback_fn)
return yolo
else:
wandb.termerror(
"The RANK of the process to add the callbacks was neither 0 or -1."
"No Weights & Biases callbacks were added to this instance of the YOLO model."
)
return yolo
|