import pickle

import pandas as pd
import gradio as gr
import plotly.express as px

from utils import (
    KEY_TO_CATEGORY_NAME,
    PROPRIETARY_LICENSES,
    CAT_NAME_TO_EXPLANATION,
    download_latest_data_from_space,
    get_constants,
)

###################
### Load Data
###################

# gather ELO data
latest_elo_file_local = download_latest_data_from_space(
    repo_id="lmsys/chatbot-arena-leaderboard", file_type="pkl"
)

with open(latest_elo_file_local, "rb") as fin:
    elo_results = pickle.load(fin)

arena_dfs = {}
for k in KEY_TO_CATEGORY_NAME.keys():
    if k not in elo_results:
        continue
    arena_dfs[KEY_TO_CATEGORY_NAME[k]] = elo_results[k]["leaderboard_table_df"]

# gather open llm leaderboard data
latest_leaderboard_file_local = download_latest_data_from_space(
    repo_id="lmsys/chatbot-arena-leaderboard", file_type="csv"
)
leaderboard_df = pd.read_csv(latest_leaderboard_file_local)

###################
### Prepare Data
###################

# merge leaderboard data with ELO data
merged_dfs = {}
for k, v in arena_dfs.items():
    merged_dfs[k] = (
        pd.merge(arena_dfs[k], leaderboard_df, left_index=True, right_on="key")
        .sort_values("rating", ascending=False)
        .reset_index(drop=True)
    )

# add release dates into the merged data
release_date_mapping = pd.read_json("release_date_mapping.json", orient="records")
for k, v in merged_dfs.items():
    merged_dfs[k] = pd.merge(
        merged_dfs[k], release_date_mapping[["key", "Release Date"]], on="key"
    )


# format dataframes
def format_data(df):
    df["License"] = df["License"].apply(
        lambda x: "Proprietary LLM" if x in PROPRIETARY_LICENSES else "Open LLM"
    )
    df["Release Date"] = pd.to_datetime(df["Release Date"])
    df["Month-Year"] = df["Release Date"].dt.to_period("M")
    df["rating"] = df["rating"].round()
    return df.reset_index(drop=True)


merged_dfs = {k: format_data(v) for k, v in merged_dfs.items()}


# get constants
min_elo_score, max_elo_score, upper_models_per_month = get_constants(merged_dfs)

date_updated = elo_results["full"]["last_updated_datetime"].split(" ")[0]


###################
### Plot Data
###################


def get_data_split(dfs, set_name):
    df = dfs[set_name].copy(deep=True)
    return df.reset_index(drop=True)


def build_plot(min_score, max_models_per_month, toggle_annotations, set_selector):

    df = get_data_split(merged_dfs, set_name=set_selector)

    # filter data
    filtered_df = df[(df["rating"] >= min_score)]

    filtered_df = (
        filtered_df.groupby(["Month-Year", "License"])
        .apply(lambda x: x.nlargest(max_models_per_month, "rating"))
        .reset_index(drop=True)
    )

    fig = px.scatter(
        filtered_df,
        x="Release Date",
        y="rating",
        color="License",
        hover_name="Model",
        hover_data=["Organization", "License", "Link"],
        trendline="ols",
        title=f"Open vs Proprietary LLMs by LMSYS Arena ELO Score (as of {date_updated})",
        labels={"rating": "Arena ELO", "Release Date": "Release Date"},
        height=800,
        template="seaborn",
    )

    fig.update_traces(marker=dict(size=10, opacity=0.6))

    if toggle_annotations:
        # get the points to annotate (only the highest rated model per month per license)
        idx_to_annotate = filtered_df.groupby(["Month-Year", "License"])[
            "rating"
        ].idxmax()
        points_to_annotate_df = filtered_df.loc[idx_to_annotate]

        for i, row in points_to_annotate_df.iterrows():
            fig.add_annotation(
                x=row["Release Date"],
                y=row["rating"],
                text=row["Model"],
                showarrow=True,
                arrowhead=0,
            )

    return fig


with gr.Blocks(
    theme=gr.themes.Soft(
        primary_hue=gr.themes.colors.sky,
        secondary_hue=gr.themes.colors.green,
        font=[
            gr.themes.GoogleFont("Open Sans"),
            "ui-sans-serif",
            "system-ui",
            "sans-serif",
        ],
    )
) as demo:
    gr.Markdown(
        """
        <div style="text-align: center; max-width: 650px; margin: auto;">
            <h1 style="font-weight: 900; margin-top: 5px;">🔬 Progress Tracker: Open vs. Proprietary LLMs
            </h1>
            <p style="text-align: left; margin-top: 10px; margin-bottom: 10px; line-height: 20px;">
            This app visualizes the progress of proprietary and open-source LLMs in the LMSYS Arena ELO leaderboard. The idea is inspired by <a href="https://www.linkedin.com/posts/maxime-labonne_arena-elo-graph-updated-with-new-models-activity-7187062633735368705-u2jB?utm_source=share&utm_medium=member_desktop">this great work</a> from <a href="https://huggingface.co/mlabonne/">Maxime Labonne</a>.
            </p>
        </div>
        """
    )

    with gr.Row():
        with gr.Column():
            set_selector = gr.Dropdown(
                choices=list(CAT_NAME_TO_EXPLANATION.keys()),
                label="Select Category",
                value="Overall",
                info="Select the category to visualize",
            )
            toggle_annotations = gr.Radio(
                choices=[True, False],
                label="Overlay Best Model Name",
                value=True,
                info="Toggle to overlay the name of the best model per month per license",
            )
        with gr.Column():
            min_score = gr.Slider(
                minimum=min_elo_score,
                maximum=max_elo_score,
                value=(max_elo_score - min_elo_score) * 0.3 + min_elo_score,
                step=50,
                label="Minimum ELO Score",
                info="Filter out low scoring models",
            )
            max_models_per_month = gr.Slider(
                value=upper_models_per_month - 2,
                minimum=1,
                maximum=upper_models_per_month,
                step=1,
                label="Max Models per Month (per License)",
                info="Limit to N best models per month per license to reduce clutter",
            )

    # Show plot
    plot = gr.Plot()
    demo.load(
        fn=build_plot,
        inputs=[min_score, max_models_per_month, toggle_annotations, set_selector],
        outputs=plot,
    )
    min_score.change(
        fn=build_plot,
        inputs=[min_score, max_models_per_month, toggle_annotations, set_selector],
        outputs=plot,
    )
    max_models_per_month.change(
        fn=build_plot,
        inputs=[min_score, max_models_per_month, toggle_annotations, set_selector],
        outputs=plot,
    )
    toggle_annotations.change(
        fn=build_plot,
        inputs=[min_score, max_models_per_month, toggle_annotations, set_selector],
        outputs=plot,
    )
    set_selector.change(
        fn=build_plot,
        inputs=[min_score, max_models_per_month, toggle_annotations, set_selector],
        outputs=plot,
    )

demo.launch()