File size: 6,769 Bytes
9c6594c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 |
from typing import Iterable, Sequence
import wandb
from wandb import util
def test_missing(**kwargs):
np = util.get_module("numpy", required="Logging plots requires numpy")
pd = util.get_module("pandas", required="Logging dataframes requires pandas")
scipy = util.get_module("scipy", required="Logging scipy matrices requires scipy")
test_passed = True
for k, v in kwargs.items():
# Missing/empty params/datapoint arrays
if v is None:
wandb.termerror(f"{k} is None. Please try again.")
test_passed = False
if (k == "X") or (k == "X_test"):
if isinstance(v, scipy.sparse.csr.csr_matrix):
v = v.toarray()
elif isinstance(v, (pd.DataFrame, pd.Series)):
v = v.to_numpy()
elif isinstance(v, list):
v = np.asarray(v)
# Warn the user about missing values
missing = 0
missing = np.count_nonzero(pd.isnull(v))
if missing > 0:
wandb.termwarn("%s contains %d missing values. " % (k, missing))
test_passed = False
# Ensure the dataset contains only integers
non_nums = 0
if v.ndim == 1:
non_nums = sum(
1
for val in v
if (
not isinstance(val, (int, float, complex))
and not isinstance(val, np.number)
)
)
else:
non_nums = sum(
1
for sl in v
for val in sl
if (
not isinstance(val, (int, float, complex))
and not isinstance(val, np.number)
)
)
if non_nums > 0:
wandb.termerror(
f"{k} contains values that are not numbers. Please vectorize, "
f"label encode or one hot encode {k} and call the plotting function again."
)
test_passed = False
return test_passed
def test_fitted(model):
np = util.get_module("numpy", required="Logging plots requires numpy")
_ = util.get_module("pandas", required="Logging dataframes requires pandas")
_ = util.get_module("scipy", required="Logging scipy matrices requires scipy")
scikit_utils = util.get_module(
"sklearn.utils",
required="roc requires the scikit utils submodule, install with `pip install scikit-learn`",
)
scikit_exceptions = util.get_module(
"sklearn.exceptions",
"roc requires the scikit preprocessing submodule, install with `pip install scikit-learn`",
)
try:
model.predict(np.zeros((7, 3)))
except scikit_exceptions.NotFittedError:
wandb.termerror("Please fit the model before passing it in.")
return False
except AttributeError:
# Some clustering models (LDA, PCA, Agglomerative) don't implement ``predict``
try:
scikit_utils.validation.check_is_fitted(
model,
[
"coef_",
"estimator_",
"labels_",
"n_clusters_",
"children_",
"components_",
"n_components_",
"n_iter_",
"n_batch_iter_",
"explained_variance_",
"singular_values_",
"mean_",
],
all_or_any=any,
)
except scikit_exceptions.NotFittedError:
wandb.termerror("Please fit the model before passing it in.")
return False
else:
return True
except Exception:
# Assume it's fitted, since ``NotFittedError`` wasn't raised
return True
def encode_labels(df):
_ = util.get_module("pandas", required="Logging dataframes requires pandas")
preprocessing = util.get_module(
"sklearn.preprocessing",
"roc requires the scikit preprocessing submodule, install with `pip install scikit-learn`",
)
le = preprocessing.LabelEncoder()
# apply le on categorical feature columns
categorical_cols = df.select_dtypes(
exclude=["int", "float", "float64", "float32", "int32", "int64"]
).columns
df[categorical_cols] = df[categorical_cols].apply(lambda col: le.fit_transform(col))
def test_types(**kwargs):
np = util.get_module("numpy", required="Logging plots requires numpy")
pd = util.get_module("pandas", required="Logging dataframes requires pandas")
_ = util.get_module("scipy", required="Logging scipy matrices requires scipy")
base = util.get_module(
"sklearn.base",
"roc requires the scikit base submodule, install with `pip install scikit-learn`",
)
test_passed = True
for k, v in kwargs.items():
# check for incorrect types
if (
(k == "X")
or (k == "X_test")
or (k == "y")
or (k == "y_test")
or (k == "y_true")
or (k == "y_probas")
or (k == "x_labels")
or (k == "y_labels")
or (k == "matrix_values")
):
# FIXME: do this individually
if not isinstance(
v,
(
Sequence,
Iterable,
np.ndarray,
np.generic,
pd.DataFrame,
pd.Series,
list,
),
):
wandb.termerror(f"{k} is not an array. Please try again.")
test_passed = False
# check for classifier types
if k == "model":
if (not base.is_classifier(v)) and (not base.is_regressor(v)):
wandb.termerror(
f"{k} is not a classifier or regressor. Please try again."
)
test_passed = False
elif k == "clf" or k == "binary_clf":
if not (base.is_classifier(v)):
wandb.termerror(f"{k} is not a classifier. Please try again.")
test_passed = False
elif k == "regressor":
if not base.is_regressor(v):
wandb.termerror(f"{k} is not a regressor. Please try again.")
test_passed = False
elif k == "clusterer":
if not (getattr(v, "_estimator_type", None) == "clusterer"):
wandb.termerror(f"{k} is not a clusterer. Please try again.")
test_passed = False
return test_passed
|