"""Define plots used by multiple sklearn model classes.""" from warnings import simplefilter import numpy as np import wandb from wandb.integration.sklearn import calculate, utils # ignore all future warnings simplefilter(action="ignore", category=FutureWarning) def summary_metrics(model=None, X=None, y=None, X_test=None, y_test=None): # noqa: N803 """Logs a chart depicting summary metrics for a model. Should only be called with a fitted model (otherwise an error is thrown). Args: model: (clf or reg) Takes in a fitted regressor or classifier. X: (arr) Training set features. y: (arr) Training set labels. X_test: (arr) Test set features. y_test: (arr) Test set labels. Returns: None: To see plots, go to your W&B run page then expand the 'media' tab under 'auto visualizations'. Example: ```python wandb.sklearn.plot_summary_metrics(model, X_train, y_train, X_test, y_test) ``` """ not_missing = utils.test_missing( model=model, X=X, y=y, X_test=X_test, y_test=y_test ) correct_types = utils.test_types( model=model, X=X, y=y, X_test=X_test, y_test=y_test ) model_fitted = utils.test_fitted(model) if not_missing and correct_types and model_fitted: metrics_chart = calculate.summary_metrics(model, X, y, X_test, y_test) wandb.log({"summary_metrics": metrics_chart}) def learning_curve( model=None, X=None, # noqa: N803 y=None, cv=None, shuffle=False, random_state=None, train_sizes=None, n_jobs=1, scoring=None, ): """Logs a plot depicting model performance against dataset size. Please note this function fits the model to datasets of varying sizes when called. Args: model: (clf or reg) Takes in a fitted regressor or classifier. X: (arr) Dataset features. y: (arr) Dataset labels. For details on the other keyword arguments, see the documentation for `sklearn.model_selection.learning_curve`. Returns: None: To see plots, go to your W&B run page then expand the 'media' tab under 'auto visualizations'. Example: ```python wandb.sklearn.plot_learning_curve(model, X, y) ``` """ not_missing = utils.test_missing(model=model, X=X, y=y) correct_types = utils.test_types(model=model, X=X, y=y) if not_missing and correct_types: if train_sizes is None: train_sizes = np.linspace(0.1, 1.0, 5) y = np.asarray(y) learning_curve_chart = calculate.learning_curve( model, X, y, cv, shuffle, random_state, train_sizes, n_jobs, scoring ) wandb.log({"learning_curve": learning_curve_chart})