|
"""catboost init.""" |
|
|
|
from pathlib import Path |
|
from types import SimpleNamespace |
|
from typing import List, Union |
|
|
|
from catboost import CatBoostClassifier, CatBoostRegressor |
|
|
|
import wandb |
|
from wandb.sdk.lib import telemetry as wb_telemetry |
|
|
|
|
|
class WandbCallback: |
|
"""`WandbCallback` automatically integrates CatBoost with wandb. |
|
|
|
Args: |
|
- metric_period: (int) if you are passing `metric_period` to your CatBoost model please pass the same value here (default=1). |
|
|
|
Passing `WandbCallback` to CatBoost will: |
|
- log training and validation metrics at every `metric_period` |
|
- log iteration at every `metric_period` |
|
|
|
Example: |
|
``` |
|
train_pool = Pool(train[features], label=train["label"], cat_features=cat_features) |
|
test_pool = Pool(test[features], label=test["label"], cat_features=cat_features) |
|
|
|
model = CatBoostRegressor( |
|
iterations=100, |
|
loss_function="Cox", |
|
eval_metric="Cox", |
|
) |
|
|
|
model.fit( |
|
train_pool, |
|
eval_set=test_pool, |
|
callbacks=[WandbCallback()], |
|
) |
|
``` |
|
""" |
|
|
|
def __init__(self, metric_period: int = 1): |
|
if wandb.run is None: |
|
raise wandb.Error("You must call `wandb.init()` before `WandbCallback()`") |
|
|
|
with wb_telemetry.context() as tel: |
|
tel.feature.catboost_wandb_callback = True |
|
|
|
self.metric_period: int = metric_period |
|
|
|
def after_iteration(self, info: SimpleNamespace) -> bool: |
|
if info.iteration % self.metric_period == 0: |
|
for data, metric in info.metrics.items(): |
|
for metric_name, log in metric.items(): |
|
|
|
wandb.log({f"{data}-{metric_name}": log[-1]}, commit=False) |
|
|
|
wandb.log({f"iteration@metric-period-{self.metric_period}": info.iteration}) |
|
|
|
return True |
|
|
|
|
|
def _checkpoint_artifact( |
|
model: Union[CatBoostClassifier, CatBoostRegressor], aliases: List[str] |
|
) -> None: |
|
"""Upload model checkpoint as W&B artifact.""" |
|
if wandb.run is None: |
|
raise wandb.Error( |
|
"You must call `wandb.init()` before `_checkpoint_artifact()`" |
|
) |
|
|
|
model_name = f"model_{wandb.run.id}" |
|
|
|
model_path = Path(wandb.run.dir) / "model" |
|
|
|
model.save_model(model_path) |
|
|
|
model_artifact = wandb.Artifact(name=model_name, type="model") |
|
model_artifact.add_file(str(model_path)) |
|
wandb.log_artifact(model_artifact, aliases=aliases) |
|
|
|
|
|
def _log_feature_importance( |
|
model: Union[CatBoostClassifier, CatBoostRegressor], |
|
) -> None: |
|
"""Log feature importance with default settings.""" |
|
if wandb.run is None: |
|
raise wandb.Error( |
|
"You must call `wandb.init()` before `_checkpoint_artifact()`" |
|
) |
|
|
|
feat_df = model.get_feature_importance(prettified=True) |
|
|
|
fi_data = [ |
|
[feat, feat_imp] |
|
for feat, feat_imp in zip(feat_df["Feature Id"], feat_df["Importances"]) |
|
] |
|
table = wandb.Table(data=fi_data, columns=["Feature", "Importance"]) |
|
|
|
wandb.log( |
|
{ |
|
"Feature Importance": wandb.plot.bar( |
|
table, "Feature", "Importance", title="Feature Importance" |
|
) |
|
}, |
|
commit=False, |
|
) |
|
|
|
|
|
def log_summary( |
|
model: Union[CatBoostClassifier, CatBoostRegressor], |
|
log_all_params: bool = True, |
|
save_model_checkpoint: bool = False, |
|
log_feature_importance: bool = True, |
|
) -> None: |
|
"""`log_summary` logs useful metrics about catboost model after training is done. |
|
|
|
Args: |
|
model: it can be CatBoostClassifier or CatBoostRegressor. |
|
log_all_params: (boolean) if True (default) log the model hyperparameters as W&B config. |
|
save_model_checkpoint: (boolean) if True saves the model upload as W&B artifacts. |
|
log_feature_importance: (boolean) if True (default) logs feature importance as W&B bar chart using the default setting of `get_feature_importance`. |
|
|
|
Using this along with `wandb_callback` will: |
|
|
|
- save the hyperparameters as W&B config, |
|
- log `best_iteration` and `best_score` as `wandb.summary`, |
|
- save and upload your trained model to Weights & Biases Artifacts (when `save_model_checkpoint = True`) |
|
- log feature importance plot. |
|
|
|
Example: |
|
```python |
|
train_pool = Pool(train[features], label=train["label"], cat_features=cat_features) |
|
test_pool = Pool(test[features], label=test["label"], cat_features=cat_features) |
|
|
|
model = CatBoostRegressor( |
|
iterations=100, |
|
loss_function="Cox", |
|
eval_metric="Cox", |
|
) |
|
|
|
model.fit( |
|
train_pool, |
|
eval_set=test_pool, |
|
callbacks=[WandbCallback()], |
|
) |
|
|
|
log_summary(model) |
|
``` |
|
""" |
|
if wandb.run is None: |
|
raise wandb.Error("You must call `wandb.init()` before `log_summary()`") |
|
|
|
if not (isinstance(model, (CatBoostClassifier, CatBoostRegressor))): |
|
raise wandb.Error( |
|
"Model should be an instance of CatBoostClassifier or CatBoostRegressor" |
|
) |
|
|
|
with wb_telemetry.context() as tel: |
|
tel.feature.catboost_log_summary = True |
|
|
|
|
|
params = model.get_all_params() |
|
if log_all_params: |
|
wandb.config.update(params) |
|
|
|
|
|
wandb.run.summary["best_iteration"] = model.get_best_iteration() |
|
wandb.run.summary["best_score"] = model.get_best_score() |
|
|
|
|
|
if save_model_checkpoint: |
|
aliases = ["best"] if params["use_best_model"] else ["last"] |
|
_checkpoint_artifact(model, aliases=aliases) |
|
|
|
|
|
if log_feature_importance: |
|
_log_feature_importance(model) |
|
|