#!/usr/bin/env python

import gradio as gr
import polars as pl

from app_mcp import demo as demo_mcp
from search import search
from table import df_orig

DESCRIPTION = "# CVPR 2025"

# TODO: remove this once https://github.com/gradio-app/gradio/issues/10916 https://github.com/gradio-app/gradio/issues/11001 https://github.com/gradio-app/gradio/issues/11002 are fixed  # noqa: TD002, FIX002
NOTE = """\
Note: Sorting by upvotes or comments may not work correctly due to a known bug in Gradio.
"""

df_main = df_orig.select(
    "title",
    "authors_str",
    "cvf_md",
    "paper_page_md",
    "upvotes",
    "num_comments",
    "project_page_md",
    "github_md",
    "Spaces",
    "Models",
    "Datasets",
    "claimed",
    "abstract",
    "paper_id",
)
# TODO: Fix this once https://github.com/gradio-app/gradio/issues/10916 is fixed # noqa: FIX002, TD002
# format numbers as strings
df_main = df_main.with_columns(
    [pl.col(col).fill_null(0).cast(pl.Int64).alias(col) for col in ["upvotes", "num_comments"]]
)

df_main = df_main.rename(
    {
        "title": "Title",
        "authors_str": "Authors",
        "cvf_md": "CVF",
        "paper_page_md": "Paper page",
        "upvotes": "👍",
        "num_comments": "💬",
        "project_page_md": "Project page",
        "github_md": "GitHub",
    }
)

COLUMN_INFO = {
    "Title": ("str", "40%"),
    "Authors": ("str", "20%"),
    "Paper page": ("markdown", "135px"),
    "👍": ("number", "50px"),
    "💬": ("number", "50px"),
    "CVF": ("markdown", None),
    "Project page": ("markdown", None),
    "GitHub": ("markdown", None),
    "Spaces": ("markdown", None),
    "Models": ("markdown", None),
    "Datasets": ("markdown", None),
    "claimed": ("markdown", None),
}


DEFAULT_COLUMNS = [
    "Title",
    "Paper page",
    "👍",
    "💬",
    "CVF",
    "Project page",
    "GitHub",
    "Spaces",
    "Models",
    "Datasets",
]


def update_num_papers(df: pl.DataFrame) -> str:
    if "claimed" in df.columns:
        return f"{len(df)} / {len(df_main)} ({df.select(pl.col('claimed').str.contains('✅').sum()).item()} claimed)"
    return f"{len(df)} / {len(df_main)}"


def update_df(
    search_query: str,
    candidate_pool_size: int,
    num_results: int,
    column_names: list[str],
) -> gr.Dataframe:
    if num_results > candidate_pool_size:
        raise gr.Error("Number of results must be less than or equal to candidate pool size", print_exception=False)

    df = df_main.clone()
    column_names = ["Title", *column_names]

    if search_query:
        results = search(search_query, candidate_pool_size, num_results)
        if not results:
            df = df.head(0)
        else:
            df = pl.DataFrame(results).join(df, on="paper_id", how="inner")
            df = df.sort("ce_score", descending=True).drop("ce_score")

    sorted_column_names = [col for col in COLUMN_INFO if col in column_names]
    df = df.select(sorted_column_names)
    return gr.Dataframe(
        value=df,
        datatype=[COLUMN_INFO[col][0] for col in sorted_column_names],
        column_widths=[COLUMN_INFO[col][1] for col in sorted_column_names],
    )


with gr.Blocks(css_paths="style.css") as demo:
    gr.Markdown(DESCRIPTION)
    search_query = gr.Textbox(label="Search", submit_btn=True, show_label=False, placeholder="Search...")
    with gr.Accordion(label="Advanced Search Options", open=False) as advanced_search_options:
        with gr.Row():
            candidate_pool_size = gr.Slider(label="Candidate Pool Size", minimum=1, maximum=600, step=1, value=200)
            num_results = gr.Slider(label="Number of Results", minimum=1, maximum=400, step=1, value=100)

    column_names = gr.CheckboxGroup(
        label="Columns",
        choices=[col for col in COLUMN_INFO if col != "Title"],
        value=[col for col in DEFAULT_COLUMNS if col != "Title"],
    )

    num_papers = gr.Textbox(label="Number of papers", value=update_num_papers(df_orig), interactive=False)

    gr.Markdown(NOTE)
    df = gr.Dataframe(
        value=df_main,
        datatype=list(COLUMN_INFO.values()),
        type="polars",
        row_count=(0, "dynamic"),
        show_row_numbers=True,
        interactive=False,
        max_height=1000,
        elem_id="table",
        column_widths=[COLUMN_INFO[col][1] for col in COLUMN_INFO],
    )

    inputs = [
        search_query,
        candidate_pool_size,
        num_results,
        column_names,
    ]
    gr.on(
        triggers=[
            search_query.submit,
            column_names.input,
        ],
        fn=update_df,
        inputs=inputs,
        outputs=df,
        api_name=False,
    ).then(
        fn=update_num_papers,
        inputs=df,
        outputs=num_papers,
        queue=False,
        api_name=False,
    )
    demo.load(
        fn=update_df,
        inputs=inputs,
        outputs=df,
        api_name=False,
    ).then(
        fn=update_num_papers,
        inputs=df,
        outputs=num_papers,
        queue=False,
        api_name=False,
    )

    with gr.Row(visible=False):
        demo_mcp.render()


if __name__ == "__main__":
    demo.launch(mcp_server=True)