import streamlit as st
import pandas as pd
import json
import os
from datetime import datetime
from utils import (
    load_model, 
    get_hf_token, 
    simulate_training, 
    plot_training_metrics, 
    load_finetuned_model, 
    save_model
)

st.title("🔥 Fine-tune the Gemma Model")

# -------------------------------
# Finetuning Option Selection
# -------------------------------
finetune_option = st.radio("Select Finetuning Option", ["Fine-tune from scratch", "Refinetune existing model"])

# -------------------------------
# Model Selection Logic
# -------------------------------
selected_model = None
saved_model_path = None

if finetune_option == "Fine-tune from scratch":
    # Display Hugging Face model list
    model_list = [
        "google/gemma-3-1b-pt",
        "google/gemma-3-1b-it", 
        "google/gemma-3-4b-pt", 
        "google/gemma-3-4b-it",
        "google/gemma-3-12b-pt", 
        "google/gemma-3-12b-it", 
        "google/gemma-3-27b-pt", 
        "google/gemma-3-27b-it"
    ]
    selected_model = st.selectbox("🛠️ Select Gemma Model to Fine-tune", model_list)

elif finetune_option == "Refinetune existing model":
    # Dynamically list all saved models from the /models folder
    model_dir = "models"
    
    if os.path.exists(model_dir):
        saved_models = [f for f in os.listdir(model_dir) if f.endswith(".pt")]
    else:
        saved_models = []

    if saved_models:
        saved_model_path = st.selectbox("Select a saved model to re-finetune", saved_models)
        saved_model_path = os.path.join(model_dir, saved_model_path)
        st.success(f"✅ Selected model for refinement: `{saved_model_path}`")
    else:
        st.warning("⚠️ No saved models found! Switching to fine-tuning from scratch.")
        finetune_option = "Fine-tune from scratch"

# -------------------------------
# Dataset Selection
# -------------------------------
st.subheader("📚 Dataset Selection")
dataset_option = st.radio("Choose dataset:", ["Upload New Dataset", "Use Existing Dataset (`train_data.csv`)"])
dataset_path = "datasets/train_data.csv"

if dataset_option == "Upload New Dataset":
    uploaded_file = st.file_uploader("📤 Upload Dataset (CSV or JSON)", type=["csv", "json"])
    if uploaded_file is not None:
        if uploaded_file.name.endswith(".csv"):
            new_data = pd.read_csv(uploaded_file)
        elif uploaded_file.name.endswith(".json"):
            json_data = json.load(uploaded_file)
            new_data = pd.json_normalize(json_data)
        else:
            st.error("❌ Unsupported file format. Please upload CSV or JSON.")
            st.stop()

        if os.path.exists(dataset_path):
            new_data.to_csv(dataset_path, mode='a', index=False, header=False)
            st.success(f"✅ Data appended to `{dataset_path}`!")
        else:
            new_data.to_csv(dataset_path, index=False)
            st.success(f"✅ Dataset saved as `{dataset_path}`!")
elif dataset_option == "Use Existing Dataset (`train_data.csv`)":
    if os.path.exists(dataset_path):
        st.success("✅ Using existing `train_data.csv` for fine-tuning.")
    else:
        st.error("❌ `train_data.csv` not found! Please upload a new dataset.")
        st.stop()

# -------------------------------
# Hyperparameters Configuration
# -------------------------------
st.subheader("🔧 Hyperparameter Configuration")
learning_rate = st.number_input("📊 Learning Rate", value=1e-4, format="%.5f")
batch_size = st.number_input("🛠️ Batch Size", value=16, step=1)
epochs = st.number_input("⏱️ Epochs", value=3, step=1)


# -------------------------------
# Fine-tuning Execution with Real-Time Visualization
# -------------------------------
if st.button("🚀 Start Fine-tuning"):
    st.info("Fine-tuning process initiated...")
    hf_token = get_hf_token()

    # Model loading logic
    if finetune_option == "Refinetune existing model" and saved_model_path:
        tokenizer, model = load_model("google/gemma-3-1b-it", hf_token)
        model = load_finetuned_model(model, saved_model_path)
        if model:
            st.success(f"✅ Loaded saved model: `{saved_model_path}` for refinement!")
        else:
            st.error("❌ Failed to load the saved model. Aborting.")
            st.stop()
    else:
        if not selected_model:
            st.error("❌ Please select a model to fine-tune.")
            st.stop()
        tokenizer, model = load_model(selected_model, hf_token)
        if model:
            st.success(f"✅ Base model loaded: `{selected_model}`")
        else:
            st.error("❌ Failed to load the base model. Aborting.")
            st.stop()

    # Create placeholders for training progress
    loss_chart = st.line_chart()  # Loss curve
    acc_chart = st.line_chart()   # Accuracy curve
    progress_text = st.empty()
    
    # Simulate training loop with real-time visualization
    losses_over_epochs = []
    accuracies_over_epochs = []
    
    for epoch, losses, accs in simulate_training(epochs, learning_rate, batch_size):
        # Update training text
        progress_text.text(f"Epoch {epoch}/{epochs} in progress...")
        
        # Assume simulate_training returns overall average loss and accuracy per epoch
        losses_over_epochs.append(losses)  # e.g., average loss of the epoch
        accuracies_over_epochs.append(accs)  # e.g., average accuracy of the epoch
        
        # Update real-time charts
        loss_chart.add_rows(pd.DataFrame({"Loss": [losses]}))
        acc_chart.add_rows(pd.DataFrame({"Accuracy": [accs]}))
    
    progress_text.text("Fine-tuning completed!")
    
    # Save fine-tuned model with timestamp
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    model_identifier = selected_model if selected_model else os.path.basename(saved_model_path)
    new_model_name = f"models/fine_tuned_model_{model_identifier.replace('/', '_')}_{timestamp}.pt"
    
    saved_model_path = save_model(model, new_model_name)
    if saved_model_path:
        st.success(f"✅ Fine-tuning completed! Model saved as `{saved_model_path}`")
        model = load_finetuned_model(model, saved_model_path)
        if model:
            st.success("🛠️ Fine-tuned model loaded and ready for inference!")
        else:
            st.error("❌ Failed to load the fine-tuned model for inference.")
    else:
        st.error("❌ Failed to save the fine-tuned model.")