File size: 3,707 Bytes
a7fd807
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8b9d6c2
87b9b04
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import streamlit as st
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.cluster import KMeans
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
from datasets import load_dataset

# App Title
st.title("Unsupervised Data Clustering App")

# About App
with st.expander("About this App"):
    st.write(
        "This app allows you to upload any type of unlabeled dataset "
        "and automatically clusters the data using K-means clustering. "
        "It visualizes the clusters using PCA and provides time series and cluster distribution plots "
        "to help you identify patterns and groupings within your data."
    )

# File uploader
uploaded_file = st.file_uploader("Upload Custom CSV file", type=["csv"])

# # Example Demo Dataset
if st.button("Test With An Example Dataset"):
    dataset = load_dataset('kheejay88/country_data', split='train')
    df = pd.DataFrame(dataset)
    st.success("Loaded example dataset from Hugging Face.")

with st.expander("Dataset Columns"):
    st.write("""
        **country** – Name of the country\n
        **child_mort** – Death of children under 5 years of age per 1000 live births\n
        **exports** – Exports of goods and services per capita (as a percentage of GDP)\n
        **health** – Total health spending per capita (as a percentage of GDP)\n
        **imports** – Imports of goods and services per capita (as a percentage of GDP)\n
        **income** – Net income per person\n
        **inflation** – Annual inflation rate (percentage)\n
        **life_expec** – Average life expectancy at birth (in years)\n
        **total_fer** – Total fertility rate (average number of children per woman)\n
        **gdpp** – GDP per capita\n
    """)

if uploaded_file is not None:
    df = pd.read_csv(uploaded_file)

if 'df' in locals():
    # Drop non-numeric columns
    categorical_cols = df.select_dtypes(include=['object', 'category']).columns.tolist()
    df.drop(columns=categorical_cols, inplace=True)
    st.write("### Raw Data:")
    st.write(df.head())

    # Preprocessing
    scaler = StandardScaler()
    scaled_data = scaler.fit_transform(df)

    # User input for clusters
    num_clusters = st.slider("Select number of clusters", min_value=2, max_value=10, value=3)

    # K-Means Clustering
    kmeans = KMeans(n_clusters=num_clusters, random_state=42)
    clusters = kmeans.fit_predict(scaled_data)
    df['Cluster'] = clusters

    # PCA for visualization
    pca = PCA(n_components=2)
    pca_data = pca.fit_transform(scaled_data)
    df['PCA1'] = pca_data[:, 0]
    df['PCA2'] = pca_data[:, 1]

    # Plot Clusters
    st.write("### Cluster Visualization:")
    fig, ax = plt.subplots()
    sns.scatterplot(x='PCA1', y='PCA2', hue='Cluster', data=df, palette='viridis', ax=ax)
    st.pyplot(fig)

    # Time Series Plot (if available)
    numeric_cols = df.select_dtypes(include=[np.number]).columns.tolist()
    if len(numeric_cols) >= 2:
        selected_col = st.selectbox("Select column for time series visualization", numeric_cols)
        st.write("### Time Series Plot:")
        fig, ax = plt.subplots()
        for cluster in df['Cluster'].unique():
            cluster_data = df[df['Cluster'] == cluster]
            ax.plot(cluster_data.index, cluster_data[selected_col], label=f'Cluster {cluster}')
        ax.legend()
        st.pyplot(fig)

    # Cluster distribution
    st.write("### Cluster Distribution:")
    fig, ax = plt.subplots()
    sns.countplot(x='Cluster', data=df, palette='viridis', ax=ax)
    st.pyplot(fig)

    st.markdown("---")  # Adds a horizontal line
    st.markdown("**Thanks!**")