Hnabil's picture
Update app.py
2b037b3
from time import time
import gradio as gr
import numpy as np
import matplotlib.pyplot as plt
import plotly.graph_objects as go
from sklearn import manifold, datasets
from sklearn.cluster import AgglomerativeClustering
SEED = 0
digits = datasets.load_digits()
X, y = digits.data, digits.target
n_samples, n_features = X.shape
np.random.seed(SEED)
import matplotlib
matplotlib.use('Agg')
def plot_clustering(linkage, dim):
if dim == '3D':
X_red = manifold.SpectralEmbedding(n_components=3).fit_transform(X)
else:
X_red = manifold.SpectralEmbedding(n_components=2).fit_transform(X)
clustering = AgglomerativeClustering(linkage=linkage, n_clusters=10)
t0 = time()
clustering.fit(X_red)
print("%s :\t%.2fs" % (linkage, time() - t0))
labels = clustering.labels_
x_min, x_max = np.min(X_red, axis=0), np.max(X_red, axis=0)
X_red = (X_red - x_min) / (x_max - x_min)
fig = go.Figure()
for digit in digits.target_names:
subset = X_red[y==digit]
rgbas = plt.cm.nipy_spectral(labels[y == digit]/10)
color = [f'rgba({rgba[0]}, {rgba[1]}, {rgba[2]}, 0.8)' for rgba in rgbas]
if dim == '2D':
fig.add_trace(go.Scatter(x=subset[:,0], y=subset[:,1], mode='text', text=str(digit), textfont={'size': 16, 'color': color}))
elif dim == '3D':
fig.add_trace(go.Scatter3d(x=subset[:,0], y=subset[:,1], z=subset[:,2], mode='text', text=str(digit), textfont={'size': 16, 'color': color}))
fig.update_traces(showlegend=False)
return fig
title = '# Agglomerative Clustering on MNIST'
description = """
An illustration of various linkage option for [agglomerative clustering](https://scikit-learn.org/stable/modules/generated/sklearn.cluster.AgglomerativeClustering.html) on the digits dataset.
The goal of this example is to show intuitively how the metrics behave, and not to find good clusters for the digits.
What this example shows us is the behavior of "rich getting richer" in agglomerative clustering, which tends to create uneven cluster sizes.
This behavior is pronounced for the average linkage strategy, which ends up with a couple of clusters having few data points.
The case of single linkage is even more pathological, with a very large cluster covering most digits, an intermediate-sized (clean) cluster with mostly zero digits, and all other clusters being drawn from noise points around the fringes.
The other linkage strategies lead to more evenly distributed clusters, which are therefore likely to be less sensitive to random resampling of the dataset.
"""
author = '''
Created by [@Hnabil](https://huggingface.co/Hnabil) based on [scikit-learn docs](https://scikit-learn.org/stable/auto_examples/cluster/plot_digits_linkage.html)
'''
with gr.Blocks(analytics_enabled=False, title=title) as demo:
gr.Markdown(title)
gr.Markdown(description)
gr.Markdown(author)
with gr.Row():
with gr.Column():
linkage = gr.Radio(["ward", "average", "complete", "single"], value="average", interactive=True, label="Linkage Method")
dim = gr.Radio(['2D', '3D'], label='Embedding Dimensionality', value='2D')
btn = gr.Button('Submit')
with gr.Column():
plot = gr.Plot(label='MNIST Embeddings')
btn.click(plot_clustering, inputs=[linkage, dim], outputs=[plot])
demo.load(plot_clustering, inputs=[linkage, dim], outputs=[plot])
demo.launch()