File size: 3,485 Bytes
1cc0557
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2b037b3
 
 
 
 
 
 
 
 
 
1cc0557
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
from time import time
import gradio as gr
import numpy as np
import matplotlib.pyplot as plt
import plotly.graph_objects as go

from sklearn import manifold, datasets
from sklearn.cluster import AgglomerativeClustering


SEED = 0
digits = datasets.load_digits()
X, y = digits.data, digits.target
n_samples, n_features = X.shape
np.random.seed(SEED)

import matplotlib
matplotlib.use('Agg')



def plot_clustering(linkage, dim):
    if dim == '3D':
        X_red = manifold.SpectralEmbedding(n_components=3).fit_transform(X)
    else:
        X_red = manifold.SpectralEmbedding(n_components=2).fit_transform(X)

    clustering = AgglomerativeClustering(linkage=linkage, n_clusters=10)

    t0 = time()
    clustering.fit(X_red)
    print("%s :\t%.2fs" % (linkage, time() - t0))

    labels = clustering.labels_

    x_min, x_max = np.min(X_red, axis=0), np.max(X_red, axis=0)
    X_red = (X_red - x_min) / (x_max - x_min)

    fig = go.Figure()
   
    for digit in digits.target_names:
        subset = X_red[y==digit]
        rgbas = plt.cm.nipy_spectral(labels[y == digit]/10)
        color = [f'rgba({rgba[0]}, {rgba[1]}, {rgba[2]}, 0.8)' for rgba in rgbas]
        if dim == '2D':
            fig.add_trace(go.Scatter(x=subset[:,0], y=subset[:,1], mode='text', text=str(digit), textfont={'size': 16, 'color': color}))
        elif dim == '3D':
            fig.add_trace(go.Scatter3d(x=subset[:,0], y=subset[:,1], z=subset[:,2], mode='text', text=str(digit), textfont={'size': 16, 'color': color}))
    
    fig.update_traces(showlegend=False)

    return fig


title = '# Agglomerative Clustering on MNIST'

description = """
An illustration of various linkage option for [agglomerative clustering](https://scikit-learn.org/stable/modules/generated/sklearn.cluster.AgglomerativeClustering.html) on the digits dataset.

The goal of this example is to show intuitively how the metrics behave, and not to find good clusters for the digits.

What this example shows us is the behavior of "rich getting richer" in agglomerative clustering, which tends to create uneven cluster sizes.

This behavior is pronounced for the average linkage strategy, which ends up with a couple of clusters having few data points.

The case of single linkage is even more pathological, with a very large cluster covering most digits, an intermediate-sized (clean) cluster with mostly zero digits, and all other clusters being drawn from noise points around the fringes.

The other linkage strategies lead to more evenly distributed clusters, which are therefore likely to be less sensitive to random resampling of the dataset.
"""

author = '''
Created by [@Hnabil](https://huggingface.co/Hnabil) based on [scikit-learn docs](https://scikit-learn.org/stable/auto_examples/cluster/plot_digits_linkage.html)
'''

with gr.Blocks(analytics_enabled=False, title=title) as demo:
    gr.Markdown(title)
    gr.Markdown(description)
    gr.Markdown(author)
    
    with gr.Row():
        with gr.Column():
            linkage = gr.Radio(["ward", "average", "complete", "single"], value="average", interactive=True, label="Linkage Method")
            dim = gr.Radio(['2D', '3D'], label='Embedding Dimensionality', value='2D')

            btn = gr.Button('Submit')
        
        with gr.Column():
            plot = gr.Plot(label='MNIST Embeddings')
    
    btn.click(plot_clustering, inputs=[linkage, dim], outputs=[plot])
    demo.load(plot_clustering, inputs=[linkage,  dim], outputs=[plot])

demo.launch()