|
import gradio as gr |
|
|
|
from sklearn.datasets import make_classification |
|
from sklearn.model_selection import train_test_split |
|
from sklearn.ensemble import RandomForestClassifier |
|
from sklearn.inspection import permutation_importance |
|
|
|
import numpy as np |
|
import pandas as pd |
|
import matplotlib.pyplot as plt |
|
|
|
def create_dataset(num_samples, num_informative): |
|
X, y = make_classification( |
|
n_samples=num_samples, |
|
n_features=10, |
|
n_informative=num_informative, |
|
n_redundant=0, |
|
n_repeated=0, |
|
n_classes=2, |
|
random_state=0, |
|
shuffle=False, |
|
) |
|
|
|
X_train, X_test, y_train, y_test = train_test_split(X, y, stratify=y, random_state=42) |
|
return X_train, X_test, y_train, y_test |
|
|
|
def plot_mean_decrease(clf, feature_names): |
|
importances = clf.feature_importances_ |
|
std = np.std([tree.feature_importances_ for tree in clf.estimators_], axis=0) |
|
|
|
forest_importances = pd.Series(importances, index=feature_names) |
|
|
|
fig, ax = plt.subplots() |
|
forest_importances.plot.bar(yerr=std, ax=ax) |
|
ax.set_title("Feature importances using MDI") |
|
ax.set_ylabel("Mean decrease in impurity") |
|
fig.tight_layout() |
|
|
|
return fig |
|
|
|
def plot_feature_perm(clf, feature_names, X_test, y_test): |
|
result = permutation_importance( |
|
clf, X_test, y_test, n_repeats=10, random_state=42, n_jobs=2 |
|
) |
|
forest_importances = pd.Series(result.importances_mean, index=feature_names) |
|
|
|
fig, ax = plt.subplots() |
|
forest_importances.plot.bar(yerr=result.importances_std, ax=ax) |
|
ax.set_title("Feature importances using permutation on full model") |
|
ax.set_ylabel("Mean accuracy decrease") |
|
fig.tight_layout() |
|
|
|
return fig |
|
|
|
def train_model(num_samples, num_info): |
|
|
|
X_train, X_test, y_train, y_test = create_dataset(num_samples, num_info) |
|
|
|
feature_names = [f"feature {i}" for i in range(X_train.shape[1])] |
|
forest = RandomForestClassifier(random_state=0) |
|
forest.fit(X_train, y_train) |
|
|
|
fig = plot_mean_decrease(forest, feature_names) |
|
fig2 = plot_feature_perm(forest, feature_names, X_test, y_test) |
|
return fig, fig2 |
|
|
|
|
|
|
|
title = "Feature importances with a forest of trees 🌳" |
|
description = """This example shows the use of a forest of trees to evaluate the importance of features on an artificial classification task. |
|
The blue bars are the feature importances of the forest, along with their inter-trees variability represented by the error bars. |
|
|
|
The model is trained with simulated data. |
|
""" |
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown(f"## {title}") |
|
gr.Markdown(description) |
|
|
|
|
|
|
|
num_samples = gr.Slider(minimum=1000, maximum=5000, step=500, value=1000, label="Number of samples") |
|
num_info = gr.Slider(minimum=2, maximum=10, step=1, value=3, label="Number of informative features") |
|
|
|
with gr.Row(): |
|
plot = gr.Plot() |
|
plot2 = gr.Plot() |
|
|
|
num_samples.change(fn=train_model, inputs=[num_samples, num_info], outputs=[plot, plot2]) |
|
num_info.change(fn=train_model, inputs=[num_samples, num_info], outputs=[plot, plot2]) |
|
|
|
|
|
demo.launch(enable_queue=True) |
|
|