File size: 5,884 Bytes
9c6594c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 |
"""Shared utilities for the modules in wandb.sklearn."""
from collections.abc import Iterable, Sequence
import numpy as np
import pandas as pd
import scipy
import sklearn
import wandb
chart_limit = 1000
def check_against_limit(count, chart, limit=None):
if limit is None:
limit = chart_limit
if count > limit:
warn_chart_limit(limit, chart)
return True
else:
return False
def warn_chart_limit(limit, chart):
warning = f"using only the first {limit} datapoints to create chart {chart}"
wandb.termwarn(warning)
def encode_labels(df):
le = sklearn.preprocessing.LabelEncoder()
# apply le on categorical feature columns
categorical_cols = df.select_dtypes(
exclude=["int", "float", "float64", "float32", "int32", "int64"]
).columns
df[categorical_cols] = df[categorical_cols].apply(lambda col: le.fit_transform(col))
def test_types(**kwargs):
test_passed = True
for k, v in kwargs.items():
# check for incorrect types
if (
(k == "X")
or (k == "X_test")
or (k == "y")
or (k == "y_test")
or (k == "y_true")
or (k == "y_probas")
):
# FIXME: do this individually
if not isinstance(
v,
(
Sequence,
Iterable,
np.ndarray,
np.generic,
pd.DataFrame,
pd.Series,
list,
),
):
wandb.termerror(f"{k} is not an array. Please try again.")
test_passed = False
# check for classifier types
if k == "model":
if (not sklearn.base.is_classifier(v)) and (
not sklearn.base.is_regressor(v)
):
wandb.termerror(
f"{k} is not a classifier or regressor. Please try again."
)
test_passed = False
elif k == "clf" or k == "binary_clf":
if not (sklearn.base.is_classifier(v)):
wandb.termerror(f"{k} is not a classifier. Please try again.")
test_passed = False
elif k == "regressor":
if not sklearn.base.is_regressor(v):
wandb.termerror(f"{k} is not a regressor. Please try again.")
test_passed = False
elif k == "clusterer":
if not (getattr(v, "_estimator_type", None) == "clusterer"):
wandb.termerror(f"{k} is not a clusterer. Please try again.")
test_passed = False
return test_passed
def test_fitted(model):
try:
model.predict(np.zeros((7, 3)))
except sklearn.exceptions.NotFittedError:
wandb.termerror("Please fit the model before passing it in.")
return False
except AttributeError:
# Some clustering models (LDA, PCA, Agglomerative) don't implement ``predict``
try:
sklearn.utils.validation.check_is_fitted(
model,
[
"coef_",
"estimator_",
"labels_",
"n_clusters_",
"children_",
"components_",
"n_components_",
"n_iter_",
"n_batch_iter_",
"explained_variance_",
"singular_values_",
"mean_",
],
all_or_any=any,
)
except sklearn.exceptions.NotFittedError:
wandb.termerror("Please fit the model before passing it in.")
return False
else:
return True
except Exception:
# Assume it's fitted, since ``NotFittedError`` wasn't raised
return True
# Test Asummptions for plotting parameters and datasets
def test_missing(**kwargs):
test_passed = True
for k, v in kwargs.items():
# Missing/empty params/datapoint arrays
if v is None:
wandb.termerror(f"{k} is None. Please try again.")
test_passed = False
if (k == "X") or (k == "X_test"):
if isinstance(v, scipy.sparse.csr.csr_matrix):
v = v.toarray()
elif isinstance(v, (pd.DataFrame, pd.Series)):
v = v.to_numpy()
elif isinstance(v, list):
v = np.asarray(v)
# Warn the user about missing values
missing = 0
missing = np.count_nonzero(pd.isnull(v))
if missing > 0:
wandb.termwarn(f"{k} contains {missing} missing values. ")
test_passed = False
# Ensure the dataset contains only integers
non_nums = 0
if v.ndim == 1:
non_nums = sum(
1
for val in v
if (
not isinstance(val, (int, float, complex))
and not isinstance(val, np.number)
)
)
else:
non_nums = sum(
1
for sl in v
for val in sl
if (
not isinstance(val, (int, float, complex))
and not isinstance(val, np.number)
)
)
if non_nums > 0:
wandb.termerror(
f"{k} contains values that are not numbers. Please vectorize, label encode or one hot encode {k} "
"and call the plotting function again."
)
test_passed = False
return test_passed
def round_3(n):
return round(n, 3)
def round_2(n):
return round(n, 2)
|