File size: 6,769 Bytes
9c6594c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
from typing import Iterable, Sequence

import wandb
from wandb import util


def test_missing(**kwargs):
    np = util.get_module("numpy", required="Logging plots requires numpy")
    pd = util.get_module("pandas", required="Logging dataframes requires pandas")
    scipy = util.get_module("scipy", required="Logging scipy matrices requires scipy")

    test_passed = True
    for k, v in kwargs.items():
        # Missing/empty params/datapoint arrays
        if v is None:
            wandb.termerror(f"{k} is None. Please try again.")
            test_passed = False
        if (k == "X") or (k == "X_test"):
            if isinstance(v, scipy.sparse.csr.csr_matrix):
                v = v.toarray()
            elif isinstance(v, (pd.DataFrame, pd.Series)):
                v = v.to_numpy()
            elif isinstance(v, list):
                v = np.asarray(v)

            # Warn the user about missing values
            missing = 0
            missing = np.count_nonzero(pd.isnull(v))
            if missing > 0:
                wandb.termwarn("%s contains %d missing values. " % (k, missing))
                test_passed = False
            # Ensure the dataset contains only integers
            non_nums = 0
            if v.ndim == 1:
                non_nums = sum(
                    1
                    for val in v
                    if (
                        not isinstance(val, (int, float, complex))
                        and not isinstance(val, np.number)
                    )
                )
            else:
                non_nums = sum(
                    1
                    for sl in v
                    for val in sl
                    if (
                        not isinstance(val, (int, float, complex))
                        and not isinstance(val, np.number)
                    )
                )
            if non_nums > 0:
                wandb.termerror(
                    f"{k} contains values that are not numbers. Please vectorize, "
                    f"label encode or one hot encode {k} and call the plotting function again."
                )
                test_passed = False
    return test_passed


def test_fitted(model):
    np = util.get_module("numpy", required="Logging plots requires numpy")
    _ = util.get_module("pandas", required="Logging dataframes requires pandas")
    _ = util.get_module("scipy", required="Logging scipy matrices requires scipy")
    scikit_utils = util.get_module(
        "sklearn.utils",
        required="roc requires the scikit utils submodule, install with `pip install scikit-learn`",
    )
    scikit_exceptions = util.get_module(
        "sklearn.exceptions",
        "roc requires the scikit preprocessing submodule, install with `pip install scikit-learn`",
    )

    try:
        model.predict(np.zeros((7, 3)))
    except scikit_exceptions.NotFittedError:
        wandb.termerror("Please fit the model before passing it in.")
        return False
    except AttributeError:
        # Some clustering models (LDA, PCA, Agglomerative) don't implement ``predict``
        try:
            scikit_utils.validation.check_is_fitted(
                model,
                [
                    "coef_",
                    "estimator_",
                    "labels_",
                    "n_clusters_",
                    "children_",
                    "components_",
                    "n_components_",
                    "n_iter_",
                    "n_batch_iter_",
                    "explained_variance_",
                    "singular_values_",
                    "mean_",
                ],
                all_or_any=any,
            )
        except scikit_exceptions.NotFittedError:
            wandb.termerror("Please fit the model before passing it in.")
            return False
        else:
            return True
    except Exception:
        # Assume it's fitted, since ``NotFittedError`` wasn't raised
        return True


def encode_labels(df):
    _ = util.get_module("pandas", required="Logging dataframes requires pandas")
    preprocessing = util.get_module(
        "sklearn.preprocessing",
        "roc requires the scikit preprocessing submodule, install with `pip install scikit-learn`",
    )

    le = preprocessing.LabelEncoder()
    # apply le on categorical feature columns
    categorical_cols = df.select_dtypes(
        exclude=["int", "float", "float64", "float32", "int32", "int64"]
    ).columns
    df[categorical_cols] = df[categorical_cols].apply(lambda col: le.fit_transform(col))


def test_types(**kwargs):
    np = util.get_module("numpy", required="Logging plots requires numpy")
    pd = util.get_module("pandas", required="Logging dataframes requires pandas")
    _ = util.get_module("scipy", required="Logging scipy matrices requires scipy")

    base = util.get_module(
        "sklearn.base",
        "roc requires the scikit base submodule, install with `pip install scikit-learn`",
    )

    test_passed = True
    for k, v in kwargs.items():
        # check for incorrect types
        if (
            (k == "X")
            or (k == "X_test")
            or (k == "y")
            or (k == "y_test")
            or (k == "y_true")
            or (k == "y_probas")
            or (k == "x_labels")
            or (k == "y_labels")
            or (k == "matrix_values")
        ):
            # FIXME: do this individually
            if not isinstance(
                v,
                (
                    Sequence,
                    Iterable,
                    np.ndarray,
                    np.generic,
                    pd.DataFrame,
                    pd.Series,
                    list,
                ),
            ):
                wandb.termerror(f"{k} is not an array. Please try again.")
                test_passed = False
        # check for classifier types
        if k == "model":
            if (not base.is_classifier(v)) and (not base.is_regressor(v)):
                wandb.termerror(
                    f"{k} is not a classifier or regressor. Please try again."
                )
                test_passed = False
        elif k == "clf" or k == "binary_clf":
            if not (base.is_classifier(v)):
                wandb.termerror(f"{k} is not a classifier. Please try again.")
                test_passed = False
        elif k == "regressor":
            if not base.is_regressor(v):
                wandb.termerror(f"{k} is not a regressor. Please try again.")
                test_passed = False
        elif k == "clusterer":
            if not (getattr(v, "_estimator_type", None) == "clusterer"):
                wandb.termerror(f"{k} is not a clusterer. Please try again.")
                test_passed = False
    return test_passed