import streamlit as st
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import os
from utils import (
    load_dataset,
    save_dataset,
    clean_dataset,
    compute_dataset_score,
    detect_outliers,
    apply_transformation,
    list_datasets,
    detect_inconsistent_types
)

# -------------------------------
# Constants & Setup
# -------------------------------
DATASET_DIR = "datasets"
DEFAULT_DATASET = "train_data.csv"
os.makedirs(DATASET_DIR, exist_ok=True)  # Ensure directory exists

# -------------------------------
# Sidebar: Dataset Selection
# -------------------------------
st.sidebar.header("๐Ÿ“Š Dataset Selection")

# List available datasets from the datasets folder
available_datasets = list_datasets(DATASET_DIR)
dataset_choice = st.sidebar.radio("Choose Dataset Source:", ["Select Existing Dataset", "Upload New Dataset"])

dataset_path = None

if dataset_choice == "Select Existing Dataset":
    if available_datasets:
        selected_dataset = st.sidebar.selectbox("Select Dataset:", available_datasets)
        dataset_path = os.path.join(DATASET_DIR, selected_dataset)
        st.sidebar.success(f"Using `{selected_dataset}` dataset.")
    else:
        st.sidebar.warning("No datasets found. Please upload a new dataset.")
elif dataset_choice == "Upload New Dataset":
    uploaded_file = st.sidebar.file_uploader("Upload Dataset (CSV, JSON, or Excel)", type=["csv", "json", "xlsx"])
    if uploaded_file:
        file_ext = uploaded_file.name.split('.')[-1].lower()
        try:
            if file_ext == "csv":
                new_df = pd.read_csv(uploaded_file)
            elif file_ext == "json":
                new_df = pd.json_normalize(json.load(uploaded_file))
            elif file_ext == "xlsx":
                new_df = pd.read_excel(uploaded_file)
            else:
                st.error("Unsupported file format.")
                st.stop()
        except Exception as e:
            st.error(f"Error reading file: {e}")
            st.stop()

        # Save the new dataset with its filename
        dataset_path = os.path.join(DATASET_DIR, uploaded_file.name)
        save_dataset(new_df, dataset_path)
        st.sidebar.success(f"Dataset `{uploaded_file.name}` uploaded successfully!")
        available_datasets = list_datasets(DATASET_DIR)  # Refresh list
    else:
        st.sidebar.warning("Please upload a dataset.")

# -------------------------------
# Load the Selected Dataset
# -------------------------------
if dataset_path:
    df = load_dataset(dataset_path)
    if df.empty:
        st.warning("Dataset is empty or failed to load.")
else:
    df = pd.DataFrame()
    st.warning("No dataset selected. Please choose or upload a dataset.")

# -------------------------------
# Main App Title & Description
# -------------------------------
st.title("๐Ÿ“Š The Data Hub")

# -------------------------------
# Tabs for Operations
# -------------------------------
tabs = st.tabs([
    "View & Summary", "Clean Data",
    "Visualize Data", "Data Profiling",
    "Outlier Detection", "Custom Transformations",
    "Export"
])

# -------------------------------
# Tab 1: View & Summary
# -------------------------------
with tabs[0]:
    st.subheader("๐Ÿ“‹ Current Dataset Preview")
    if not df.empty:
        st.dataframe(df)
        st.markdown("#### ๐Ÿ”Ž Basic Statistics")
        st.write(df.describe(include="all"))
    else:
        st.warning("No dataset available. Please choose or upload a dataset.")

# -------------------------------
# Tab 2: Clean Data
# -------------------------------
with tabs[1]:
    st.subheader("๐Ÿงผ Clean Your Dataset")
    if not df.empty:
        remove_duplicates = st.checkbox("Remove Duplicate Rows", value=True)
        fill_missing = st.checkbox("Fill Missing Values", value=False)
        fill_value = st.text_input("Fill missing values with:", value="0")

        st.markdown("#### Optional: Rename Columns")
        new_names = {}
        for col in df.columns:
            new_names[col] = st.text_input(f"Rename column '{col}'", value=col)

        if st.button("Clean Dataset"):
            cleaned_df = clean_dataset(df, remove_duplicates, fill_missing, fill_value)
            cleaned_df = cleaned_df.rename(columns=new_names)
            save_dataset(cleaned_df, dataset_path)
            st.success("โœ… Dataset cleaned successfully!")
            st.dataframe(cleaned_df.head())
            df = cleaned_df
    else:
        st.warning("No dataset available for cleaning.")

# -------------------------------
# Tab 3: Visualize Data (Fixed KeyError Issue)
# -------------------------------
with tabs[2]:
    st.subheader("๐Ÿ“Š Visualize Your Data")

    if not df.empty:
        viz_type = st.selectbox("Select Visualization Type", ["Histogram", "Scatter", "Box Plot", "Heatmap", "Line Chart"])
        numeric_cols = df.select_dtypes(include=["number"]).columns.tolist()

        if numeric_cols:
            # Validate column selection
            col = st.selectbox("Select Column", numeric_cols)

            if col:  # Ensure valid column selection
                fig, ax = plt.subplots()

                if viz_type == "Histogram":
                    ax.hist(df[col].dropna(), bins=20, color="skyblue", edgecolor="black")
                elif viz_type == "Box Plot":
                    sns.boxplot(x=df[col].dropna(), ax=ax)
                elif viz_type == "Scatter":
                    x_col = st.selectbox("X-axis", numeric_cols)
                    y_col = st.selectbox("Y-axis", numeric_cols)
                    if x_col and y_col:
                        ax.scatter(df[x_col], df[y_col], color="green")
                elif viz_type == "Heatmap":
                    corr = df[numeric_cols].corr()
                    sns.heatmap(corr, annot=True, cmap="coolwarm", ax=ax)
                elif viz_type == "Line Chart":
                    ax.plot(df.index, df[col], marker="o")

                st.pyplot(fig)
            else:
                st.warning("Please select a valid column.")
        else:
            st.warning("No numeric columns available for visualization.")
    else:
        st.warning("No dataset available for visualization.")

# -------------------------------
# Tab 4: Data Profiling
# -------------------------------
with tabs[3]:
    if not df.empty:

        # -------------------------------
        # 1. General Dataset Info
        # -------------------------------
        st.markdown("### ๐Ÿ› ๏ธ General Information")
        st.write(f"โœ… **Total Rows:** `{df.shape[0]}`")
        st.write(f"โœ… **Total Columns:** `{df.shape[1]}`")
        st.write(f"โœ… **Memory Usage:** `{df.memory_usage(deep=True).sum() / (1024 ** 2):.2f} MB`")
        st.write(f"โœ… **Dataset Shape:** `{df.shape}`")

        # -------------------------------
        # 2. Dataset Quality Score
        # -------------------------------
        st.markdown("### ๐Ÿ“Š Dataset Quality Score")
        score = compute_dataset_score(df)
        st.success(f"๐Ÿ’ฏ Dataset Quality Score: `{score} / 100`")

        # -------------------------------
        # 3. Column Overview with Stats
        # -------------------------------
        st.markdown("### ๐Ÿ”ฅ Column Overview")

        # Numeric and categorical columns
        numeric_cols = df.select_dtypes(include=["number"]).columns
        categorical_cols = df.select_dtypes(include=["object"]).columns

        profile = pd.DataFrame({
            "Column": df.columns,
            "Data Type": df.dtypes.values,
            "Missing Values": df.isnull().sum().values,
            "Missing %": (df.isnull().sum() / len(df) * 100).values,
            "Unique Values": df.nunique().values
        })

        # Add numeric statistics
        if len(numeric_cols) > 0:
            numeric_stats = pd.DataFrame({
                "Column": numeric_cols,
                "Min": df[numeric_cols].min().values,
                "Max": df[numeric_cols].max().values,
                "Mean": df[numeric_cols].mean().values,
                "Std Dev": df[numeric_cols].std().values,
                "Skewness": df[numeric_cols].skew().values,
                "Kurtosis": df[numeric_cols].kurt().values
            })

            # Merge stats with the profile
            profile = profile.merge(numeric_stats, on="Column", how="left")

        st.dataframe(profile)

        # -------------------------------
        # 4. Missing Values Visualization
        # -------------------------------
        st.markdown("### ๐Ÿ”Ž Missing Values Distribution")
        missing_values = df.isnull().sum()
        missing_values = missing_values[missing_values > 0]

        if not missing_values.empty:
            fig, ax = plt.subplots(figsize=(12, 5))
            sns.barplot(x=missing_values.index, y=missing_values.values, ax=ax, color="skyblue")
            ax.set_title("Missing Values per Column")
            ax.set_ylabel("Missing Count")
            ax.set_xticklabels(ax.get_xticklabels(), rotation=45)
            st.pyplot(fig)
        else:
            st.success("No missing values found!")

        # -------------------------------
        # 5. Duplicates Detection
        # -------------------------------
        st.markdown("### ๐Ÿ”ฅ Duplicates & Constant Columns Detection")
        
        # Duplicates
        duplicate_count = df.duplicated().sum()
        st.write(f"๐Ÿ” **Duplicate Rows:** `{duplicate_count}`")

        # Constant Columns
        constant_cols = [col for col in df.columns if df[col].nunique() == 1]
        if constant_cols:
            st.write(f"๐Ÿšฉ **Constant Columns:** `{constant_cols}`")
        else:
            st.success("No constant columns detected!")

        # -------------------------------
        # 6. Cardinality Analysis
        # -------------------------------
        st.markdown("### ๐Ÿงฌ Cardinality Analysis")
        
        high_cardinality = [col for col in df.columns if df[col].nunique() > len(df) * 0.8]
        if high_cardinality:
            st.write(f"๐Ÿ”ข **High-Cardinality Columns:** `{high_cardinality}`")
        else:
            st.success("No high-cardinality columns detected!")

        # -------------------------------
        # 7. Top Frequent & Rare Values
        # -------------------------------
        st.markdown("### ๐ŸŽฏ Frequent & Rare Values")

        for col in categorical_cols:
            st.write(f"โœ… **{col}**")
            
            top_values = df[col].value_counts().nlargest(5)
            rare_values = df[col].value_counts().nsmallest(5)

            st.write("๐Ÿ“Š **Top Frequent Values:**")
            st.dataframe(top_values)

            st.write("๐Ÿงช **Rare Values:**")
            st.dataframe(rare_values)

        # -------------------------------
        # 8. Correlation Matrix
        # -------------------------------
        st.markdown("### ๐Ÿ“Š Correlation Matrix")
        
        if len(numeric_cols) > 1:
            corr = df[numeric_cols].corr()

            fig, ax = plt.subplots(figsize=(12, 8))
            sns.heatmap(corr, annot=True, fmt=".2f", cmap="coolwarm", square=True, ax=ax)
            st.pyplot(fig)
        else:
            st.info("Not enough numeric columns for correlation analysis.")

        # -------------------------------
        # 9. Pair Plot (Numerical Relationships)
        # -------------------------------
        st.markdown("### ๐Ÿ”ฅ Pair Plot (Numerical Relationships)")
        
        if len(numeric_cols) >= 2:
            pairplot = sns.pairplot(df[numeric_cols], diag_kind='kde')
            st.pyplot(pairplot.fig)
        else:
            st.info("Not enough numeric columns for pair plot visualization.")

        # -------------------------------
        # 10. Outlier Detection
        # -------------------------------
        st.markdown("### ๐Ÿšฉ Outlier Detection")
        
        outliers = detect_outliers(df)
        if outliers:
            st.write("โœ… **Outliers Detected:**")
            st.dataframe(pd.DataFrame(outliers.items(), columns=["Column", "Outlier Count"]))
        else:
            st.success("No significant outliers detected!")

        # -------------------------------
        # 11. Inconsistent Data Types
        # -------------------------------
        st.markdown("### ๐Ÿšซ Inconsistent Data Types")
        
        inconsistent_types = detect_inconsistent_types(df)
        if inconsistent_types:
            st.write("โš ๏ธ **Inconsistent Data Types Detected:**")
            st.write(inconsistent_types)
        else:
            st.success("No inconsistent data types detected!")

    else:
        st.warning("No dataset available for profiling.")

# -------------------------------
# Tab 5: Outlier Detection
# -------------------------------
with tabs[4]:
    st.subheader("๐Ÿš€ Outlier Detection")
    if not df.empty:
        outliers = detect_outliers(df)
        st.write(outliers)
    else:
        st.warning("No dataset available for outlier detection.")

# -------------------------------
# Tab 6: Export
# -------------------------------
with tabs[5]:
    st.subheader("๐Ÿ“ค Export Dataset")
    export_format = st.selectbox("Export Format", ["CSV", "Excel", "JSON"])
    if not df.empty:
        st.download_button("Download", df.to_csv(index=False), f"dataset.{export_format.lower()}")