File size: 5,942 Bytes
9c6594c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
"""catboost init."""

from pathlib import Path
from types import SimpleNamespace
from typing import List, Union

from catboost import CatBoostClassifier, CatBoostRegressor  # type: ignore

import wandb
from wandb.sdk.lib import telemetry as wb_telemetry


class WandbCallback:
    """`WandbCallback` automatically integrates CatBoost with wandb.

    Args:
        - metric_period: (int) if you are passing `metric_period` to your CatBoost model please pass the same value here (default=1).

    Passing `WandbCallback` to CatBoost will:
    - log training and validation metrics at every `metric_period`
    - log iteration at every `metric_period`

    Example:
        ```
        train_pool = Pool(train[features], label=train["label"], cat_features=cat_features)
        test_pool = Pool(test[features], label=test["label"], cat_features=cat_features)

        model = CatBoostRegressor(
            iterations=100,
            loss_function="Cox",
            eval_metric="Cox",
        )

        model.fit(
            train_pool,
            eval_set=test_pool,
            callbacks=[WandbCallback()],
        )
        ```
    """

    def __init__(self, metric_period: int = 1):
        if wandb.run is None:
            raise wandb.Error("You must call `wandb.init()` before `WandbCallback()`")

        with wb_telemetry.context() as tel:
            tel.feature.catboost_wandb_callback = True

        self.metric_period: int = metric_period

    def after_iteration(self, info: SimpleNamespace) -> bool:
        if info.iteration % self.metric_period == 0:
            for data, metric in info.metrics.items():
                for metric_name, log in metric.items():
                    # todo: replace with wandb.run._log once available
                    wandb.log({f"{data}-{metric_name}": log[-1]}, commit=False)
            # todo: replace with wandb.run._log once available
            wandb.log({f"iteration@metric-period-{self.metric_period}": info.iteration})

        return True


def _checkpoint_artifact(
    model: Union[CatBoostClassifier, CatBoostRegressor], aliases: List[str]
) -> None:
    """Upload model checkpoint as W&B artifact."""
    if wandb.run is None:
        raise wandb.Error(
            "You must call `wandb.init()` before `_checkpoint_artifact()`"
        )

    model_name = f"model_{wandb.run.id}"
    # save the model in the default `cbm` format
    model_path = Path(wandb.run.dir) / "model"

    model.save_model(model_path)

    model_artifact = wandb.Artifact(name=model_name, type="model")
    model_artifact.add_file(str(model_path))
    wandb.log_artifact(model_artifact, aliases=aliases)


def _log_feature_importance(
    model: Union[CatBoostClassifier, CatBoostRegressor],
) -> None:
    """Log feature importance with default settings."""
    if wandb.run is None:
        raise wandb.Error(
            "You must call `wandb.init()` before `_checkpoint_artifact()`"
        )

    feat_df = model.get_feature_importance(prettified=True)

    fi_data = [
        [feat, feat_imp]
        for feat, feat_imp in zip(feat_df["Feature Id"], feat_df["Importances"])
    ]
    table = wandb.Table(data=fi_data, columns=["Feature", "Importance"])
    # todo: replace with wandb.run._log once available
    wandb.log(
        {
            "Feature Importance": wandb.plot.bar(
                table, "Feature", "Importance", title="Feature Importance"
            )
        },
        commit=False,
    )


def log_summary(
    model: Union[CatBoostClassifier, CatBoostRegressor],
    log_all_params: bool = True,
    save_model_checkpoint: bool = False,
    log_feature_importance: bool = True,
) -> None:
    """`log_summary` logs useful metrics about catboost model after training is done.

    Args:
        model: it can be CatBoostClassifier or CatBoostRegressor.
        log_all_params: (boolean) if True (default) log the model hyperparameters as W&B config.
        save_model_checkpoint: (boolean) if True saves the model upload as W&B artifacts.
        log_feature_importance: (boolean) if True (default) logs feature importance as W&B bar chart using the default setting of `get_feature_importance`.

    Using this along with `wandb_callback` will:

    - save the hyperparameters as W&B config,
    - log `best_iteration` and `best_score` as `wandb.summary`,
    - save and upload your trained model to Weights & Biases Artifacts (when `save_model_checkpoint = True`)
    - log feature importance plot.

    Example:
        ```python
        train_pool = Pool(train[features], label=train["label"], cat_features=cat_features)
        test_pool = Pool(test[features], label=test["label"], cat_features=cat_features)

        model = CatBoostRegressor(
            iterations=100,
            loss_function="Cox",
            eval_metric="Cox",
        )

        model.fit(
            train_pool,
            eval_set=test_pool,
            callbacks=[WandbCallback()],
        )

        log_summary(model)
        ```
    """
    if wandb.run is None:
        raise wandb.Error("You must call `wandb.init()` before `log_summary()`")

    if not (isinstance(model, (CatBoostClassifier, CatBoostRegressor))):
        raise wandb.Error(
            "Model should be an instance of CatBoostClassifier or CatBoostRegressor"
        )

    with wb_telemetry.context() as tel:
        tel.feature.catboost_log_summary = True

    # log configs
    params = model.get_all_params()
    if log_all_params:
        wandb.config.update(params)

    # log best score and iteration
    wandb.run.summary["best_iteration"] = model.get_best_iteration()
    wandb.run.summary["best_score"] = model.get_best_score()

    # log model
    if save_model_checkpoint:
        aliases = ["best"] if params["use_best_model"] else ["last"]
        _checkpoint_artifact(model, aliases=aliases)

    # Feature importance
    if log_feature_importance:
        _log_feature_importance(model)