import gradio as gr
import json
import matplotlib.pyplot as plt
import pandas as pd
import io
import base64
import math
import ast
import logging
import numpy as np
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from scipy import stats

# Set up logging
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)

# Function to safely parse JSON or Python dictionary input
def parse_input(json_input):
    logger.debug("Attempting to parse input: %s", json_input)
    try:
        # Try to parse as JSON first
        data = json.loads(json_input)
        logger.debug("Successfully parsed as JSON")
        return data
    except json.JSONDecodeError as e:
        logger.error("JSON parsing failed: %s", str(e))
        try:
            # If JSON fails, try to parse as Python literal (e.g., with single quotes)
            data = ast.literal_eval(json_input)
            logger.debug("Successfully parsed as Python literal")
            # Convert Python dictionary to JSON-compatible format (replace single quotes with double quotes)
            def dict_to_json(obj):
                if isinstance(obj, dict):
                    return {str(k): dict_to_json(v) for k, v in obj.items()}
                elif isinstance(obj, list):
                    return [dict_to_json(item) for item in obj]
                else:
                    return obj
            converted_data = dict_to_json(data)
            logger.debug("Converted to JSON-compatible format")
            return converted_data
        except (SyntaxError, ValueError) as e:
            logger.error("Python literal parsing failed: %s", str(e))
            raise ValueError(f"Malformed input: {str(e)}. Ensure property names are in double quotes (e.g., \"content\") or correct Python dictionary format.")

# Function to ensure a value is a float, converting from string if necessary
def ensure_float(value):
    if value is None:
        return None
    if isinstance(value, str):
        try:
            return float(value)
        except ValueError:
            logger.error("Failed to convert string '%s' to float", value)
            return None
    if isinstance(value, (int, float)):
        return float(value)
    return None

# Function to process and visualize log probs with interactive Plotly plots
def visualize_logprobs(json_input, prob_filter=-1e9, page_size=50, page=0):
    try:
        # Parse the input (handles both JSON and Python dictionaries)
        data = parse_input(json_input)
        
        # Ensure data is a list or dictionary with 'content'
        if isinstance(data, dict) and "content" in data:
            content = data["content"]
        elif isinstance(data, list):
            content = data
        else:
            raise ValueError("Input must be a list or dictionary with 'content' key")

        # Extract tokens, log probs, and top alternatives, skipping None or non-finite values
        tokens = []
        logprobs = []
        top_alternatives = []  # List to store top 3 log probs (selected token + 2 alternatives)
        for entry in content:
            logprob = ensure_float(entry.get("logprob", None))
            if logprob is not None and math.isfinite(logprob) and logprob >= prob_filter:
                tokens.append(entry["token"])
                logprobs.append(logprob)
                # Get top_logprobs, default to empty dict if None
                top_probs = entry.get("top_logprobs", {})
                # Ensure all values in top_logprobs are floats
                finite_top_probs = {}
                for key, value in top_probs.items():
                    float_value = ensure_float(value)
                    if float_value is not None and math.isfinite(float_value):
                        finite_top_probs[key] = float_value
                # Get the top 3 log probs (including the selected token)
                all_probs = {entry["token"]: logprob}  # Add the selected token's logprob
                all_probs.update(finite_top_probs)  # Add alternatives
                sorted_probs = sorted(all_probs.items(), key=lambda x: x[1], reverse=True)
                top_3 = sorted_probs[:3]  # Top 3 log probs (highest to lowest)
                top_alternatives.append(top_3)
            else:
                logger.debug("Skipping entry with logprob: %s (type: %s)", entry.get("logprob"), type(entry.get("logprob", None)))

        # Check if there's valid data after filtering
        if not logprobs or not tokens:
            return (gr.update(value="No finite log probabilities or tokens to visualize after filtering"), None, None, None, 1, 0)

        # Paginate data for large inputs
        total_pages = max(1, (len(logprobs) + page_size - 1) // page_size)
        start_idx = page * page_size
        end_idx = min((page + 1) * page_size, len(logprobs))
        paginated_tokens = tokens[start_idx:end_idx]
        paginated_logprobs = logprobs[start_idx:end_idx]
        paginated_alternatives = top_alternatives[start_idx:end_idx] if top_alternatives else []

        # 1. Main Log Probability Plot (Interactive Plotly)
        main_fig = go.Figure()
        main_fig.add_trace(go.Scatter(x=list(range(len(paginated_logprobs))), y=paginated_logprobs, mode='markers+lines', name='Log Prob', marker=dict(color='blue')))
        main_fig.update_layout(
            title="Log Probabilities of Generated Tokens",
            xaxis_title="Token Position",
            yaxis_title="Log Probability",
            hovermode="closest",
            clickmode='event+select'
        )
        main_fig.update_traces(
            customdata=[f"Token: {tok}, Log Prob: {prob:.4f}, Position: {i+start_idx}" for i, (tok, prob) in enumerate(zip(paginated_tokens, paginated_logprobs))],
            hovertemplate='<b>%{customdata}</b><extra></extra>'
        )

        # 2. Probability Drop Analysis (Interactive Plotly)
        if len(paginated_logprobs) < 2:
            drops_fig = go.Figure()
            drops_fig.add_trace(go.Bar(x=list(range(len(paginated_logprobs)-1)), y=[0], name='Drop', marker_color='red'))
        else:
            drops = [paginated_logprobs[i+1] - paginated_logprobs[i] for i in range(len(paginated_logprobs)-1)]
            drops_fig = go.Figure()
            drops_fig.add_trace(go.Bar(x=list(range(len(drops))), y=drops, name='Drop', marker_color='red'))
            drops_fig.update_layout(
                title="Significant Probability Drops",
                xaxis_title="Token Position",
                yaxis_title="Log Probability Drop",
                hovermode="closest",
                clickmode='event+select'
            )
            drops_fig.update_traces(
                customdata=[f"Drop: {drop:.4f}, From: {paginated_tokens[i]} to {paginated_tokens[i+1]}, Position: {i+start_idx}" for i, drop in enumerate(drops)],
                hovertemplate='<b>%{customdata}</b><extra></extra>'
            )

        # 3. Anomaly Detection (Interactive Plotly)
        if not paginated_logprobs:
            anomaly_fig = go.Figure()
            anomaly_fig.add_trace(go.Scatter(x=[], y=[], mode='markers+lines', name='Log Prob', marker_color='blue'))
        else:
            z_scores = np.abs(stats.zscore(paginated_logprobs))
            outliers = z_scores > 2  # Threshold for outliers
            anomaly_fig = go.Figure()
            anomaly_fig.add_trace(go.Scatter(x=list(range(len(paginated_logprobs))), y=paginated_logprobs, mode='markers+lines', name='Log Prob', marker_color='blue'))
            anomaly_fig.add_trace(go.Scatter(x=np.where(outliers)[0], y=[paginated_logprobs[i] for i in np.where(outliers)[0]], mode='markers', name='Outliers', marker_color='red'))
            anomaly_fig.update_layout(
                title="Log Probabilities with Outliers",
                xaxis_title="Token Position",
                yaxis_title="Log Probability",
                hovermode="closest",
                clickmode='event+select'
            )
            anomaly_fig.update_traces(
                customdata=[f"Token: {tok}, Log Prob: {prob:.4f}, Position: {i+start_idx}, Outlier: {out}" for i, (tok, prob, out) in enumerate(zip(paginated_tokens, paginated_logprobs, outliers))],
                hovertemplate='<b>%{customdata}</b><extra></extra>'
            )

        # Create DataFrame for the table (paginated)
        table_data = []
        for i, entry in enumerate(content[start_idx:end_idx]):
            logprob = ensure_float(entry.get("logprob", None))
            if logprob is not None and math.isfinite(logprob) and logprob >= prob_filter and "top_logprobs" in entry and entry["top_logprobs"] is not None:
                token = entry["token"]
                top_logprobs = entry["top_logprobs"]
                # Ensure all values in top_logprobs are floats
                finite_top_logprobs = {}
                for key, value in top_logprobs.items():
                    float_value = ensure_float(value)
                    if float_value is not None and math.isfinite(float_value):
                        finite_top_logprobs[key] = float_value
                # Extract top 3 alternatives from top_logprobs
                top_3 = sorted(finite_top_logprobs.items(), key=lambda x: x[1], reverse=True)[:3]
                row = [token, f"{logprob:.4f}"]
                for alt_token, alt_logprob in top_3:
                    row.append(f"{alt_token}: {alt_logprob:.4f}")
                while len(row) < 5:
                    row.append("")
                table_data.append(row)

        df = (
            pd.DataFrame(
                table_data,
                columns=[
                    "Token",
                    "Log Prob",
                    "Top 1 Alternative",
                    "Top 2 Alternative",
                    "Top 3 Alternative",
                ],
            )
            if table_data
            else None
        )

        # Generate colored text (paginated)
        if paginated_logprobs:
            min_logprob = min(paginated_logprobs)
            max_logprob = max(paginated_logprobs)
            if max_logprob == min_logprob:
                normalized_probs = [0.5] * len(paginated_logprobs)
            else:
                normalized_probs = [
                    (lp - min_logprob) / (max_logprob - min_logprob) for lp in paginated_logprobs
                ]

            colored_text = ""
            for i, (token, norm_prob) in enumerate(zip(paginated_tokens, normalized_probs)):
                r = int(255 * (1 - norm_prob))  # Red for low confidence
                g = int(255 * norm_prob)        # Green for high confidence
                b = 0
                color = f"rgb({r}, {g}, {b})"
                colored_text += f'<span style="color: {color}; font-weight: bold;">{token}</span>'
                if i < len(paginated_tokens) - 1:
                    colored_text += " "
            colored_text_html = f"<p>{colored_text}</p>"
        else:
            colored_text_html = "No finite log probabilities to display."

        # Top 3 Token Log Probabilities (paginated)
        alt_viz_html = ""
        if paginated_logprobs and paginated_alternatives:
            alt_viz_html = "<h3>Top 3 Token Log Probabilities (Paginated)</h3><ul>"
            for i, (token, probs) in enumerate(zip(paginated_tokens, paginated_alternatives)):
                alt_viz_html += f"<li>Position {i+start_idx} (Token: {token}):<br>"
                for tok, prob in probs:
                    alt_viz_html += f"{tok}: {prob:.4f}<br>"
                alt_viz_html += "</li>"
            alt_viz_html += "</ul>"

        return (main_fig, df, colored_text_html, alt_viz_html, drops_fig, anomaly_fig, total_pages, page)

    except Exception as e:
        logger.error("Visualization failed: %s", str(e))
        return (gr.update(value=f"Error: {str(e)}"), None, "No finite log probabilities to display.", None, gr.update(value="No data for probability drops."), gr.update(value="No data for anomalies."), 1, 0)

# Gradio interface with interactive layout and pagination
with gr.Blocks(title="Log Probability Visualizer") as app:
    gr.Markdown("# Log Probability Visualizer")
    gr.Markdown(
        "Paste your JSON or Python dictionary log prob data below to visualize the tokens and their probabilities. Use the filter and pagination to navigate large inputs."
    )

    with gr.Row():
        with gr.Column(scale=1):
            json_input = gr.Textbox(
                label="JSON Input",
                lines=10,
                placeholder="Paste your JSON (e.g., {\"content\": [...]}) or Python dict (e.g., {'content': [...]}) here...",
            )
        with gr.Column(scale=1):
            prob_filter = gr.Slider(minimum=-1e9, maximum=0, value=-1e9, label="Log Probability Filter (≥)")
            page_size = gr.Number(value=50, label="Page Size", precision=0, minimum=10, maximum=1000)
            page = gr.Number(value=0, label="Page Number", precision=0, minimum=0)

    with gr.Row():
        plot_output = gr.Plot(label="Log Probability Plot (Click for Tokens)")
        drops_output = gr.Plot(label="Probability Drops (Click for Details)")

    with gr.Row():
        anomaly_output = gr.Plot(label="Anomaly Detection (Click for Details)")
        table_output = gr.Dataframe(label="Token Log Probabilities and Top Alternatives")

    with gr.Row():
        text_output = gr.HTML(label="Colored Text (Confidence Visualization)")
        alt_viz_output = gr.HTML(label="Top 3 Token Log Probabilities")

    btn = gr.Button("Visualize")
    btn.click(
        fn=visualize_logprobs,
        inputs=[json_input, prob_filter, page_size, page],
        outputs=[plot_output, table_output, text_output, alt_viz_output, drops_output, anomaly_output, gr.State(visible=False), gr.State(visible=False)],
    )

    # Pagination controls
    with gr.Row():
        prev_btn = gr.Button("Previous Page")
        next_btn = gr.Button("Next Page")
        total_pages_output = gr.Number(label="Total Pages", interactive=False, visible=False)
        current_page_output = gr.Number(label="Current Page", interactive=False, visible=False)

    def update_page(json_input, prob_filter, page_size, current_page, action):
        if action == "prev" and current_page > 0:
            current_page -= 1
        elif action == "next":
            total_pages = visualize_logprobs(json_input, prob_filter, page_size, 0)[6]  # Get total pages
            if current_page < total_pages - 1:
                current_page += 1
        return gr.update(value=current_page), gr.update(value=total_pages)

    prev_btn.click(
        fn=lambda *args: update_page(*args, "prev"),
        inputs=[json_input, prob_filter, page_size, page, gr.State()],
        outputs=[page, total_pages_output]
    )

    next_btn.click(
        fn=lambda *args: update_page(*args, "next"),
        inputs=[json_input, prob_filter, page_size, page, gr.State()],
        outputs=[page, total_pages_output]
    )

app.launch()