|
import copy |
|
from datetime import datetime |
|
from typing import Callable, Dict, Optional, Union |
|
|
|
from packaging import version |
|
|
|
try: |
|
import dill as pickle |
|
except ImportError: |
|
import pickle |
|
|
|
import wandb |
|
from wandb.sdk.lib import telemetry |
|
|
|
try: |
|
import torch |
|
import ultralytics |
|
from tqdm.auto import tqdm |
|
|
|
if version.parse(ultralytics.__version__) > version.parse("8.0.238"): |
|
wandb.termwarn( |
|
"""This integration is tested and supported for ultralytics v8.0.238 and below. |
|
Please report any issues to https://github.com/wandb/wandb/issues with the tag `yolov8`.""", |
|
repeat=False, |
|
) |
|
|
|
from ultralytics.models import YOLO |
|
from ultralytics.models.sam.predict import Predictor as SAMPredictor |
|
from ultralytics.models.yolo.classify import ( |
|
ClassificationPredictor, |
|
ClassificationTrainer, |
|
ClassificationValidator, |
|
) |
|
from ultralytics.models.yolo.detect import ( |
|
DetectionPredictor, |
|
DetectionTrainer, |
|
DetectionValidator, |
|
) |
|
from ultralytics.models.yolo.pose import PosePredictor, PoseTrainer, PoseValidator |
|
from ultralytics.models.yolo.segment import ( |
|
SegmentationPredictor, |
|
SegmentationTrainer, |
|
SegmentationValidator, |
|
) |
|
from ultralytics.utils.torch_utils import de_parallel |
|
|
|
try: |
|
from ultralytics.yolo.utils import RANK, __version__ |
|
except ModuleNotFoundError: |
|
from ultralytics.utils import RANK, __version__ |
|
|
|
from wandb.integration.ultralytics.bbox_utils import ( |
|
plot_bbox_predictions, |
|
plot_detection_validation_results, |
|
) |
|
from wandb.integration.ultralytics.classification_utils import ( |
|
plot_classification_predictions, |
|
plot_classification_validation_results, |
|
) |
|
from wandb.integration.ultralytics.mask_utils import ( |
|
plot_mask_predictions, |
|
plot_sam_predictions, |
|
plot_segmentation_validation_results, |
|
) |
|
from wandb.integration.ultralytics.pose_utils import ( |
|
plot_pose_predictions, |
|
plot_pose_validation_results, |
|
) |
|
except Exception as e: |
|
wandb.Error(e) |
|
|
|
|
|
TRAINER_TYPE = Union[ |
|
ClassificationTrainer, DetectionTrainer, SegmentationTrainer, PoseTrainer |
|
] |
|
VALIDATOR_TYPE = Union[ |
|
ClassificationValidator, DetectionValidator, SegmentationValidator, PoseValidator |
|
] |
|
PREDICTOR_TYPE = Union[ |
|
ClassificationPredictor, |
|
DetectionPredictor, |
|
SegmentationPredictor, |
|
PosePredictor, |
|
SAMPredictor, |
|
] |
|
|
|
|
|
class WandBUltralyticsCallback: |
|
"""Stateful callback for logging to W&B. |
|
|
|
In particular, it will log model checkpoints, predictions, and |
|
ground-truth annotations with interactive overlays for bounding boxes |
|
to Weights & Biases Tables during training, validation and prediction |
|
for a `ultratytics` workflow. |
|
|
|
Example: |
|
```python |
|
from ultralytics.yolo.engine.model import YOLO |
|
from wandb.yolov8 import add_wandb_callback |
|
|
|
# initialize YOLO model |
|
model = YOLO("yolov8n.pt") |
|
|
|
# add wandb callback |
|
add_wandb_callback(model, max_validation_batches=2, enable_model_checkpointing=True) |
|
|
|
# train |
|
model.train(data="coco128.yaml", epochs=5, imgsz=640) |
|
|
|
# validate |
|
model.val() |
|
|
|
# perform inference |
|
model(["img1.jpeg", "img2.jpeg"]) |
|
``` |
|
|
|
Args: |
|
model: (ultralytics.yolo.engine.model.YOLO) YOLO Model of type |
|
`ultralytics.yolo.engine.model.YOLO`. |
|
epoch_logging_interval: (int) interval to log the prediction visualizations |
|
during training. |
|
max_validation_batches: (int) maximum number of validation batches to log to |
|
a table per epoch. |
|
enable_model_checkpointing: (bool) enable logging model checkpoints as |
|
artifacts at the end of eveny epoch if set to `True`. |
|
visualize_skeleton: (bool) visualize pose skeleton by drawing lines connecting |
|
keypoints for human pose. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
model: YOLO, |
|
epoch_logging_interval: int = 1, |
|
max_validation_batches: int = 1, |
|
enable_model_checkpointing: bool = False, |
|
visualize_skeleton: bool = False, |
|
) -> None: |
|
self.epoch_logging_interval = epoch_logging_interval |
|
self.max_validation_batches = max_validation_batches |
|
self.enable_model_checkpointing = enable_model_checkpointing |
|
self.visualize_skeleton = visualize_skeleton |
|
self.task = model.task |
|
self.task_map = model.task_map |
|
self.model_name = ( |
|
model.overrides["model"].split(".")[0] |
|
if "model" in model.overrides |
|
else None |
|
) |
|
self._make_tables() |
|
self._make_predictor(model) |
|
self.supported_tasks = ["detect", "segment", "pose", "classify"] |
|
self.prompts = None |
|
self.run_id = None |
|
self.train_epoch = None |
|
|
|
def _make_tables(self): |
|
if self.task in ["detect", "segment"]: |
|
validation_columns = [ |
|
"Data-Index", |
|
"Batch-Index", |
|
"Image", |
|
"Mean-Confidence", |
|
"Speed", |
|
] |
|
train_columns = ["Epoch"] + validation_columns |
|
self.train_validation_table = wandb.Table( |
|
columns=["Model-Name"] + train_columns |
|
) |
|
self.validation_table = wandb.Table( |
|
columns=["Model-Name"] + validation_columns |
|
) |
|
self.prediction_table = wandb.Table( |
|
columns=[ |
|
"Model-Name", |
|
"Image", |
|
"Num-Objects", |
|
"Mean-Confidence", |
|
"Speed", |
|
] |
|
) |
|
elif self.task == "classify": |
|
classification_columns = [ |
|
"Image", |
|
"Predicted-Category", |
|
"Prediction-Confidence", |
|
"Top-5-Prediction-Categories", |
|
"Top-5-Prediction-Confindence", |
|
"Probabilities", |
|
"Speed", |
|
] |
|
validation_columns = ["Data-Index", "Batch-Index"] + classification_columns |
|
validation_columns.insert(3, "Ground-Truth-Category") |
|
self.train_validation_table = wandb.Table( |
|
columns=["Model-Name", "Epoch"] + validation_columns |
|
) |
|
self.validation_table = wandb.Table( |
|
columns=["Model-Name"] + validation_columns |
|
) |
|
self.prediction_table = wandb.Table( |
|
columns=["Model-Name"] + classification_columns |
|
) |
|
elif self.task == "pose": |
|
validation_columns = [ |
|
"Data-Index", |
|
"Batch-Index", |
|
"Image-Ground-Truth", |
|
"Image-Prediction", |
|
"Num-Instances", |
|
"Mean-Confidence", |
|
"Speed", |
|
] |
|
train_columns = ["Epoch"] + validation_columns |
|
self.train_validation_table = wandb.Table( |
|
columns=["Model-Name"] + train_columns |
|
) |
|
self.validation_table = wandb.Table( |
|
columns=["Model-Name"] + validation_columns |
|
) |
|
self.prediction_table = wandb.Table( |
|
columns=[ |
|
"Model-Name", |
|
"Image-Prediction", |
|
"Num-Instances", |
|
"Mean-Confidence", |
|
"Speed", |
|
] |
|
) |
|
|
|
def _make_predictor(self, model: YOLO): |
|
overrides = copy.deepcopy(model.overrides) |
|
overrides["conf"] = 0.1 |
|
self.predictor = self.task_map[self.task]["predictor"](overrides=overrides) |
|
self.predictor.callbacks = {} |
|
self.predictor.args.save = False |
|
self.predictor.args.save_txt = False |
|
self.predictor.args.save_crop = False |
|
self.predictor.args.verbose = None |
|
|
|
def _save_model(self, trainer: TRAINER_TYPE): |
|
model_checkpoint_artifact = wandb.Artifact(f"run_{wandb.run.id}_model", "model") |
|
checkpoint_dict = { |
|
"epoch": trainer.epoch, |
|
"best_fitness": trainer.best_fitness, |
|
"model": copy.deepcopy(de_parallel(self.model)).half(), |
|
"ema": copy.deepcopy(trainer.ema.ema).half(), |
|
"updates": trainer.ema.updates, |
|
"optimizer": trainer.optimizer.state_dict(), |
|
"train_args": vars(trainer.args), |
|
"date": datetime.now().isoformat(), |
|
"version": __version__, |
|
} |
|
checkpoint_path = trainer.wdir / f"epoch{trainer.epoch}.pt" |
|
torch.save(checkpoint_dict, checkpoint_path, pickle_module=pickle) |
|
model_checkpoint_artifact.add_file(checkpoint_path) |
|
wandb.log_artifact( |
|
model_checkpoint_artifact, aliases=[f"epoch_{trainer.epoch}"] |
|
) |
|
|
|
def on_train_start(self, trainer: TRAINER_TYPE): |
|
with telemetry.context(run=wandb.run) as tel: |
|
tel.feature.ultralytics_yolov8 = True |
|
wandb.config.train = vars(trainer.args) |
|
self.run_id = wandb.run.id |
|
|
|
@torch.no_grad() |
|
def on_fit_epoch_end(self, trainer: DetectionTrainer): |
|
if self.task in self.supported_tasks and self.train_epoch != trainer.epoch: |
|
self.train_epoch = trainer.epoch |
|
if (self.train_epoch + 1) % self.epoch_logging_interval == 0: |
|
validator = trainer.validator |
|
dataloader = validator.dataloader |
|
class_label_map = validator.names |
|
self.device = next(trainer.model.parameters()).device |
|
if isinstance(trainer.model, torch.nn.parallel.DistributedDataParallel): |
|
model = trainer.model.module |
|
else: |
|
model = trainer.model |
|
self.model = copy.deepcopy(model).eval().to(self.device) |
|
self.predictor.setup_model(model=self.model, verbose=False) |
|
if self.task == "pose": |
|
self.train_validation_table = plot_pose_validation_results( |
|
dataloader=dataloader, |
|
class_label_map=class_label_map, |
|
model_name=self.model_name, |
|
predictor=self.predictor, |
|
visualize_skeleton=self.visualize_skeleton, |
|
table=self.train_validation_table, |
|
max_validation_batches=self.max_validation_batches, |
|
epoch=trainer.epoch, |
|
) |
|
elif self.task == "segment": |
|
self.train_validation_table = plot_segmentation_validation_results( |
|
dataloader=dataloader, |
|
class_label_map=class_label_map, |
|
model_name=self.model_name, |
|
predictor=self.predictor, |
|
table=self.train_validation_table, |
|
max_validation_batches=self.max_validation_batches, |
|
epoch=trainer.epoch, |
|
) |
|
elif self.task == "detect": |
|
self.train_validation_table = plot_detection_validation_results( |
|
dataloader=dataloader, |
|
class_label_map=class_label_map, |
|
model_name=self.model_name, |
|
predictor=self.predictor, |
|
table=self.train_validation_table, |
|
max_validation_batches=self.max_validation_batches, |
|
epoch=trainer.epoch, |
|
) |
|
elif self.task == "classify": |
|
self.train_validation_table = ( |
|
plot_classification_validation_results( |
|
dataloader=dataloader, |
|
model_name=self.model_name, |
|
predictor=self.predictor, |
|
table=self.train_validation_table, |
|
max_validation_batches=self.max_validation_batches, |
|
epoch=trainer.epoch, |
|
) |
|
) |
|
if self.enable_model_checkpointing: |
|
self._save_model(trainer) |
|
trainer.model.to(self.device) |
|
|
|
def on_train_end(self, trainer: TRAINER_TYPE): |
|
if self.task in self.supported_tasks: |
|
wandb.log({"Train-Table": self.train_validation_table}, commit=False) |
|
|
|
def on_val_start(self, validator: VALIDATOR_TYPE): |
|
wandb.run or wandb.init( |
|
project=validator.args.project or "YOLOv8", |
|
job_type="validation_" + validator.args.task, |
|
) |
|
|
|
@torch.no_grad() |
|
def on_val_end(self, trainer: VALIDATOR_TYPE): |
|
if self.task in self.supported_tasks: |
|
validator = trainer |
|
dataloader = validator.dataloader |
|
class_label_map = validator.names |
|
if self.task == "pose": |
|
self.validation_table = plot_pose_validation_results( |
|
dataloader=dataloader, |
|
class_label_map=class_label_map, |
|
model_name=self.model_name, |
|
predictor=self.predictor, |
|
visualize_skeleton=self.visualize_skeleton, |
|
table=self.validation_table, |
|
max_validation_batches=self.max_validation_batches, |
|
) |
|
elif self.task == "segment": |
|
self.validation_table = plot_segmentation_validation_results( |
|
dataloader=dataloader, |
|
class_label_map=class_label_map, |
|
model_name=self.model_name, |
|
predictor=self.predictor, |
|
table=self.validation_table, |
|
max_validation_batches=self.max_validation_batches, |
|
) |
|
elif self.task == "detect": |
|
self.validation_table = plot_detection_validation_results( |
|
dataloader=dataloader, |
|
class_label_map=class_label_map, |
|
model_name=self.model_name, |
|
predictor=self.predictor, |
|
table=self.validation_table, |
|
max_validation_batches=self.max_validation_batches, |
|
) |
|
elif self.task == "classify": |
|
self.validation_table = plot_classification_validation_results( |
|
dataloader=dataloader, |
|
model_name=self.model_name, |
|
predictor=self.predictor, |
|
table=self.validation_table, |
|
max_validation_batches=self.max_validation_batches, |
|
) |
|
wandb.log({"Validation-Table": self.validation_table}, commit=False) |
|
|
|
def on_predict_start(self, predictor: PREDICTOR_TYPE): |
|
wandb.run or wandb.init( |
|
project=predictor.args.project or "YOLOv8", |
|
config=vars(predictor.args), |
|
job_type="prediction_" + predictor.args.task, |
|
) |
|
if isinstance(predictor, SAMPredictor): |
|
self.prompts = copy.deepcopy(predictor.prompts) |
|
self.prediction_table = wandb.Table(columns=["Image"]) |
|
|
|
def on_predict_end(self, predictor: PREDICTOR_TYPE): |
|
wandb.config.prediction_configs = vars(predictor.args) |
|
if self.task in self.supported_tasks: |
|
for result in tqdm(predictor.results): |
|
if self.task == "pose": |
|
self.prediction_table = plot_pose_predictions( |
|
result, |
|
self.model_name, |
|
self.visualize_skeleton, |
|
self.prediction_table, |
|
) |
|
elif self.task == "segment": |
|
if isinstance(predictor, SegmentationPredictor): |
|
self.prediction_table = plot_mask_predictions( |
|
result, self.model_name, self.prediction_table |
|
) |
|
elif isinstance(predictor, SAMPredictor): |
|
self.prediction_table = plot_sam_predictions( |
|
result, self.prompts, self.prediction_table |
|
) |
|
elif self.task == "detect": |
|
self.prediction_table = plot_bbox_predictions( |
|
result, self.model_name, self.prediction_table |
|
) |
|
elif self.task == "classify": |
|
self.prediction_table = plot_classification_predictions( |
|
result, self.model_name, self.prediction_table |
|
) |
|
|
|
wandb.log({"Prediction-Table": self.prediction_table}, commit=False) |
|
|
|
@property |
|
def callbacks(self) -> Dict[str, Callable]: |
|
"""Property contains all the relevant callbacks to add to the YOLO model for the Weights & Biases logging.""" |
|
return { |
|
"on_train_start": self.on_train_start, |
|
"on_fit_epoch_end": self.on_fit_epoch_end, |
|
"on_train_end": self.on_train_end, |
|
"on_val_start": self.on_val_start, |
|
"on_val_end": self.on_val_end, |
|
"on_predict_start": self.on_predict_start, |
|
"on_predict_end": self.on_predict_end, |
|
} |
|
|
|
|
|
|
|
def add_wandb_callback( |
|
model: YOLO, |
|
epoch_logging_interval: int = 1, |
|
enable_model_checkpointing: bool = False, |
|
enable_train_validation_logging: bool = True, |
|
enable_validation_logging: bool = True, |
|
enable_prediction_logging: bool = True, |
|
max_validation_batches: Optional[int] = 1, |
|
visualize_skeleton: Optional[bool] = True, |
|
): |
|
"""Function to add the `WandBUltralyticsCallback` callback to the `YOLO` model. |
|
|
|
Example: |
|
```python |
|
from ultralytics.yolo.engine.model import YOLO |
|
from wandb.yolov8 import add_wandb_callback |
|
|
|
# initialize YOLO model |
|
model = YOLO("yolov8n.pt") |
|
|
|
# add wandb callback |
|
add_wandb_callback(model, max_validation_batches=2, enable_model_checkpointing=True) |
|
|
|
# train |
|
model.train(data="coco128.yaml", epochs=5, imgsz=640) |
|
|
|
# validate |
|
model.val() |
|
|
|
# perform inference |
|
model(["img1.jpeg", "img2.jpeg"]) |
|
``` |
|
|
|
Args: |
|
model: (ultralytics.yolo.engine.model.YOLO) YOLO Model of type |
|
`ultralytics.yolo.engine.model.YOLO`. |
|
epoch_logging_interval: (int) interval to log the prediction visualizations |
|
during training. |
|
enable_model_checkpointing: (bool) enable logging model checkpoints as |
|
artifacts at the end of eveny epoch if set to `True`. |
|
enable_train_validation_logging: (bool) enable logging the predictions and |
|
ground-truths as interactive image overlays on the images from |
|
the validation dataloader to a `wandb.Table` along with |
|
mean-confidence of the predictions per-class at the end of each |
|
training epoch. |
|
enable_validation_logging: (bool) enable logging the predictions and |
|
ground-truths as interactive image overlays on the images from the |
|
validation dataloader to a `wandb.Table` along with |
|
mean-confidence of the predictions per-class at the end of |
|
validation. |
|
enable_prediction_logging: (bool) enable logging the predictions and |
|
ground-truths as interactive image overlays on the images from the |
|
validation dataloader to a `wandb.Table` along with mean-confidence |
|
of the predictions per-class at the end of each prediction. |
|
max_validation_batches: (Optional[int]) maximum number of validation batches to log to |
|
a table per epoch. |
|
visualize_skeleton: (Optional[bool]) visualize pose skeleton by drawing lines connecting |
|
keypoints for human pose. |
|
|
|
Returns: |
|
An instance of `ultralytics.yolo.engine.model.YOLO` with the `WandBUltralyticsCallback`. |
|
""" |
|
if RANK in [-1, 0]: |
|
wandb_callback = WandBUltralyticsCallback( |
|
copy.deepcopy(model), |
|
epoch_logging_interval, |
|
max_validation_batches, |
|
enable_model_checkpointing, |
|
visualize_skeleton, |
|
) |
|
callbacks = wandb_callback.callbacks |
|
if not enable_train_validation_logging: |
|
_ = callbacks.pop("on_fit_epoch_end") |
|
_ = callbacks.pop("on_train_end") |
|
if not enable_validation_logging: |
|
_ = callbacks.pop("on_val_start") |
|
_ = callbacks.pop("on_val_end") |
|
if not enable_prediction_logging: |
|
_ = callbacks.pop("on_predict_start") |
|
_ = callbacks.pop("on_predict_end") |
|
for event, callback_fn in callbacks.items(): |
|
model.add_callback(event, callback_fn) |
|
else: |
|
wandb.termerror( |
|
"The RANK of the process to add the callbacks was neither 0 or " |
|
"-1. No Weights & Biases callbacks were added to this instance " |
|
"of the YOLO model." |
|
) |
|
return model |
|
|