File size: 2,701 Bytes
1cc0557
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
from time import time
import gradio as gr
import numpy as np
import matplotlib.pyplot as plt
import plotly.graph_objects as go

from sklearn import manifold, datasets
from sklearn.cluster import AgglomerativeClustering


SEED = 0
digits = datasets.load_digits()
X, y = digits.data, digits.target
n_samples, n_features = X.shape
np.random.seed(SEED)

import matplotlib
matplotlib.use('Agg')



def plot_clustering(linkage, dim):
    if dim == '3D':
        X_red = manifold.SpectralEmbedding(n_components=3).fit_transform(X)
    else:
        X_red = manifold.SpectralEmbedding(n_components=2).fit_transform(X)

    clustering = AgglomerativeClustering(linkage=linkage, n_clusters=10)

    t0 = time()
    clustering.fit(X_red)
    print("%s :\t%.2fs" % (linkage, time() - t0))

    labels = clustering.labels_

    x_min, x_max = np.min(X_red, axis=0), np.max(X_red, axis=0)
    X_red = (X_red - x_min) / (x_max - x_min)

    fig = go.Figure()
   
    for digit in digits.target_names:
        subset = X_red[y==digit]
        rgbas = plt.cm.nipy_spectral(labels[y == digit]/10)
        color = [f'rgba({rgba[0]}, {rgba[1]}, {rgba[2]}, 0.8)' for rgba in rgbas]
        if dim == '2D':
            fig.add_trace(go.Scatter(x=subset[:,0], y=subset[:,1], mode='text', text=str(digit), textfont={'size': 16, 'color': color}))
        elif dim == '3D':
            fig.add_trace(go.Scatter3d(x=subset[:,0], y=subset[:,1], z=subset[:,2], mode='text', text=str(digit), textfont={'size': 16, 'color': color}))
    
    fig.update_traces(showlegend=False)

    return fig


title = '# Agglomerative Clustering on MNIST'

description = """
An illustration of various linkage option for [agglomerative clustering](https://scikit-learn.org/stable/modules/generated/sklearn.cluster.AgglomerativeClustering.html) on the digits dataset.
"""

author = '''
Created by [@Hnabil](https://huggingface.co/Hnabil) based on [scikit-learn docs](https://scikit-learn.org/stable/auto_examples/cluster/plot_digits_linkage.html)
'''

with gr.Blocks(analytics_enabled=False, title=title) as demo:
    gr.Markdown(title)
    gr.Markdown(description)
    gr.Markdown(author)
    
    with gr.Row():
        with gr.Column():
            linkage = gr.Radio(["ward", "average", "complete", "single"], value="average", interactive=True, label="Linkage Method")
            dim = gr.Radio(['2D', '3D'], label='Embedding Dimensionality', value='2D')

            btn = gr.Button('Submit')
        
        with gr.Column():
            plot = gr.Plot(label='MNIST Embeddings')
    
    btn.click(plot_clustering, inputs=[linkage, dim], outputs=[plot])
    demo.load(plot_clustering, inputs=[linkage,  dim], outputs=[plot])

demo.launch()