Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import numpy as np
|
3 |
+
import matplotlib.pyplot as plt
|
4 |
+
|
5 |
+
from sklearn.datasets import make_blobs
|
6 |
+
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
|
7 |
+
from sklearn.covariance import OAS
|
8 |
+
|
9 |
+
def generate_data(n_samples, n_features):
|
10 |
+
X, y = make_blobs(n_samples=n_samples, n_features=1, centers=[[-2], [2]])
|
11 |
+
|
12 |
+
if n_features > 1:
|
13 |
+
X = np.hstack([X, np.random.randn(n_samples, n_features - 1)])
|
14 |
+
return X, y
|
15 |
+
|
16 |
+
def classify(n_train, n_test, n_averages, n_features_max, step):
|
17 |
+
acc_clf1, acc_clf2, acc_clf3 = [], [], []
|
18 |
+
n_features_range = range(1, n_features_max + 1, step)
|
19 |
+
|
20 |
+
for n_features in n_features_range:
|
21 |
+
score_clf1, score_clf2, score_clf3 = 0, 0, 0
|
22 |
+
for _ in range(n_averages):
|
23 |
+
X, y = generate_data(n_train, n_features)
|
24 |
+
|
25 |
+
clf1 = LinearDiscriminantAnalysis(solver="lsqr", shrinkage=None).fit(X, y)
|
26 |
+
clf2 = LinearDiscriminantAnalysis(solver="lsqr", shrinkage="auto").fit(X, y)
|
27 |
+
oa = OAS(store_precision=False, assume_centered=False)
|
28 |
+
clf3 = LinearDiscriminantAnalysis(solver="lsqr", covariance_estimator=oa).fit(X, y)
|
29 |
+
|
30 |
+
X, y = generate_data(n_test, n_features)
|
31 |
+
score_clf1 += clf1.score(X, y)
|
32 |
+
score_clf2 += clf2.score(X, y)
|
33 |
+
score_clf3 += clf3.score(X, y)
|
34 |
+
|
35 |
+
acc_clf1.append(score_clf1 / n_averages)
|
36 |
+
acc_clf2.append(score_clf2 / n_averages)
|
37 |
+
acc_clf3.append(score_clf3 / n_averages)
|
38 |
+
|
39 |
+
features_samples_ratio = np.array(n_features_range) / n_train
|
40 |
+
|
41 |
+
plt.plot(
|
42 |
+
features_samples_ratio,
|
43 |
+
acc_clf1,
|
44 |
+
linewidth=2,
|
45 |
+
label="LDA",
|
46 |
+
color="gold",
|
47 |
+
linestyle="solid",
|
48 |
+
)
|
49 |
+
plt.plot(
|
50 |
+
features_samples_ratio,
|
51 |
+
acc_clf2,
|
52 |
+
linewidth=2,
|
53 |
+
label="LDA with Ledoit Wolf",
|
54 |
+
color="navy",
|
55 |
+
linestyle="dashed",
|
56 |
+
)
|
57 |
+
plt.plot(
|
58 |
+
features_samples_ratio,
|
59 |
+
acc_clf3,
|
60 |
+
linewidth=2,
|
61 |
+
label="LDA with OAS",
|
62 |
+
color="red",
|
63 |
+
linestyle="dotted",
|
64 |
+
)
|
65 |
+
|
66 |
+
plt.xlabel("n_features / n_samples")
|
67 |
+
plt.ylabel("Classification accuracy")
|
68 |
+
plt.legend(loc="lower left")
|
69 |
+
plt.ylim((0.65, 1.0))
|
70 |
+
plt.suptitle(
|
71 |
+
"LDA (Linear Discriminant Analysis) vs. "
|
72 |
+
+ "\n"
|
73 |
+
+ "LDA with Ledoit Wolf vs. "
|
74 |
+
+ "\n"
|
75 |
+
+ "LDA with OAS (1 discriminative feature)"
|
76 |
+
)
|
77 |
+
|
78 |
+
# Convert the plot to Gradio compatible format
|
79 |
+
plt.tight_layout()
|
80 |
+
plt.savefig("plot.png")
|
81 |
+
return "plot.png"
|
82 |
+
|
83 |
+
# Define the input and output interfaces
|
84 |
+
inputs = [
|
85 |
+
gr.inputs.Slider(minimum=1, maximum=100, step=1, label="n_train", default=20),
|
86 |
+
gr.inputs.Slider(minimum=1, maximum=500, step=1, label="n_test", default=200),
|
87 |
+
gr.inputs.Slider(minimum=1, maximum=100, step=1, label="n_averages", default=50),
|
88 |
+
gr.inputs.Slider(minimum=1, maximum=100, step=1, label="n_features_max", default=75),
|
89 |
+
gr.inputs.Slider(minimum=1, maximum=20, step=1, label="step", default=4),
|
90 |
+
]
|
91 |
+
output = gr.outputs.Image(type="pil")
|
92 |
+
examples = [
|
93 |
+
[20, 200, 50, 75, 4],
|
94 |
+
[30, 250, 60, 80, 5],
|
95 |
+
[40, 300, 70, 90, 6],
|
96 |
+
]
|
97 |
+
|
98 |
+
# Create the Gradio app
|
99 |
+
title = "Normal, Ledoit-Wolf and OAS Linear Discriminant Analysis for classification"
|
100 |
+
description = "This example illustrates how the Ledoit-Wolf and Oracle Shrinkage Approximating (OAS) estimators of covariance can improve classification. See the original example: https://scikit-learn.org/stable/auto_examples/classification/plot_lda.html"
|
101 |
+
gr.Interface(classify, inputs, output, examples=examples, title=title, description=description).launch()
|