File size: 6,495 Bytes
9c6594c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
"""xgboost init!"""

import json
import warnings
from pathlib import Path
from typing import TYPE_CHECKING, cast

import xgboost as xgb  # type: ignore
from xgboost import Booster

import wandb
from wandb.sdk.lib import telemetry as wb_telemetry

MINIMIZE_METRICS = [
    "rmse",
    "rmsle",
    "mae",
    "mape",
    "mphe",
    "logloss",
    "error",
    "error@t",
    "merror",
]

MAXIMIZE_METRICS = ["auc", "aucpr", "ndcg", "map", "ndcg@n", "map@n"]


if TYPE_CHECKING:
    from typing import Callable, List, NamedTuple

    class CallbackEnv(NamedTuple):
        evaluation_result_list: List


def wandb_callback() -> "Callable":
    """Old style callback that will be deprecated in favor of WandbCallback. Please try the new logger for more features."""
    warnings.warn(
        "wandb_callback will be deprecated in favor of WandbCallback. Please use WandbCallback for more features.",
        UserWarning,
        stacklevel=2,
    )

    with wb_telemetry.context() as tel:
        tel.feature.xgboost_old_wandb_callback = True

    def callback(env: "CallbackEnv") -> None:
        for k, v in env.evaluation_result_list:
            wandb.log({k: v}, commit=False)
        wandb.log({})

    return callback


class WandbCallback(xgb.callback.TrainingCallback):
    """`WandbCallback` automatically integrates XGBoost with wandb.

    Args:
        log_model: (boolean) if True save and upload the model to Weights & Biases Artifacts
        log_feature_importance: (boolean) if True log a feature importance bar plot
        importance_type: (str) one of {weight, gain, cover, total_gain, total_cover} for tree model. weight for linear model.
        define_metric: (boolean) if True (default) capture model performance at the best step, instead of the last step, of training in your `wandb.summary`.

    Passing `WandbCallback` to XGBoost will:

    - log the booster model configuration to Weights & Biases
    - log evaluation metrics collected by XGBoost, such as rmse, accuracy etc. to Weights & Biases
    - log training metric collected by XGBoost (if you provide training data to eval_set)
    - log the best score and the best iteration
    - save and upload your trained model to Weights & Biases Artifacts (when `log_model = True`)
    - log feature importance plot when `log_feature_importance=True` (default).
    - Capture the best eval metric in `wandb.summary` when `define_metric=True` (default).

    Example:
        ```python
        bst_params = dict(
            objective="reg:squarederror",
            colsample_bytree=0.3,
            learning_rate=0.1,
            max_depth=5,
            alpha=10,
            n_estimators=10,
            tree_method="hist",
            callbacks=[WandbCallback()],
        )

        xg_reg = xgb.XGBRegressor(**bst_params)
        xg_reg.fit(
            X_train,
            y_train,
            eval_set=[(X_test, y_test)],
        )
        ```
    """

    def __init__(
        self,
        log_model: bool = False,
        log_feature_importance: bool = True,
        importance_type: str = "gain",
        define_metric: bool = True,
    ):
        self.log_model: bool = log_model
        self.log_feature_importance: bool = log_feature_importance
        self.importance_type: str = importance_type
        self.define_metric: bool = define_metric

        if wandb.run is None:
            raise wandb.Error("You must call wandb.init() before WandbCallback()")

        with wb_telemetry.context() as tel:
            tel.feature.xgboost_wandb_callback = True

    def before_training(self, model: Booster) -> Booster:
        """Run before training is finished."""
        # Update W&B config
        config = model.save_config()
        wandb.config.update(json.loads(config))

        return model

    def after_training(self, model: Booster) -> Booster:
        """Run after training is finished."""
        # Log the booster model as artifacts
        if self.log_model:
            self._log_model_as_artifact(model)

        # Plot feature importance
        if self.log_feature_importance:
            self._log_feature_importance(model)

        # Log the best score and best iteration
        if model.attr("best_score") is not None:
            wandb.log(
                {
                    "best_score": float(cast(str, model.attr("best_score"))),
                    "best_iteration": int(cast(str, model.attr("best_iteration"))),
                }
            )

        return model

    def after_iteration(self, model: Booster, epoch: int, evals_log: dict) -> bool:
        """Run after each iteration. Return True when training should stop."""
        # Log metrics
        for data, metric in evals_log.items():
            for metric_name, log in metric.items():
                if self.define_metric:
                    self._define_metric(data, metric_name)
                    wandb.log({f"{data}-{metric_name}": log[-1]}, commit=False)
                else:
                    wandb.log({f"{data}-{metric_name}": log[-1]}, commit=False)

        wandb.log({"epoch": epoch})

        self.define_metric = False

        return False

    def _log_model_as_artifact(self, model: Booster) -> None:
        model_name = f"{wandb.run.id}_model.json"  # type: ignore
        model_path = Path(wandb.run.dir) / model_name  # type: ignore
        model.save_model(str(model_path))

        model_artifact = wandb.Artifact(name=model_name, type="model")
        model_artifact.add_file(str(model_path))
        wandb.log_artifact(model_artifact)

    def _log_feature_importance(self, model: Booster) -> None:
        fi = model.get_score(importance_type=self.importance_type)
        fi_data = [[k, fi[k]] for k in fi]
        table = wandb.Table(data=fi_data, columns=["Feature", "Importance"])
        wandb.log(
            {
                "Feature Importance": wandb.plot.bar(
                    table, "Feature", "Importance", title="Feature Importance"
                )
            }
        )

    def _define_metric(self, data: str, metric_name: str) -> None:
        if "loss" in str.lower(metric_name):
            wandb.define_metric(f"{data}-{metric_name}", summary="min")
        elif str.lower(metric_name) in MINIMIZE_METRICS:
            wandb.define_metric(f"{data}-{metric_name}", summary="min")
        elif str.lower(metric_name) in MAXIMIZE_METRICS:
            wandb.define_metric(f"{data}-{metric_name}", summary="max")
        else:
            pass