Update app.py
Browse files
app.py
CHANGED
@@ -1,96 +1,96 @@
|
|
1 |
-
import streamlit as st
|
2 |
-
import pandas as pd
|
3 |
-
import numpy as np
|
4 |
-
import matplotlib.pyplot as plt
|
5 |
-
import seaborn as sns
|
6 |
-
from sklearn.cluster import KMeans
|
7 |
-
from sklearn.preprocessing import StandardScaler
|
8 |
-
from sklearn.decomposition import PCA
|
9 |
-
|
10 |
-
|
11 |
-
# App Title
|
12 |
-
st.title("Unsupervised Data Clustering App")
|
13 |
-
|
14 |
-
# About App
|
15 |
-
with st.expander("About this App"):
|
16 |
-
st.write(
|
17 |
-
"This app allows you to upload any type of unlabeled dataset "
|
18 |
-
"and automatically clusters the data using K-means clustering. "
|
19 |
-
"It visualizes the clusters using PCA and provides time series and cluster distribution plots "
|
20 |
-
"to help you identify patterns and groupings within your data."
|
21 |
-
)
|
22 |
-
|
23 |
-
# File uploader
|
24 |
-
uploaded_file = st.file_uploader("Upload Custom CSV file", type=["csv"])
|
25 |
-
|
26 |
-
# # Example Demo Dataset
|
27 |
-
if st.button("Test With An Example Dataset"):
|
28 |
-
dataset = load_dataset('kheejay88/country_data', split='train')
|
29 |
-
df = pd.DataFrame(dataset)
|
30 |
-
st.success("Loaded example dataset from Hugging Face.")
|
31 |
-
|
32 |
-
with st.expander("Dataset Columns"):
|
33 |
-
st.write("""
|
34 |
-
**country** – Name of the country\n
|
35 |
-
**child_mort** – Death of children under 5 years of age per 1000 live births\n
|
36 |
-
**exports** – Exports of goods and services per capita (as a percentage of GDP)\n
|
37 |
-
**health** – Total health spending per capita (as a percentage of GDP)\n
|
38 |
-
**imports** – Imports of goods and services per capita (as a percentage of GDP)\n
|
39 |
-
**income** – Net income per person\n
|
40 |
-
**inflation** – Annual inflation rate (percentage)\n
|
41 |
-
**life_expec** – Average life expectancy at birth (in years)\n
|
42 |
-
**total_fer** – Total fertility rate (average number of children per woman)\n
|
43 |
-
**gdpp** – GDP per capita\n
|
44 |
-
""")
|
45 |
-
|
46 |
-
if uploaded_file is not None:
|
47 |
-
df = pd.read_csv(uploaded_file)
|
48 |
-
|
49 |
-
if 'df' in locals():
|
50 |
-
# Drop non-numeric columns
|
51 |
-
categorical_cols = df.select_dtypes(include=['object', 'category']).columns.tolist()
|
52 |
-
df.drop(columns=categorical_cols, inplace=True)
|
53 |
-
st.write("### Raw Data:")
|
54 |
-
st.write(df.head())
|
55 |
-
|
56 |
-
# Preprocessing
|
57 |
-
scaler = StandardScaler()
|
58 |
-
scaled_data = scaler.fit_transform(df)
|
59 |
-
|
60 |
-
# User input for clusters
|
61 |
-
num_clusters = st.slider("Select number of clusters", min_value=2, max_value=10, value=3)
|
62 |
-
|
63 |
-
# K-Means Clustering
|
64 |
-
kmeans = KMeans(n_clusters=num_clusters, random_state=42)
|
65 |
-
clusters = kmeans.fit_predict(scaled_data)
|
66 |
-
df['Cluster'] = clusters
|
67 |
-
|
68 |
-
# PCA for visualization
|
69 |
-
pca = PCA(n_components=2)
|
70 |
-
pca_data = pca.fit_transform(scaled_data)
|
71 |
-
df['PCA1'] = pca_data[:, 0]
|
72 |
-
df['PCA2'] = pca_data[:, 1]
|
73 |
-
|
74 |
-
# Plot Clusters
|
75 |
-
st.write("### Cluster Visualization:")
|
76 |
-
fig, ax = plt.subplots()
|
77 |
-
sns.scatterplot(x='PCA1', y='PCA2', hue='Cluster', data=df, palette='viridis', ax=ax)
|
78 |
-
st.pyplot(fig)
|
79 |
-
|
80 |
-
# Time Series Plot (if available)
|
81 |
-
numeric_cols = df.select_dtypes(include=[np.number]).columns.tolist()
|
82 |
-
if len(numeric_cols) >= 2:
|
83 |
-
selected_col = st.selectbox("Select column for time series visualization", numeric_cols)
|
84 |
-
st.write("### Time Series Plot:")
|
85 |
-
fig, ax = plt.subplots()
|
86 |
-
for cluster in df['Cluster'].unique():
|
87 |
-
cluster_data = df[df['Cluster'] == cluster]
|
88 |
-
ax.plot(cluster_data.index, cluster_data[selected_col], label=f'Cluster {cluster}')
|
89 |
-
ax.legend()
|
90 |
-
st.pyplot(fig)
|
91 |
-
|
92 |
-
# Cluster distribution
|
93 |
-
st.write("### Cluster Distribution:")
|
94 |
-
fig, ax = plt.subplots()
|
95 |
-
sns.countplot(x='Cluster', data=df, palette='viridis', ax=ax)
|
96 |
-
st.pyplot(fig)
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import pandas as pd
|
3 |
+
import numpy as np
|
4 |
+
import matplotlib.pyplot as plt
|
5 |
+
import seaborn as sns
|
6 |
+
from sklearn.cluster import KMeans
|
7 |
+
from sklearn.preprocessing import StandardScaler
|
8 |
+
from sklearn.decomposition import PCA
|
9 |
+
from datasets import load_dataset
|
10 |
+
|
11 |
+
# App Title
|
12 |
+
st.title("Unsupervised Data Clustering App")
|
13 |
+
|
14 |
+
# About App
|
15 |
+
with st.expander("About this App"):
|
16 |
+
st.write(
|
17 |
+
"This app allows you to upload any type of unlabeled dataset "
|
18 |
+
"and automatically clusters the data using K-means clustering. "
|
19 |
+
"It visualizes the clusters using PCA and provides time series and cluster distribution plots "
|
20 |
+
"to help you identify patterns and groupings within your data."
|
21 |
+
)
|
22 |
+
|
23 |
+
# File uploader
|
24 |
+
uploaded_file = st.file_uploader("Upload Custom CSV file", type=["csv"])
|
25 |
+
|
26 |
+
# # Example Demo Dataset
|
27 |
+
if st.button("Test With An Example Dataset"):
|
28 |
+
dataset = load_dataset('kheejay88/country_data', split='train')
|
29 |
+
df = pd.DataFrame(dataset)
|
30 |
+
st.success("Loaded example dataset from Hugging Face.")
|
31 |
+
|
32 |
+
with st.expander("Dataset Columns"):
|
33 |
+
st.write("""
|
34 |
+
**country** – Name of the country\n
|
35 |
+
**child_mort** – Death of children under 5 years of age per 1000 live births\n
|
36 |
+
**exports** – Exports of goods and services per capita (as a percentage of GDP)\n
|
37 |
+
**health** – Total health spending per capita (as a percentage of GDP)\n
|
38 |
+
**imports** – Imports of goods and services per capita (as a percentage of GDP)\n
|
39 |
+
**income** – Net income per person\n
|
40 |
+
**inflation** – Annual inflation rate (percentage)\n
|
41 |
+
**life_expec** – Average life expectancy at birth (in years)\n
|
42 |
+
**total_fer** – Total fertility rate (average number of children per woman)\n
|
43 |
+
**gdpp** – GDP per capita\n
|
44 |
+
""")
|
45 |
+
|
46 |
+
if uploaded_file is not None:
|
47 |
+
df = pd.read_csv(uploaded_file)
|
48 |
+
|
49 |
+
if 'df' in locals():
|
50 |
+
# Drop non-numeric columns
|
51 |
+
categorical_cols = df.select_dtypes(include=['object', 'category']).columns.tolist()
|
52 |
+
df.drop(columns=categorical_cols, inplace=True)
|
53 |
+
st.write("### Raw Data:")
|
54 |
+
st.write(df.head())
|
55 |
+
|
56 |
+
# Preprocessing
|
57 |
+
scaler = StandardScaler()
|
58 |
+
scaled_data = scaler.fit_transform(df)
|
59 |
+
|
60 |
+
# User input for clusters
|
61 |
+
num_clusters = st.slider("Select number of clusters", min_value=2, max_value=10, value=3)
|
62 |
+
|
63 |
+
# K-Means Clustering
|
64 |
+
kmeans = KMeans(n_clusters=num_clusters, random_state=42)
|
65 |
+
clusters = kmeans.fit_predict(scaled_data)
|
66 |
+
df['Cluster'] = clusters
|
67 |
+
|
68 |
+
# PCA for visualization
|
69 |
+
pca = PCA(n_components=2)
|
70 |
+
pca_data = pca.fit_transform(scaled_data)
|
71 |
+
df['PCA1'] = pca_data[:, 0]
|
72 |
+
df['PCA2'] = pca_data[:, 1]
|
73 |
+
|
74 |
+
# Plot Clusters
|
75 |
+
st.write("### Cluster Visualization:")
|
76 |
+
fig, ax = plt.subplots()
|
77 |
+
sns.scatterplot(x='PCA1', y='PCA2', hue='Cluster', data=df, palette='viridis', ax=ax)
|
78 |
+
st.pyplot(fig)
|
79 |
+
|
80 |
+
# Time Series Plot (if available)
|
81 |
+
numeric_cols = df.select_dtypes(include=[np.number]).columns.tolist()
|
82 |
+
if len(numeric_cols) >= 2:
|
83 |
+
selected_col = st.selectbox("Select column for time series visualization", numeric_cols)
|
84 |
+
st.write("### Time Series Plot:")
|
85 |
+
fig, ax = plt.subplots()
|
86 |
+
for cluster in df['Cluster'].unique():
|
87 |
+
cluster_data = df[df['Cluster'] == cluster]
|
88 |
+
ax.plot(cluster_data.index, cluster_data[selected_col], label=f'Cluster {cluster}')
|
89 |
+
ax.legend()
|
90 |
+
st.pyplot(fig)
|
91 |
+
|
92 |
+
# Cluster distribution
|
93 |
+
st.write("### Cluster Distribution:")
|
94 |
+
fig, ax = plt.subplots()
|
95 |
+
sns.countplot(x='Cluster', data=df, palette='viridis', ax=ax)
|
96 |
+
st.pyplot(fig)
|