|
from time import time |
|
import gradio as gr |
|
import numpy as np |
|
import matplotlib.pyplot as plt |
|
import plotly.graph_objects as go |
|
|
|
from sklearn import manifold, datasets |
|
from sklearn.cluster import AgglomerativeClustering |
|
|
|
|
|
SEED = 0 |
|
digits = datasets.load_digits() |
|
X, y = digits.data, digits.target |
|
n_samples, n_features = X.shape |
|
np.random.seed(SEED) |
|
|
|
import matplotlib |
|
matplotlib.use('Agg') |
|
|
|
|
|
|
|
def plot_clustering(linkage, dim): |
|
if dim == '3D': |
|
X_red = manifold.SpectralEmbedding(n_components=3).fit_transform(X) |
|
else: |
|
X_red = manifold.SpectralEmbedding(n_components=2).fit_transform(X) |
|
|
|
clustering = AgglomerativeClustering(linkage=linkage, n_clusters=10) |
|
|
|
t0 = time() |
|
clustering.fit(X_red) |
|
print("%s :\t%.2fs" % (linkage, time() - t0)) |
|
|
|
labels = clustering.labels_ |
|
|
|
x_min, x_max = np.min(X_red, axis=0), np.max(X_red, axis=0) |
|
X_red = (X_red - x_min) / (x_max - x_min) |
|
|
|
fig = go.Figure() |
|
|
|
for digit in digits.target_names: |
|
subset = X_red[y==digit] |
|
rgbas = plt.cm.nipy_spectral(labels[y == digit]/10) |
|
color = [f'rgba({rgba[0]}, {rgba[1]}, {rgba[2]}, 0.8)' for rgba in rgbas] |
|
if dim == '2D': |
|
fig.add_trace(go.Scatter(x=subset[:,0], y=subset[:,1], mode='text', text=str(digit), textfont={'size': 16, 'color': color})) |
|
elif dim == '3D': |
|
fig.add_trace(go.Scatter3d(x=subset[:,0], y=subset[:,1], z=subset[:,2], mode='text', text=str(digit), textfont={'size': 16, 'color': color})) |
|
|
|
fig.update_traces(showlegend=False) |
|
|
|
return fig |
|
|
|
|
|
title = '# Agglomerative Clustering on MNIST' |
|
|
|
description = """ |
|
An illustration of various linkage option for [agglomerative clustering](https://scikit-learn.org/stable/modules/generated/sklearn.cluster.AgglomerativeClustering.html) on the digits dataset. |
|
|
|
The goal of this example is to show intuitively how the metrics behave, and not to find good clusters for the digits. |
|
|
|
What this example shows us is the behavior of "rich getting richer" in agglomerative clustering, which tends to create uneven cluster sizes. |
|
|
|
This behavior is pronounced for the average linkage strategy, which ends up with a couple of clusters having few data points. |
|
|
|
The case of single linkage is even more pathological, with a very large cluster covering most digits, an intermediate-sized (clean) cluster with mostly zero digits, and all other clusters being drawn from noise points around the fringes. |
|
|
|
The other linkage strategies lead to more evenly distributed clusters, which are therefore likely to be less sensitive to random resampling of the dataset. |
|
""" |
|
|
|
author = ''' |
|
Created by [@Hnabil](https://huggingface.co/Hnabil) based on [scikit-learn docs](https://scikit-learn.org/stable/auto_examples/cluster/plot_digits_linkage.html) |
|
''' |
|
|
|
with gr.Blocks(analytics_enabled=False, title=title) as demo: |
|
gr.Markdown(title) |
|
gr.Markdown(description) |
|
gr.Markdown(author) |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
linkage = gr.Radio(["ward", "average", "complete", "single"], value="average", interactive=True, label="Linkage Method") |
|
dim = gr.Radio(['2D', '3D'], label='Embedding Dimensionality', value='2D') |
|
|
|
btn = gr.Button('Submit') |
|
|
|
with gr.Column(): |
|
plot = gr.Plot(label='MNIST Embeddings') |
|
|
|
btn.click(plot_clustering, inputs=[linkage, dim], outputs=[plot]) |
|
demo.load(plot_clustering, inputs=[linkage, dim], outputs=[plot]) |
|
|
|
demo.launch() |