Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import time
|
3 |
+
import numpy as np
|
4 |
+
from scipy.ndimage import gaussian_filter
|
5 |
+
import matplotlib.pyplot as plt
|
6 |
+
from skimage.data import coins
|
7 |
+
from skimage.transform import rescale
|
8 |
+
from sklearn.feature_extraction import image
|
9 |
+
from sklearn.cluster import spectral_clustering
|
10 |
+
import gradio as gr
|
11 |
+
|
12 |
+
|
13 |
+
# function for making the clustering plot.
|
14 |
+
# input: one of the following algorithms: "kmeans", "discretize", "cluster_qr"
|
15 |
+
def getClusteringPlot(algorithm):
|
16 |
+
# load the coins as a numpy array
|
17 |
+
orig_coins = coins()
|
18 |
+
|
19 |
+
# Pre-processing the image
|
20 |
+
smoothened_coins = gaussian_filter(orig_coins, sigma=2)
|
21 |
+
rescaled_coins = rescale(smoothened_coins, 0.2, mode="reflect", anti_aliasing=False)
|
22 |
+
|
23 |
+
# Convert the image into a graph
|
24 |
+
graph = image.img_to_graph(rescaled_coins)
|
25 |
+
|
26 |
+
beta = 10
|
27 |
+
eps = 1e-6
|
28 |
+
graph.data = np.exp(-beta * graph.data / graph.data.std()) + eps
|
29 |
+
|
30 |
+
# The number of segmented regions to display needs to be chosen manually
|
31 |
+
n_regions = 26
|
32 |
+
|
33 |
+
# The spectral clustering quality may also benetif from requesting
|
34 |
+
# extra regions for segmentation.
|
35 |
+
n_regions_plus = 3
|
36 |
+
|
37 |
+
t0 = time.time()
|
38 |
+
labels = spectral_clustering(
|
39 |
+
graph,
|
40 |
+
n_clusters=(n_regions + n_regions_plus),
|
41 |
+
eigen_tol=1e-7,
|
42 |
+
assign_labels=algorithm,
|
43 |
+
random_state=42,
|
44 |
+
)
|
45 |
+
|
46 |
+
t1 = time.time()
|
47 |
+
labels = labels.reshape(rescaled_coins.shape)
|
48 |
+
plt.figure(figsize=(5, 5))
|
49 |
+
plt.imshow(rescaled_coins, cmap=plt.cm.gray)
|
50 |
+
|
51 |
+
plt.xticks(())
|
52 |
+
plt.yticks(())
|
53 |
+
title = "Spectral clustering: %s, %.2fs" % (algorithm, (t1 - t0))
|
54 |
+
print(title)
|
55 |
+
plt.title(title)
|
56 |
+
for l in range(n_regions):
|
57 |
+
colors = [plt.cm.nipy_spectral((l + 4) / float(n_regions + 4))]
|
58 |
+
plt.contour(labels == l, colors=colors)
|
59 |
+
# To view individual segments as appear comment in plt.pause(0.5)
|
60 |
+
return (plt, "%.3fs" % (t1 - t0))
|
61 |
+
|
62 |
+
|
63 |
+
# building the gradio interface
|
64 |
+
with gr.Blocks() as demo:
|
65 |
+
gr.Markdown("## Segmenting the picture of Greek coins in regions 🪙")
|
66 |
+
gr.Markdown("This demo is based on this [scikit-learn example](https://scikit-learn.org/stable/auto_examples/cluster/plot_coin_segmentation.html#sphx-glr-auto-examples-cluster-plot-coin-segmentation-py).")
|
67 |
+
gr.Markdown("In this demo, we compare three strategies for performing segmentation-clustering and breaking the below image of Greek coins into multiple partly-homogeneous regions.")
|
68 |
+
gr.Image(coins(), label="An image of 24 Greek coins")
|
69 |
+
gr.Markdown("The image is retrieved from scikit-image's data [gallery](https://scikit-image.org/docs/stable/auto_examples/).")
|
70 |
+
inp = gr.Radio(["kmeans", "discretize", "cluster_qr"], label="Solver", info="Choose a clustering algorithm", value="kmeans")
|
71 |
+
with gr.Row():
|
72 |
+
plot = gr.Plot(label="Plot")
|
73 |
+
num = gr.Textbox(label="Running Time")
|
74 |
+
inp.change(getClusteringPlot, inputs=[inp], outputs=[plot, num])
|
75 |
+
demo.load(getClusteringPlot, inputs=[inp], outputs=[plot, num])
|
76 |
+
|
77 |
+
demo.launch()
|