jamtur01's picture
Upload folder using huggingface_hub
9c6594c verified
from warnings import simplefilter
from sklearn import model_selection
import wandb
from wandb.integration.sklearn import utils
# ignore all future warnings
simplefilter(action="ignore", category=FutureWarning)
def residuals(regressor, X, y): # noqa: N803
# Create the train and test splits
x_train, x_test, y_train, y_test = model_selection.train_test_split(
X, y, test_size=0.2
)
# Store labels and colors for the legend ordered by call
regressor.fit(x_train, y_train)
train_score_ = regressor.score(x_train, y_train)
test_score_ = regressor.score(x_test, y_test)
y_pred_train = regressor.predict(x_train)
residuals_train = y_pred_train - y_train
y_pred_test = regressor.predict(x_test)
residuals_test = y_pred_test - y_test
table = make_table(
y_pred_train,
residuals_train,
y_pred_test,
residuals_test,
train_score_,
test_score_,
)
chart = wandb.visualize("wandb/residuals_plot/v1", table)
return chart
def make_table(
y_pred_train,
residuals_train,
y_pred_test,
residuals_test,
train_score_,
test_score_,
):
y_pred_column, dataset_column, residuals_column = [], [], []
datapoints, max_datapoints_train = 0, 100
for pred, residual in zip(y_pred_train, residuals_train):
# add class counts from training set
y_pred_column.append(pred)
dataset_column.append("train")
residuals_column.append(residual)
datapoints += 1
if utils.check_against_limit(datapoints, "residuals", max_datapoints_train):
break
datapoints = 0
for pred, residual in zip(y_pred_test, residuals_test):
# add class counts from training set
y_pred_column.append(pred)
dataset_column.append("test")
residuals_column.append(residual)
datapoints += 1
if utils.check_against_limit(datapoints, "residuals", max_datapoints_train):
break
columns = ["dataset", "y_pred", "residuals", "train_score", "test_score"]
data = [
[
dataset_column[i],
y_pred_column[i],
residuals_column[i],
train_score_,
test_score_,
]
for i in range(len(y_pred_column))
]
table = wandb.Table(columns=columns, data=data)
return table