File size: 2,371 Bytes
9c6594c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
from warnings import simplefilter

from sklearn import model_selection

import wandb
from wandb.integration.sklearn import utils

# ignore all future warnings
simplefilter(action="ignore", category=FutureWarning)


def residuals(regressor, X, y):  # noqa: N803
    # Create the train and test splits
    x_train, x_test, y_train, y_test = model_selection.train_test_split(
        X, y, test_size=0.2
    )

    # Store labels and colors for the legend ordered by call
    regressor.fit(x_train, y_train)
    train_score_ = regressor.score(x_train, y_train)
    test_score_ = regressor.score(x_test, y_test)

    y_pred_train = regressor.predict(x_train)
    residuals_train = y_pred_train - y_train

    y_pred_test = regressor.predict(x_test)
    residuals_test = y_pred_test - y_test

    table = make_table(
        y_pred_train,
        residuals_train,
        y_pred_test,
        residuals_test,
        train_score_,
        test_score_,
    )
    chart = wandb.visualize("wandb/residuals_plot/v1", table)

    return chart


def make_table(
    y_pred_train,
    residuals_train,
    y_pred_test,
    residuals_test,
    train_score_,
    test_score_,
):
    y_pred_column, dataset_column, residuals_column = [], [], []

    datapoints, max_datapoints_train = 0, 100
    for pred, residual in zip(y_pred_train, residuals_train):
        # add class counts from training set
        y_pred_column.append(pred)
        dataset_column.append("train")
        residuals_column.append(residual)
        datapoints += 1
        if utils.check_against_limit(datapoints, "residuals", max_datapoints_train):
            break

    datapoints = 0
    for pred, residual in zip(y_pred_test, residuals_test):
        # add class counts from training set
        y_pred_column.append(pred)
        dataset_column.append("test")
        residuals_column.append(residual)
        datapoints += 1
        if utils.check_against_limit(datapoints, "residuals", max_datapoints_train):
            break

    columns = ["dataset", "y_pred", "residuals", "train_score", "test_score"]
    data = [
        [
            dataset_column[i],
            y_pred_column[i],
            residuals_column[i],
            train_score_,
            test_score_,
        ]
        for i in range(len(y_pred_column))
    ]

    table = wandb.Table(columns=columns, data=data)

    return table