jamtur01's picture
Upload folder using huggingface_hub
9c6594c verified
from warnings import simplefilter
import numpy as np
from sklearn.metrics import silhouette_samples, silhouette_score
from sklearn.preprocessing import LabelEncoder
import wandb
from wandb.integration.sklearn import utils
# ignore all future warnings
simplefilter(action="ignore", category=FutureWarning)
def silhouette(clusterer, X, cluster_labels, labels, metric, kmeans): # noqa: N803
# Run clusterer for n_clusters in range(len(cluster_ranges), get cluster labels
# TODO - keep/delete once we decide if we should train clusterers
# or ask for trained models
# clusterer.set_params(n_clusters=n_clusters, random_state=42)
# cluster_labels = clusterer.fit_predict(X)
cluster_labels = np.asarray(cluster_labels)
labels = np.asarray(labels)
le = LabelEncoder()
_ = le.fit_transform(cluster_labels)
n_clusters = len(np.unique(cluster_labels))
# The silhouette_score gives the average value for all the samples.
# This gives a perspective into the density and separation of the formed
# clusters
silhouette_avg = silhouette_score(X, cluster_labels, metric=metric)
# Compute the silhouette scores for each sample
sample_silhouette_values = silhouette_samples(X, cluster_labels, metric=metric)
x_sil, y_sil, color_sil = [], [], []
count, y_lower = 0, 10
for i in range(n_clusters):
# Aggregate the silhouette scores for samples belonging to
# cluster i, and sort them
ith_cluster_silhouette_values = sample_silhouette_values[cluster_labels == i]
ith_cluster_silhouette_values.sort()
size_cluster_i = ith_cluster_silhouette_values.shape[0]
y_upper = y_lower + size_cluster_i
y_values = np.arange(y_lower, y_upper)
for j in range(len(y_values)):
y_sil.append(y_values[j])
x_sil.append(ith_cluster_silhouette_values[j])
color_sil.append(i)
count += 1
if utils.check_against_limit(count, "silhouette", utils.chart_limit):
break
# Compute the new y_lower for next plot
y_lower = y_upper + 10 # 10 for the 0 samples
if kmeans:
centers = clusterer.cluster_centers_
centerx = centers[:, 0]
centery = centers[:, 1]
else:
centerx = [None] * len(color_sil)
centery = [None] * len(color_sil)
table = make_table(
X[:, 0],
X[:, 1],
cluster_labels,
centerx,
centery,
y_sil,
x_sil,
color_sil,
silhouette_avg,
)
chart = wandb.visualize("wandb/silhouette_/v1", table)
return chart
def make_table(x, y, colors, centerx, centery, y_sil, x_sil, color_sil, silhouette_avg):
columns = [
"x",
"y",
"colors",
"centerx",
"centery",
"y_sil",
"x1",
"x2",
"color_sil",
"silhouette_avg",
]
data = [
[
x[i],
y[i],
colors[i],
centerx[colors[i]],
centery[colors[i]],
y_sil[i],
0,
x_sil[i],
color_sil[i],
silhouette_avg,
]
for i in range(len(color_sil))
]
table = wandb.Table(data=data, columns=columns)
return table