File size: 11,371 Bytes
9c6594c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
from typing import Any, Callable, Dict, List, Optional

from ultralytics.yolo.engine.model import YOLO
from ultralytics.yolo.engine.trainer import BaseTrainer

try:
    from ultralytics.yolo.utils import RANK
    from ultralytics.yolo.utils.torch_utils import get_flops, get_num_params
except ModuleNotFoundError:
    from ultralytics.utils import RANK
    from ultralytics.utils.torch_utils import get_flops, get_num_params
from ultralytics.yolo.v8.classify.train import ClassificationTrainer

import wandb
from wandb.sdk.lib import telemetry


class WandbCallback:
    """An internal YOLO model wrapper that tracks metrics, and logs models to Weights & Biases.

    Usage:
    ```python
    from wandb.integration.yolov8.yolov8 import WandbCallback

    model = YOLO("yolov8n.pt")
    wandb_logger = WandbCallback(
        model,
    )
    for event, callback_fn in wandb_logger.callbacks.items():
        model.add_callback(event, callback_fn)
    ```
    """

    def __init__(
        self,
        yolo: YOLO,
        run_name: Optional[str] = None,
        project: Optional[str] = None,
        tags: Optional[List[str]] = None,
        resume: Optional[str] = None,
        **kwargs: Optional[Any],
    ) -> None:
        """A utility class to manage wandb run and various callbacks for the ultralytics YOLOv8 framework.

        Args:
            yolo: A YOLOv8 model that's inherited from `:class:ultralytics.yolo.engine.model.YOLO`
            run_name, str: The name of the Weights & Biases run, defaults to an auto generated run_name if `trainer.args.name` is not defined.
            project, str: The name of the Weights & Biases project, defaults to `"YOLOv8"` if `trainer.args.project` is not defined.
            tags, List[str]: A list of tags to be added to the Weights & Biases run, defaults to `["YOLOv8"]`.
            resume, str: Whether to resume a previous run on Weights & Biases, defaults to `None`.
            **kwargs: Additional arguments to be passed to `wandb.init()`.
        """
        self.yolo = yolo
        self.run_name = run_name
        self.project = project
        self.tags = tags
        self.resume = resume
        self.kwargs = kwargs

    def on_pretrain_routine_start(self, trainer: BaseTrainer) -> None:
        """Starts a new wandb run to track the training process and log to Weights & Biases.

        Args:
            trainer: A task trainer that's inherited from `:class:ultralytics.yolo.engine.trainer.BaseTrainer`
                    that contains the model training and optimization routine.
        """
        if wandb.run is None:
            self.run = wandb.init(
                name=self.run_name if self.run_name else trainer.args.name,
                project=self.project
                if self.project
                else trainer.args.project or "YOLOv8",
                tags=self.tags if self.tags else ["YOLOv8"],
                config=vars(trainer.args),
                resume=self.resume if self.resume else None,
                **self.kwargs,
            )
        else:
            self.run = wandb.run
        assert self.run is not None
        self.run.define_metric("epoch", hidden=True)
        self.run.define_metric(
            "train/*", step_metric="epoch", step_sync=True, summary="min"
        )

        self.run.define_metric(
            "val/*", step_metric="epoch", step_sync=True, summary="min"
        )

        self.run.define_metric(
            "metrics/*", step_metric="epoch", step_sync=True, summary="max"
        )
        self.run.define_metric(
            "lr/*", step_metric="epoch", step_sync=True, summary="last"
        )

        with telemetry.context(run=wandb.run) as tel:
            tel.feature.ultralytics_yolov8 = True

    def on_pretrain_routine_end(self, trainer: BaseTrainer) -> None:
        assert self.run is not None
        self.run.summary.update(
            {
                "model/parameters": get_num_params(trainer.model),
                "model/GFLOPs": round(get_flops(trainer.model), 3),
            }
        )

    def on_train_epoch_start(self, trainer: BaseTrainer) -> None:
        """On train epoch start we only log epoch number to the Weights & Biases run."""
        # We log the epoch number here to commit the previous step,
        assert self.run is not None
        self.run.log({"epoch": trainer.epoch + 1})

    def on_train_epoch_end(self, trainer: BaseTrainer) -> None:
        """On train epoch end we log all the metrics to the Weights & Biases run."""
        assert self.run is not None
        self.run.log(
            {
                **trainer.metrics,
                **trainer.label_loss_items(trainer.tloss, prefix="train"),
                **trainer.lr,
            },
        )
        # Currently only the detection and segmentation trainers save images to the save_dir
        if not isinstance(trainer, ClassificationTrainer):
            self.run.log(
                {
                    "train_batch_images": [
                        wandb.Image(str(image_path), caption=image_path.stem)
                        for image_path in trainer.save_dir.glob("train_batch*.jpg")
                    ]
                }
            )

    def on_fit_epoch_end(self, trainer: BaseTrainer) -> None:
        """On fit epoch end we log all the best metrics and model detail to Weights & Biases run summary."""
        assert self.run is not None
        if trainer.epoch == 0:
            speeds = [
                trainer.validator.speed.get(
                    key,
                )
                for key in (1, "inference")
            ]
            speed = speeds[0] if speeds[0] else speeds[1]
            if speed:
                self.run.summary.update(
                    {
                        "model/speed(ms/img)": round(speed, 3),
                    }
                )
        if trainer.best_fitness == trainer.fitness:
            self.run.summary.update(
                {
                    "best/epoch": trainer.epoch + 1,
                    **{f"best/{key}": val for key, val in trainer.metrics.items()},
                }
            )

    def on_train_end(self, trainer: BaseTrainer) -> None:
        """On train end we log all the media, including plots, images and best model artifact to Weights & Biases."""
        # Currently only the detection and segmentation trainers save images to the save_dir
        assert self.run is not None
        if not isinstance(trainer, ClassificationTrainer):
            assert self.run is not None
            self.run.log(
                {
                    "plots": [
                        wandb.Image(str(image_path), caption=image_path.stem)
                        for image_path in trainer.save_dir.glob("*.png")
                    ],
                    "val_images": [
                        wandb.Image(str(image_path), caption=image_path.stem)
                        for image_path in trainer.validator.save_dir.glob("val*.jpg")
                    ],
                },
            )

        if trainer.best.exists():
            assert self.run is not None
            self.run.log_artifact(
                str(trainer.best),
                type="model",
                name=f"{self.run.name}_{trainer.args.task}.pt",
                aliases=["best", f"epoch_{trainer.epoch + 1}"],
            )

    def on_model_save(self, trainer: BaseTrainer) -> None:
        """On model save we log the model as an artifact to Weights & Biases."""
        assert self.run is not None
        self.run.log_artifact(
            str(trainer.last),
            type="model",
            name=f"{self.run.name}_{trainer.args.task}.pt",
            aliases=["last", f"epoch_{trainer.epoch + 1}"],
        )

    def teardown(self, _trainer: BaseTrainer) -> None:
        """On teardown, we finish the Weights & Biases run and set it to None."""
        assert self.run is not None
        self.run.finish()
        self.run = None

    @property
    def callbacks(
        self,
    ) -> Dict[str, Callable]:
        """Property contains all the relevant callbacks to add to the YOLO model for the Weights & Biases logging."""
        return {
            "on_pretrain_routine_start": self.on_pretrain_routine_start,
            "on_pretrain_routine_end": self.on_pretrain_routine_end,
            "on_train_epoch_start": self.on_train_epoch_start,
            "on_train_epoch_end": self.on_train_epoch_end,
            "on_fit_epoch_end": self.on_fit_epoch_end,
            "on_train_end": self.on_train_end,
            "on_model_save": self.on_model_save,
            "teardown": self.teardown,
        }


def add_callbacks(
    yolo: YOLO,
    run_name: Optional[str] = None,
    project: Optional[str] = None,
    tags: Optional[List[str]] = None,
    resume: Optional[str] = None,
    **kwargs: Optional[Any],
) -> YOLO:
    """A YOLO model wrapper that tracks metrics, and logs models to Weights & Biases.

    Args:
        yolo: A YOLOv8 model that's inherited from `:class:ultralytics.yolo.engine.model.YOLO`
        run_name, str: The name of the Weights & Biases run, defaults to an auto generated name if `trainer.args.name` is not defined.
        project, str: The name of the Weights & Biases project, defaults to `"YOLOv8"` if `trainer.args.project` is not defined.
        tags, List[str]: A list of tags to be added to the Weights & Biases run, defaults to `["YOLOv8"]`.
        resume, str: Whether to resume a previous run on Weights & Biases, defaults to `None`.
        **kwargs: Additional arguments to be passed to `wandb.init()`.

    Usage:
    ```python
    from wandb.integration.yolov8 import add_callbacks as add_wandb_callbacks

    model = YOLO("yolov8n.pt")
    add_wandb_callbacks(
        model,
    )
    model.train(
        data="coco128.yaml",
        epochs=3,
        imgsz=640,
    )
    ```
    """
    wandb.termwarn(
        """The wandb callback is currently in beta and is subject to change based on updates to `ultralytics yolov8`.
        The callback is tested and supported for ultralytics v8.0.43 and above.
        Please report any issues to https://github.com/wandb/wandb/issues with the tag `yolov8`.
        """,
        repeat=False,
    )
    wandb.termwarn(
        """This wandb callback is no longer functional and would be deprecated in the near future.
        We recommend you to use the updated callback using `from wandb.integration.ultralytics import add_wandb_callback`.
        The updated callback is tested and supported for ultralytics 8.0.167 and above.
        You can refer to https://docs.wandb.ai/guides/integrations/ultralytics for the updated documentation.
        Please report any issues to https://github.com/wandb/wandb/issues with the tag `yolov8`.
        """,
        repeat=False,
    )

    if RANK in [-1, 0]:
        wandb_logger = WandbCallback(
            yolo, run_name=run_name, project=project, tags=tags, resume=resume, **kwargs
        )
        for event, callback_fn in wandb_logger.callbacks.items():
            yolo.add_callback(event, callback_fn)
        return yolo
    else:
        wandb.termerror(
            "The RANK of the process to add the callbacks was neither 0 or -1."
            "No Weights & Biases callbacks were added to this instance of the YOLO model."
        )
    return yolo