|
from typing import Any, Callable, Dict, List, Optional |
|
|
|
from ultralytics.yolo.engine.model import YOLO |
|
from ultralytics.yolo.engine.trainer import BaseTrainer |
|
|
|
try: |
|
from ultralytics.yolo.utils import RANK |
|
from ultralytics.yolo.utils.torch_utils import get_flops, get_num_params |
|
except ModuleNotFoundError: |
|
from ultralytics.utils import RANK |
|
from ultralytics.utils.torch_utils import get_flops, get_num_params |
|
from ultralytics.yolo.v8.classify.train import ClassificationTrainer |
|
|
|
import wandb |
|
from wandb.sdk.lib import telemetry |
|
|
|
|
|
class WandbCallback: |
|
"""An internal YOLO model wrapper that tracks metrics, and logs models to Weights & Biases. |
|
|
|
Usage: |
|
```python |
|
from wandb.integration.yolov8.yolov8 import WandbCallback |
|
|
|
model = YOLO("yolov8n.pt") |
|
wandb_logger = WandbCallback( |
|
model, |
|
) |
|
for event, callback_fn in wandb_logger.callbacks.items(): |
|
model.add_callback(event, callback_fn) |
|
``` |
|
""" |
|
|
|
def __init__( |
|
self, |
|
yolo: YOLO, |
|
run_name: Optional[str] = None, |
|
project: Optional[str] = None, |
|
tags: Optional[List[str]] = None, |
|
resume: Optional[str] = None, |
|
**kwargs: Optional[Any], |
|
) -> None: |
|
"""A utility class to manage wandb run and various callbacks for the ultralytics YOLOv8 framework. |
|
|
|
Args: |
|
yolo: A YOLOv8 model that's inherited from `:class:ultralytics.yolo.engine.model.YOLO` |
|
run_name, str: The name of the Weights & Biases run, defaults to an auto generated run_name if `trainer.args.name` is not defined. |
|
project, str: The name of the Weights & Biases project, defaults to `"YOLOv8"` if `trainer.args.project` is not defined. |
|
tags, List[str]: A list of tags to be added to the Weights & Biases run, defaults to `["YOLOv8"]`. |
|
resume, str: Whether to resume a previous run on Weights & Biases, defaults to `None`. |
|
**kwargs: Additional arguments to be passed to `wandb.init()`. |
|
""" |
|
self.yolo = yolo |
|
self.run_name = run_name |
|
self.project = project |
|
self.tags = tags |
|
self.resume = resume |
|
self.kwargs = kwargs |
|
|
|
def on_pretrain_routine_start(self, trainer: BaseTrainer) -> None: |
|
"""Starts a new wandb run to track the training process and log to Weights & Biases. |
|
|
|
Args: |
|
trainer: A task trainer that's inherited from `:class:ultralytics.yolo.engine.trainer.BaseTrainer` |
|
that contains the model training and optimization routine. |
|
""" |
|
if wandb.run is None: |
|
self.run = wandb.init( |
|
name=self.run_name if self.run_name else trainer.args.name, |
|
project=self.project |
|
if self.project |
|
else trainer.args.project or "YOLOv8", |
|
tags=self.tags if self.tags else ["YOLOv8"], |
|
config=vars(trainer.args), |
|
resume=self.resume if self.resume else None, |
|
**self.kwargs, |
|
) |
|
else: |
|
self.run = wandb.run |
|
assert self.run is not None |
|
self.run.define_metric("epoch", hidden=True) |
|
self.run.define_metric( |
|
"train/*", step_metric="epoch", step_sync=True, summary="min" |
|
) |
|
|
|
self.run.define_metric( |
|
"val/*", step_metric="epoch", step_sync=True, summary="min" |
|
) |
|
|
|
self.run.define_metric( |
|
"metrics/*", step_metric="epoch", step_sync=True, summary="max" |
|
) |
|
self.run.define_metric( |
|
"lr/*", step_metric="epoch", step_sync=True, summary="last" |
|
) |
|
|
|
with telemetry.context(run=wandb.run) as tel: |
|
tel.feature.ultralytics_yolov8 = True |
|
|
|
def on_pretrain_routine_end(self, trainer: BaseTrainer) -> None: |
|
assert self.run is not None |
|
self.run.summary.update( |
|
{ |
|
"model/parameters": get_num_params(trainer.model), |
|
"model/GFLOPs": round(get_flops(trainer.model), 3), |
|
} |
|
) |
|
|
|
def on_train_epoch_start(self, trainer: BaseTrainer) -> None: |
|
"""On train epoch start we only log epoch number to the Weights & Biases run.""" |
|
|
|
assert self.run is not None |
|
self.run.log({"epoch": trainer.epoch + 1}) |
|
|
|
def on_train_epoch_end(self, trainer: BaseTrainer) -> None: |
|
"""On train epoch end we log all the metrics to the Weights & Biases run.""" |
|
assert self.run is not None |
|
self.run.log( |
|
{ |
|
**trainer.metrics, |
|
**trainer.label_loss_items(trainer.tloss, prefix="train"), |
|
**trainer.lr, |
|
}, |
|
) |
|
|
|
if not isinstance(trainer, ClassificationTrainer): |
|
self.run.log( |
|
{ |
|
"train_batch_images": [ |
|
wandb.Image(str(image_path), caption=image_path.stem) |
|
for image_path in trainer.save_dir.glob("train_batch*.jpg") |
|
] |
|
} |
|
) |
|
|
|
def on_fit_epoch_end(self, trainer: BaseTrainer) -> None: |
|
"""On fit epoch end we log all the best metrics and model detail to Weights & Biases run summary.""" |
|
assert self.run is not None |
|
if trainer.epoch == 0: |
|
speeds = [ |
|
trainer.validator.speed.get( |
|
key, |
|
) |
|
for key in (1, "inference") |
|
] |
|
speed = speeds[0] if speeds[0] else speeds[1] |
|
if speed: |
|
self.run.summary.update( |
|
{ |
|
"model/speed(ms/img)": round(speed, 3), |
|
} |
|
) |
|
if trainer.best_fitness == trainer.fitness: |
|
self.run.summary.update( |
|
{ |
|
"best/epoch": trainer.epoch + 1, |
|
**{f"best/{key}": val for key, val in trainer.metrics.items()}, |
|
} |
|
) |
|
|
|
def on_train_end(self, trainer: BaseTrainer) -> None: |
|
"""On train end we log all the media, including plots, images and best model artifact to Weights & Biases.""" |
|
|
|
assert self.run is not None |
|
if not isinstance(trainer, ClassificationTrainer): |
|
assert self.run is not None |
|
self.run.log( |
|
{ |
|
"plots": [ |
|
wandb.Image(str(image_path), caption=image_path.stem) |
|
for image_path in trainer.save_dir.glob("*.png") |
|
], |
|
"val_images": [ |
|
wandb.Image(str(image_path), caption=image_path.stem) |
|
for image_path in trainer.validator.save_dir.glob("val*.jpg") |
|
], |
|
}, |
|
) |
|
|
|
if trainer.best.exists(): |
|
assert self.run is not None |
|
self.run.log_artifact( |
|
str(trainer.best), |
|
type="model", |
|
name=f"{self.run.name}_{trainer.args.task}.pt", |
|
aliases=["best", f"epoch_{trainer.epoch + 1}"], |
|
) |
|
|
|
def on_model_save(self, trainer: BaseTrainer) -> None: |
|
"""On model save we log the model as an artifact to Weights & Biases.""" |
|
assert self.run is not None |
|
self.run.log_artifact( |
|
str(trainer.last), |
|
type="model", |
|
name=f"{self.run.name}_{trainer.args.task}.pt", |
|
aliases=["last", f"epoch_{trainer.epoch + 1}"], |
|
) |
|
|
|
def teardown(self, _trainer: BaseTrainer) -> None: |
|
"""On teardown, we finish the Weights & Biases run and set it to None.""" |
|
assert self.run is not None |
|
self.run.finish() |
|
self.run = None |
|
|
|
@property |
|
def callbacks( |
|
self, |
|
) -> Dict[str, Callable]: |
|
"""Property contains all the relevant callbacks to add to the YOLO model for the Weights & Biases logging.""" |
|
return { |
|
"on_pretrain_routine_start": self.on_pretrain_routine_start, |
|
"on_pretrain_routine_end": self.on_pretrain_routine_end, |
|
"on_train_epoch_start": self.on_train_epoch_start, |
|
"on_train_epoch_end": self.on_train_epoch_end, |
|
"on_fit_epoch_end": self.on_fit_epoch_end, |
|
"on_train_end": self.on_train_end, |
|
"on_model_save": self.on_model_save, |
|
"teardown": self.teardown, |
|
} |
|
|
|
|
|
def add_callbacks( |
|
yolo: YOLO, |
|
run_name: Optional[str] = None, |
|
project: Optional[str] = None, |
|
tags: Optional[List[str]] = None, |
|
resume: Optional[str] = None, |
|
**kwargs: Optional[Any], |
|
) -> YOLO: |
|
"""A YOLO model wrapper that tracks metrics, and logs models to Weights & Biases. |
|
|
|
Args: |
|
yolo: A YOLOv8 model that's inherited from `:class:ultralytics.yolo.engine.model.YOLO` |
|
run_name, str: The name of the Weights & Biases run, defaults to an auto generated name if `trainer.args.name` is not defined. |
|
project, str: The name of the Weights & Biases project, defaults to `"YOLOv8"` if `trainer.args.project` is not defined. |
|
tags, List[str]: A list of tags to be added to the Weights & Biases run, defaults to `["YOLOv8"]`. |
|
resume, str: Whether to resume a previous run on Weights & Biases, defaults to `None`. |
|
**kwargs: Additional arguments to be passed to `wandb.init()`. |
|
|
|
Usage: |
|
```python |
|
from wandb.integration.yolov8 import add_callbacks as add_wandb_callbacks |
|
|
|
model = YOLO("yolov8n.pt") |
|
add_wandb_callbacks( |
|
model, |
|
) |
|
model.train( |
|
data="coco128.yaml", |
|
epochs=3, |
|
imgsz=640, |
|
) |
|
``` |
|
""" |
|
wandb.termwarn( |
|
"""The wandb callback is currently in beta and is subject to change based on updates to `ultralytics yolov8`. |
|
The callback is tested and supported for ultralytics v8.0.43 and above. |
|
Please report any issues to https://github.com/wandb/wandb/issues with the tag `yolov8`. |
|
""", |
|
repeat=False, |
|
) |
|
wandb.termwarn( |
|
"""This wandb callback is no longer functional and would be deprecated in the near future. |
|
We recommend you to use the updated callback using `from wandb.integration.ultralytics import add_wandb_callback`. |
|
The updated callback is tested and supported for ultralytics 8.0.167 and above. |
|
You can refer to https://docs.wandb.ai/guides/integrations/ultralytics for the updated documentation. |
|
Please report any issues to https://github.com/wandb/wandb/issues with the tag `yolov8`. |
|
""", |
|
repeat=False, |
|
) |
|
|
|
if RANK in [-1, 0]: |
|
wandb_logger = WandbCallback( |
|
yolo, run_name=run_name, project=project, tags=tags, resume=resume, **kwargs |
|
) |
|
for event, callback_fn in wandb_logger.callbacks.items(): |
|
yolo.add_callback(event, callback_fn) |
|
return yolo |
|
else: |
|
wandb.termerror( |
|
"The RANK of the process to add the callbacks was neither 0 or -1." |
|
"No Weights & Biases callbacks were added to this instance of the YOLO model." |
|
) |
|
return yolo |
|
|