sulpha commited on
Commit
f996981
·
1 Parent(s): 9ab3114

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +116 -0
app.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ =======================================================================================
3
+ Gradio demo to plot the decision surface of decision trees trained on the iris dataset
4
+ =======================================================================================
5
+
6
+ Plot the decision surface of a decision tree trained on pairs
7
+ of features of the iris dataset.
8
+
9
+ For each pair of iris features, the decision tree learns decision
10
+ boundaries made of combinations of simple thresholding rules inferred from
11
+ the training samples.
12
+
13
+ We also show the tree structure of a model built on all of the features.
14
+
15
+ Gradio demo created by Syed Affan <saffand03@gmail.com>
16
+ """
17
+ from sklearn.datasets import load_iris
18
+ from sklearn.tree import plot_tree
19
+ import numpy as np
20
+ import matplotlib.pyplot as plt
21
+ import gradio as gr
22
+ from sklearn.tree import DecisionTreeClassifier
23
+ from sklearn.inspection import DecisionBoundaryDisplay
24
+
25
+
26
+ iris = load_iris()
27
+
28
+ def make_plot(criterion,max_depth,ccp_alpha):
29
+ # Parameters
30
+ n_classes = 3
31
+ plot_colors = "ryb"
32
+ plot_step = 0.02
33
+
34
+ fig_1 = plt.figure()
35
+
36
+ for pairidx, pair in enumerate([[0, 1], [0, 2], [0, 3], [1, 2], [1, 3], [2, 3]]):
37
+ # We only take the two corresponding features
38
+ X = iris.data[:, pair]
39
+ y = iris.target
40
+
41
+ # Train
42
+ clf = DecisionTreeClassifier(criterion=criterion,max_depth=max_depth,ccp_alpha=ccp_alpha)
43
+ clf.fit(X, y)
44
+
45
+ # Plot the decision boundary
46
+ ax = plt.subplot(2, 3, pairidx + 1)
47
+ plt.tight_layout(h_pad=0.5, w_pad=0.5, pad=2.5)
48
+ DecisionBoundaryDisplay.from_estimator(
49
+ clf,
50
+ X,
51
+ cmap=plt.cm.RdYlBu,
52
+ response_method="predict",
53
+ ax=ax,
54
+ xlabel=iris.feature_names[pair[0]],
55
+ ylabel=iris.feature_names[pair[1]],
56
+ )
57
+
58
+ # Plot the training points
59
+ for i, color in zip(range(n_classes), plot_colors):
60
+ idx = np.where(y == i)
61
+ plt.scatter(
62
+ X[idx, 0],
63
+ X[idx, 1],
64
+ c=color,
65
+ label=iris.target_names[i],
66
+ cmap=plt.cm.RdYlBu,
67
+ edgecolor="black",
68
+ s=15,
69
+ )
70
+
71
+ plt.suptitle("Decision surface of decision trees trained on pairs of features")
72
+ plt.legend(loc="lower right", borderpad=0, handletextpad=0)
73
+ _ = plt.axis("tight")
74
+
75
+ # %%
76
+ # Display the structure of a single decision tree trained on all the features
77
+ # together.
78
+
79
+ fig_2 = plt.figure()
80
+ clf = DecisionTreeClassifier(criterion=criterion,max_depth=max_depth,ccp_alpha=ccp_alpha).fit(iris.data, iris.target)
81
+ plot_tree(clf, filled=True)
82
+ plt.title("Decision tree trained on all the iris features")
83
+ return fig_1,fig_2
84
+
85
+ title = 'Plot the decision surface of decision trees trained on the iris dataset'
86
+
87
+ model_card = f"""
88
+ ## Description:
89
+ Plot the decision surface of a decision tree trained on pairs of features of the iris dataset.
90
+ For each pair of iris features, the decision tree learns decision boundaries made of combinations of simple thresholding rules inferred from the training samples.
91
+ We also show the tree structure of a model built on all of the features.
92
+ ## Dataset
93
+ Iris Dataset
94
+ """
95
+
96
+ with gr.Blocks(title=title) as demo:
97
+ gr.Markdown('''
98
+ <div>
99
+ <h1 style='text-align: center'>⚒ Plot the decision surface of decision trees trained on the iris dataset 🛠</h1>
100
+ </div>
101
+ ''')
102
+ gr.Markdown(model_card)
103
+ gr.Markdown("Author: <a href=\"https://huggingface.co/sulpha\">sulpha</a>")
104
+ with gr.Column():
105
+ d0 = gr.Radio(['gini', 'entropy', 'log_loss'],value='gini',label='Criterion')
106
+ d1 = gr.Slider(1,10,step=1,value=5,label = 'max_depth')
107
+ d2 = gr.Slider(0.0,1,step=0.001,value=0.0,label = 'ccp_alpha')
108
+
109
+ with gr.Row():
110
+ p_1 = gr.Plot()
111
+ p_2 = gr.Plot()
112
+
113
+ btn = gr.Button(value= 'Submit')
114
+ btn.click(make_plot,inputs=[d0,d1,d2],outputs=[p_1,p_2])
115
+
116
+ demo.launch()