File size: 3,321 Bytes
9c6594c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
from warnings import simplefilter

import numpy as np
from sklearn.metrics import silhouette_samples, silhouette_score
from sklearn.preprocessing import LabelEncoder

import wandb
from wandb.integration.sklearn import utils

# ignore all future warnings
simplefilter(action="ignore", category=FutureWarning)


def silhouette(clusterer, X, cluster_labels, labels, metric, kmeans):  # noqa: N803
    # Run clusterer for n_clusters in range(len(cluster_ranges), get cluster labels
    # TODO - keep/delete once we decide if we should train clusterers
    # or ask for trained models
    # clusterer.set_params(n_clusters=n_clusters, random_state=42)
    # cluster_labels = clusterer.fit_predict(X)
    cluster_labels = np.asarray(cluster_labels)
    labels = np.asarray(labels)

    le = LabelEncoder()
    _ = le.fit_transform(cluster_labels)
    n_clusters = len(np.unique(cluster_labels))

    # The silhouette_score gives the average value for all the samples.
    # This gives a perspective into the density and separation of the formed
    # clusters
    silhouette_avg = silhouette_score(X, cluster_labels, metric=metric)

    # Compute the silhouette scores for each sample
    sample_silhouette_values = silhouette_samples(X, cluster_labels, metric=metric)

    x_sil, y_sil, color_sil = [], [], []

    count, y_lower = 0, 10
    for i in range(n_clusters):
        # Aggregate the silhouette scores for samples belonging to
        # cluster i, and sort them
        ith_cluster_silhouette_values = sample_silhouette_values[cluster_labels == i]

        ith_cluster_silhouette_values.sort()

        size_cluster_i = ith_cluster_silhouette_values.shape[0]
        y_upper = y_lower + size_cluster_i

        y_values = np.arange(y_lower, y_upper)

        for j in range(len(y_values)):
            y_sil.append(y_values[j])
            x_sil.append(ith_cluster_silhouette_values[j])
            color_sil.append(i)
            count += 1
            if utils.check_against_limit(count, "silhouette", utils.chart_limit):
                break

        # Compute the new y_lower for next plot
        y_lower = y_upper + 10  # 10 for the 0 samples

    if kmeans:
        centers = clusterer.cluster_centers_
        centerx = centers[:, 0]
        centery = centers[:, 1]

    else:
        centerx = [None] * len(color_sil)
        centery = [None] * len(color_sil)

    table = make_table(
        X[:, 0],
        X[:, 1],
        cluster_labels,
        centerx,
        centery,
        y_sil,
        x_sil,
        color_sil,
        silhouette_avg,
    )
    chart = wandb.visualize("wandb/silhouette_/v1", table)

    return chart


def make_table(x, y, colors, centerx, centery, y_sil, x_sil, color_sil, silhouette_avg):
    columns = [
        "x",
        "y",
        "colors",
        "centerx",
        "centery",
        "y_sil",
        "x1",
        "x2",
        "color_sil",
        "silhouette_avg",
    ]

    data = [
        [
            x[i],
            y[i],
            colors[i],
            centerx[colors[i]],
            centery[colors[i]],
            y_sil[i],
            0,
            x_sil[i],
            color_sil[i],
            silhouette_avg,
        ]
        for i in range(len(color_sil))
    ]

    table = wandb.Table(data=data, columns=columns)

    return table