from time import time import gradio as gr import numpy as np import matplotlib.pyplot as plt import plotly.graph_objects as go from sklearn import manifold, datasets from sklearn.cluster import AgglomerativeClustering SEED = 0 digits = datasets.load_digits() X, y = digits.data, digits.target n_samples, n_features = X.shape np.random.seed(SEED) import matplotlib matplotlib.use('Agg') def plot_clustering(linkage, dim): if dim == '3D': X_red = manifold.SpectralEmbedding(n_components=3).fit_transform(X) else: X_red = manifold.SpectralEmbedding(n_components=2).fit_transform(X) clustering = AgglomerativeClustering(linkage=linkage, n_clusters=10) t0 = time() clustering.fit(X_red) print("%s :\t%.2fs" % (linkage, time() - t0)) labels = clustering.labels_ x_min, x_max = np.min(X_red, axis=0), np.max(X_red, axis=0) X_red = (X_red - x_min) / (x_max - x_min) fig = go.Figure() for digit in digits.target_names: subset = X_red[y==digit] rgbas = plt.cm.nipy_spectral(labels[y == digit]/10) color = [f'rgba({rgba[0]}, {rgba[1]}, {rgba[2]}, 0.8)' for rgba in rgbas] if dim == '2D': fig.add_trace(go.Scatter(x=subset[:,0], y=subset[:,1], mode='text', text=str(digit), textfont={'size': 16, 'color': color})) elif dim == '3D': fig.add_trace(go.Scatter3d(x=subset[:,0], y=subset[:,1], z=subset[:,2], mode='text', text=str(digit), textfont={'size': 16, 'color': color})) fig.update_traces(showlegend=False) return fig title = '# Agglomerative Clustering on MNIST' description = """ An illustration of various linkage option for [agglomerative clustering](https://scikit-learn.org/stable/modules/generated/sklearn.cluster.AgglomerativeClustering.html) on the digits dataset. The goal of this example is to show intuitively how the metrics behave, and not to find good clusters for the digits. What this example shows us is the behavior of "rich getting richer" in agglomerative clustering, which tends to create uneven cluster sizes. This behavior is pronounced for the average linkage strategy, which ends up with a couple of clusters having few data points. The case of single linkage is even more pathological, with a very large cluster covering most digits, an intermediate-sized (clean) cluster with mostly zero digits, and all other clusters being drawn from noise points around the fringes. The other linkage strategies lead to more evenly distributed clusters, which are therefore likely to be less sensitive to random resampling of the dataset. """ author = ''' Created by [@Hnabil](https://huggingface.co/Hnabil) based on [scikit-learn docs](https://scikit-learn.org/stable/auto_examples/cluster/plot_digits_linkage.html) ''' with gr.Blocks(analytics_enabled=False, title=title) as demo: gr.Markdown(title) gr.Markdown(description) gr.Markdown(author) with gr.Row(): with gr.Column(): linkage = gr.Radio(["ward", "average", "complete", "single"], value="average", interactive=True, label="Linkage Method") dim = gr.Radio(['2D', '3D'], label='Embedding Dimensionality', value='2D') btn = gr.Button('Submit') with gr.Column(): plot = gr.Plot(label='MNIST Embeddings') btn.click(plot_clustering, inputs=[linkage, dim], outputs=[plot]) demo.load(plot_clustering, inputs=[linkage, dim], outputs=[plot]) demo.launch()