caliex commited on
Commit
ba0a8b8
·
1 Parent(s): 1e3bb19

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +101 -0
app.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import matplotlib.pyplot as plt
4
+
5
+ from sklearn.datasets import make_blobs
6
+ from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
7
+ from sklearn.covariance import OAS
8
+
9
+ def generate_data(n_samples, n_features):
10
+ X, y = make_blobs(n_samples=n_samples, n_features=1, centers=[[-2], [2]])
11
+
12
+ if n_features > 1:
13
+ X = np.hstack([X, np.random.randn(n_samples, n_features - 1)])
14
+ return X, y
15
+
16
+ def classify(n_train, n_test, n_averages, n_features_max, step):
17
+ acc_clf1, acc_clf2, acc_clf3 = [], [], []
18
+ n_features_range = range(1, n_features_max + 1, step)
19
+
20
+ for n_features in n_features_range:
21
+ score_clf1, score_clf2, score_clf3 = 0, 0, 0
22
+ for _ in range(n_averages):
23
+ X, y = generate_data(n_train, n_features)
24
+
25
+ clf1 = LinearDiscriminantAnalysis(solver="lsqr", shrinkage=None).fit(X, y)
26
+ clf2 = LinearDiscriminantAnalysis(solver="lsqr", shrinkage="auto").fit(X, y)
27
+ oa = OAS(store_precision=False, assume_centered=False)
28
+ clf3 = LinearDiscriminantAnalysis(solver="lsqr", covariance_estimator=oa).fit(X, y)
29
+
30
+ X, y = generate_data(n_test, n_features)
31
+ score_clf1 += clf1.score(X, y)
32
+ score_clf2 += clf2.score(X, y)
33
+ score_clf3 += clf3.score(X, y)
34
+
35
+ acc_clf1.append(score_clf1 / n_averages)
36
+ acc_clf2.append(score_clf2 / n_averages)
37
+ acc_clf3.append(score_clf3 / n_averages)
38
+
39
+ features_samples_ratio = np.array(n_features_range) / n_train
40
+
41
+ plt.plot(
42
+ features_samples_ratio,
43
+ acc_clf1,
44
+ linewidth=2,
45
+ label="LDA",
46
+ color="gold",
47
+ linestyle="solid",
48
+ )
49
+ plt.plot(
50
+ features_samples_ratio,
51
+ acc_clf2,
52
+ linewidth=2,
53
+ label="LDA with Ledoit Wolf",
54
+ color="navy",
55
+ linestyle="dashed",
56
+ )
57
+ plt.plot(
58
+ features_samples_ratio,
59
+ acc_clf3,
60
+ linewidth=2,
61
+ label="LDA with OAS",
62
+ color="red",
63
+ linestyle="dotted",
64
+ )
65
+
66
+ plt.xlabel("n_features / n_samples")
67
+ plt.ylabel("Classification accuracy")
68
+ plt.legend(loc="lower left")
69
+ plt.ylim((0.65, 1.0))
70
+ plt.suptitle(
71
+ "LDA (Linear Discriminant Analysis) vs. "
72
+ + "\n"
73
+ + "LDA with Ledoit Wolf vs. "
74
+ + "\n"
75
+ + "LDA with OAS (1 discriminative feature)"
76
+ )
77
+
78
+ # Convert the plot to Gradio compatible format
79
+ plt.tight_layout()
80
+ plt.savefig("plot.png")
81
+ return "plot.png"
82
+
83
+ # Define the input and output interfaces
84
+ inputs = [
85
+ gr.inputs.Slider(minimum=1, maximum=100, step=1, label="n_train", default=20),
86
+ gr.inputs.Slider(minimum=1, maximum=500, step=1, label="n_test", default=200),
87
+ gr.inputs.Slider(minimum=1, maximum=100, step=1, label="n_averages", default=50),
88
+ gr.inputs.Slider(minimum=1, maximum=100, step=1, label="n_features_max", default=75),
89
+ gr.inputs.Slider(minimum=1, maximum=20, step=1, label="step", default=4),
90
+ ]
91
+ output = gr.outputs.Image(type="pil")
92
+ examples = [
93
+ [20, 200, 50, 75, 4],
94
+ [30, 250, 60, 80, 5],
95
+ [40, 300, 70, 90, 6],
96
+ ]
97
+
98
+ # Create the Gradio app
99
+ title = "Normal, Ledoit-Wolf and OAS Linear Discriminant Analysis for classification"
100
+ description = "This example illustrates how the Ledoit-Wolf and Oracle Shrinkage Approximating (OAS) estimators of covariance can improve classification. See the original example: https://scikit-learn.org/stable/auto_examples/classification/plot_lda.html"
101
+ gr.Interface(classify, inputs, output, examples=examples, title=title, description=description).launch()