jamtur01's picture
Upload folder using huggingface_hub
9c6594c verified
import copy
from datetime import datetime
from typing import Callable, Dict, Optional, Union
from packaging import version
try:
import dill as pickle
except ImportError:
import pickle
import wandb
from wandb.sdk.lib import telemetry
try:
import torch
import ultralytics
from tqdm.auto import tqdm
if version.parse(ultralytics.__version__) > version.parse("8.0.238"):
wandb.termwarn(
"""This integration is tested and supported for ultralytics v8.0.238 and below.
Please report any issues to https://github.com/wandb/wandb/issues with the tag `yolov8`.""",
repeat=False,
)
from ultralytics.models import YOLO
from ultralytics.models.sam.predict import Predictor as SAMPredictor
from ultralytics.models.yolo.classify import (
ClassificationPredictor,
ClassificationTrainer,
ClassificationValidator,
)
from ultralytics.models.yolo.detect import (
DetectionPredictor,
DetectionTrainer,
DetectionValidator,
)
from ultralytics.models.yolo.pose import PosePredictor, PoseTrainer, PoseValidator
from ultralytics.models.yolo.segment import (
SegmentationPredictor,
SegmentationTrainer,
SegmentationValidator,
)
from ultralytics.utils.torch_utils import de_parallel
try:
from ultralytics.yolo.utils import RANK, __version__
except ModuleNotFoundError:
from ultralytics.utils import RANK, __version__
from wandb.integration.ultralytics.bbox_utils import (
plot_bbox_predictions,
plot_detection_validation_results,
)
from wandb.integration.ultralytics.classification_utils import (
plot_classification_predictions,
plot_classification_validation_results,
)
from wandb.integration.ultralytics.mask_utils import (
plot_mask_predictions,
plot_sam_predictions,
plot_segmentation_validation_results,
)
from wandb.integration.ultralytics.pose_utils import (
plot_pose_predictions,
plot_pose_validation_results,
)
except Exception as e:
wandb.Error(e)
TRAINER_TYPE = Union[
ClassificationTrainer, DetectionTrainer, SegmentationTrainer, PoseTrainer
]
VALIDATOR_TYPE = Union[
ClassificationValidator, DetectionValidator, SegmentationValidator, PoseValidator
]
PREDICTOR_TYPE = Union[
ClassificationPredictor,
DetectionPredictor,
SegmentationPredictor,
PosePredictor,
SAMPredictor,
]
class WandBUltralyticsCallback:
"""Stateful callback for logging to W&B.
In particular, it will log model checkpoints, predictions, and
ground-truth annotations with interactive overlays for bounding boxes
to Weights & Biases Tables during training, validation and prediction
for a `ultratytics` workflow.
Example:
```python
from ultralytics.yolo.engine.model import YOLO
from wandb.yolov8 import add_wandb_callback
# initialize YOLO model
model = YOLO("yolov8n.pt")
# add wandb callback
add_wandb_callback(model, max_validation_batches=2, enable_model_checkpointing=True)
# train
model.train(data="coco128.yaml", epochs=5, imgsz=640)
# validate
model.val()
# perform inference
model(["img1.jpeg", "img2.jpeg"])
```
Args:
model: (ultralytics.yolo.engine.model.YOLO) YOLO Model of type
`ultralytics.yolo.engine.model.YOLO`.
epoch_logging_interval: (int) interval to log the prediction visualizations
during training.
max_validation_batches: (int) maximum number of validation batches to log to
a table per epoch.
enable_model_checkpointing: (bool) enable logging model checkpoints as
artifacts at the end of eveny epoch if set to `True`.
visualize_skeleton: (bool) visualize pose skeleton by drawing lines connecting
keypoints for human pose.
"""
def __init__(
self,
model: YOLO,
epoch_logging_interval: int = 1,
max_validation_batches: int = 1,
enable_model_checkpointing: bool = False,
visualize_skeleton: bool = False,
) -> None:
self.epoch_logging_interval = epoch_logging_interval
self.max_validation_batches = max_validation_batches
self.enable_model_checkpointing = enable_model_checkpointing
self.visualize_skeleton = visualize_skeleton
self.task = model.task
self.task_map = model.task_map
self.model_name = (
model.overrides["model"].split(".")[0]
if "model" in model.overrides
else None
)
self._make_tables()
self._make_predictor(model)
self.supported_tasks = ["detect", "segment", "pose", "classify"]
self.prompts = None
self.run_id = None
self.train_epoch = None
def _make_tables(self):
if self.task in ["detect", "segment"]:
validation_columns = [
"Data-Index",
"Batch-Index",
"Image",
"Mean-Confidence",
"Speed",
]
train_columns = ["Epoch"] + validation_columns
self.train_validation_table = wandb.Table(
columns=["Model-Name"] + train_columns
)
self.validation_table = wandb.Table(
columns=["Model-Name"] + validation_columns
)
self.prediction_table = wandb.Table(
columns=[
"Model-Name",
"Image",
"Num-Objects",
"Mean-Confidence",
"Speed",
]
)
elif self.task == "classify":
classification_columns = [
"Image",
"Predicted-Category",
"Prediction-Confidence",
"Top-5-Prediction-Categories",
"Top-5-Prediction-Confindence",
"Probabilities",
"Speed",
]
validation_columns = ["Data-Index", "Batch-Index"] + classification_columns
validation_columns.insert(3, "Ground-Truth-Category")
self.train_validation_table = wandb.Table(
columns=["Model-Name", "Epoch"] + validation_columns
)
self.validation_table = wandb.Table(
columns=["Model-Name"] + validation_columns
)
self.prediction_table = wandb.Table(
columns=["Model-Name"] + classification_columns
)
elif self.task == "pose":
validation_columns = [
"Data-Index",
"Batch-Index",
"Image-Ground-Truth",
"Image-Prediction",
"Num-Instances",
"Mean-Confidence",
"Speed",
]
train_columns = ["Epoch"] + validation_columns
self.train_validation_table = wandb.Table(
columns=["Model-Name"] + train_columns
)
self.validation_table = wandb.Table(
columns=["Model-Name"] + validation_columns
)
self.prediction_table = wandb.Table(
columns=[
"Model-Name",
"Image-Prediction",
"Num-Instances",
"Mean-Confidence",
"Speed",
]
)
def _make_predictor(self, model: YOLO):
overrides = copy.deepcopy(model.overrides)
overrides["conf"] = 0.1
self.predictor = self.task_map[self.task]["predictor"](overrides=overrides)
self.predictor.callbacks = {}
self.predictor.args.save = False
self.predictor.args.save_txt = False
self.predictor.args.save_crop = False
self.predictor.args.verbose = None
def _save_model(self, trainer: TRAINER_TYPE):
model_checkpoint_artifact = wandb.Artifact(f"run_{wandb.run.id}_model", "model")
checkpoint_dict = {
"epoch": trainer.epoch,
"best_fitness": trainer.best_fitness,
"model": copy.deepcopy(de_parallel(self.model)).half(),
"ema": copy.deepcopy(trainer.ema.ema).half(),
"updates": trainer.ema.updates,
"optimizer": trainer.optimizer.state_dict(),
"train_args": vars(trainer.args),
"date": datetime.now().isoformat(),
"version": __version__,
}
checkpoint_path = trainer.wdir / f"epoch{trainer.epoch}.pt"
torch.save(checkpoint_dict, checkpoint_path, pickle_module=pickle)
model_checkpoint_artifact.add_file(checkpoint_path)
wandb.log_artifact(
model_checkpoint_artifact, aliases=[f"epoch_{trainer.epoch}"]
)
def on_train_start(self, trainer: TRAINER_TYPE):
with telemetry.context(run=wandb.run) as tel:
tel.feature.ultralytics_yolov8 = True
wandb.config.train = vars(trainer.args)
self.run_id = wandb.run.id
@torch.no_grad()
def on_fit_epoch_end(self, trainer: DetectionTrainer):
if self.task in self.supported_tasks and self.train_epoch != trainer.epoch:
self.train_epoch = trainer.epoch
if (self.train_epoch + 1) % self.epoch_logging_interval == 0:
validator = trainer.validator
dataloader = validator.dataloader
class_label_map = validator.names
self.device = next(trainer.model.parameters()).device
if isinstance(trainer.model, torch.nn.parallel.DistributedDataParallel):
model = trainer.model.module
else:
model = trainer.model
self.model = copy.deepcopy(model).eval().to(self.device)
self.predictor.setup_model(model=self.model, verbose=False)
if self.task == "pose":
self.train_validation_table = plot_pose_validation_results(
dataloader=dataloader,
class_label_map=class_label_map,
model_name=self.model_name,
predictor=self.predictor,
visualize_skeleton=self.visualize_skeleton,
table=self.train_validation_table,
max_validation_batches=self.max_validation_batches,
epoch=trainer.epoch,
)
elif self.task == "segment":
self.train_validation_table = plot_segmentation_validation_results(
dataloader=dataloader,
class_label_map=class_label_map,
model_name=self.model_name,
predictor=self.predictor,
table=self.train_validation_table,
max_validation_batches=self.max_validation_batches,
epoch=trainer.epoch,
)
elif self.task == "detect":
self.train_validation_table = plot_detection_validation_results(
dataloader=dataloader,
class_label_map=class_label_map,
model_name=self.model_name,
predictor=self.predictor,
table=self.train_validation_table,
max_validation_batches=self.max_validation_batches,
epoch=trainer.epoch,
)
elif self.task == "classify":
self.train_validation_table = (
plot_classification_validation_results(
dataloader=dataloader,
model_name=self.model_name,
predictor=self.predictor,
table=self.train_validation_table,
max_validation_batches=self.max_validation_batches,
epoch=trainer.epoch,
)
)
if self.enable_model_checkpointing:
self._save_model(trainer)
trainer.model.to(self.device)
def on_train_end(self, trainer: TRAINER_TYPE):
if self.task in self.supported_tasks:
wandb.log({"Train-Table": self.train_validation_table}, commit=False)
def on_val_start(self, validator: VALIDATOR_TYPE):
wandb.run or wandb.init(
project=validator.args.project or "YOLOv8",
job_type="validation_" + validator.args.task,
)
@torch.no_grad()
def on_val_end(self, trainer: VALIDATOR_TYPE):
if self.task in self.supported_tasks:
validator = trainer
dataloader = validator.dataloader
class_label_map = validator.names
if self.task == "pose":
self.validation_table = plot_pose_validation_results(
dataloader=dataloader,
class_label_map=class_label_map,
model_name=self.model_name,
predictor=self.predictor,
visualize_skeleton=self.visualize_skeleton,
table=self.validation_table,
max_validation_batches=self.max_validation_batches,
)
elif self.task == "segment":
self.validation_table = plot_segmentation_validation_results(
dataloader=dataloader,
class_label_map=class_label_map,
model_name=self.model_name,
predictor=self.predictor,
table=self.validation_table,
max_validation_batches=self.max_validation_batches,
)
elif self.task == "detect":
self.validation_table = plot_detection_validation_results(
dataloader=dataloader,
class_label_map=class_label_map,
model_name=self.model_name,
predictor=self.predictor,
table=self.validation_table,
max_validation_batches=self.max_validation_batches,
)
elif self.task == "classify":
self.validation_table = plot_classification_validation_results(
dataloader=dataloader,
model_name=self.model_name,
predictor=self.predictor,
table=self.validation_table,
max_validation_batches=self.max_validation_batches,
)
wandb.log({"Validation-Table": self.validation_table}, commit=False)
def on_predict_start(self, predictor: PREDICTOR_TYPE):
wandb.run or wandb.init(
project=predictor.args.project or "YOLOv8",
config=vars(predictor.args),
job_type="prediction_" + predictor.args.task,
)
if isinstance(predictor, SAMPredictor):
self.prompts = copy.deepcopy(predictor.prompts)
self.prediction_table = wandb.Table(columns=["Image"])
def on_predict_end(self, predictor: PREDICTOR_TYPE):
wandb.config.prediction_configs = vars(predictor.args)
if self.task in self.supported_tasks:
for result in tqdm(predictor.results):
if self.task == "pose":
self.prediction_table = plot_pose_predictions(
result,
self.model_name,
self.visualize_skeleton,
self.prediction_table,
)
elif self.task == "segment":
if isinstance(predictor, SegmentationPredictor):
self.prediction_table = plot_mask_predictions(
result, self.model_name, self.prediction_table
)
elif isinstance(predictor, SAMPredictor):
self.prediction_table = plot_sam_predictions(
result, self.prompts, self.prediction_table
)
elif self.task == "detect":
self.prediction_table = plot_bbox_predictions(
result, self.model_name, self.prediction_table
)
elif self.task == "classify":
self.prediction_table = plot_classification_predictions(
result, self.model_name, self.prediction_table
)
wandb.log({"Prediction-Table": self.prediction_table}, commit=False)
@property
def callbacks(self) -> Dict[str, Callable]:
"""Property contains all the relevant callbacks to add to the YOLO model for the Weights & Biases logging."""
return {
"on_train_start": self.on_train_start,
"on_fit_epoch_end": self.on_fit_epoch_end,
"on_train_end": self.on_train_end,
"on_val_start": self.on_val_start,
"on_val_end": self.on_val_end,
"on_predict_start": self.on_predict_start,
"on_predict_end": self.on_predict_end,
}
# TODO: Add epoch interval
def add_wandb_callback(
model: YOLO,
epoch_logging_interval: int = 1,
enable_model_checkpointing: bool = False,
enable_train_validation_logging: bool = True,
enable_validation_logging: bool = True,
enable_prediction_logging: bool = True,
max_validation_batches: Optional[int] = 1,
visualize_skeleton: Optional[bool] = True,
):
"""Function to add the `WandBUltralyticsCallback` callback to the `YOLO` model.
Example:
```python
from ultralytics.yolo.engine.model import YOLO
from wandb.yolov8 import add_wandb_callback
# initialize YOLO model
model = YOLO("yolov8n.pt")
# add wandb callback
add_wandb_callback(model, max_validation_batches=2, enable_model_checkpointing=True)
# train
model.train(data="coco128.yaml", epochs=5, imgsz=640)
# validate
model.val()
# perform inference
model(["img1.jpeg", "img2.jpeg"])
```
Args:
model: (ultralytics.yolo.engine.model.YOLO) YOLO Model of type
`ultralytics.yolo.engine.model.YOLO`.
epoch_logging_interval: (int) interval to log the prediction visualizations
during training.
enable_model_checkpointing: (bool) enable logging model checkpoints as
artifacts at the end of eveny epoch if set to `True`.
enable_train_validation_logging: (bool) enable logging the predictions and
ground-truths as interactive image overlays on the images from
the validation dataloader to a `wandb.Table` along with
mean-confidence of the predictions per-class at the end of each
training epoch.
enable_validation_logging: (bool) enable logging the predictions and
ground-truths as interactive image overlays on the images from the
validation dataloader to a `wandb.Table` along with
mean-confidence of the predictions per-class at the end of
validation.
enable_prediction_logging: (bool) enable logging the predictions and
ground-truths as interactive image overlays on the images from the
validation dataloader to a `wandb.Table` along with mean-confidence
of the predictions per-class at the end of each prediction.
max_validation_batches: (Optional[int]) maximum number of validation batches to log to
a table per epoch.
visualize_skeleton: (Optional[bool]) visualize pose skeleton by drawing lines connecting
keypoints for human pose.
Returns:
An instance of `ultralytics.yolo.engine.model.YOLO` with the `WandBUltralyticsCallback`.
"""
if RANK in [-1, 0]:
wandb_callback = WandBUltralyticsCallback(
copy.deepcopy(model),
epoch_logging_interval,
max_validation_batches,
enable_model_checkpointing,
visualize_skeleton,
)
callbacks = wandb_callback.callbacks
if not enable_train_validation_logging:
_ = callbacks.pop("on_fit_epoch_end")
_ = callbacks.pop("on_train_end")
if not enable_validation_logging:
_ = callbacks.pop("on_val_start")
_ = callbacks.pop("on_val_end")
if not enable_prediction_logging:
_ = callbacks.pop("on_predict_start")
_ = callbacks.pop("on_predict_end")
for event, callback_fn in callbacks.items():
model.add_callback(event, callback_fn)
else:
wandb.termerror(
"The RANK of the process to add the callbacks was neither 0 or "
"-1. No Weights & Biases callbacks were added to this instance "
"of the YOLO model."
)
return model