import matplotlib
matplotlib.use('Agg')

import functools
import gradio as gr
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd


FIGURE_PATH = "plt.png"
FIG_DPI = 300


def get_plot(task, gpu, omit_offload):
    # slice the dataframe according to the inputs
    df = pd.read_csv("data.csv")
    df = df[df["task"] == task]
    df = df[df["gpu"] == gpu]
    if omit_offload == "Yes":
        df = df[df["offload"] == 0]

    # combine model name and dtype
    df["model and dtype"] = df['model_name'].str.cat(df[['dtype']], sep=', ')

    # fuse the two columns to be compared (original and assisted generation)
    df = df.melt(
        id_vars=["task", "gpu", "model and dtype", "offload"],
        value_vars=["Greedy", "Assisted"],
        var_name="generation_type",
        value_name="generation_time",
    )

    g = sns.catplot(
        data=df,
        kind="bar",
        x="model and dtype",
        y="generation_time",
        hue="generation_type",
        palette={"Greedy": "blue", "Assisted": "orange"},
        alpha=.9,
    )
    g.despine(left=True)
    g.set_axis_labels("Model size and dtype", "Latency (ms/token)")
    g.set_xticklabels(fontsize=7)
    g.set_yticklabels(fontsize=7)
    g.legend.set_title("Generation Type")
    plt.setp(g._legend.get_texts(), fontsize='7')  # for legend text

    # Add the number to the top of each bar
    ax = g.facet_axis(0, 0)
    for i in ax.containers:
        ax.bar_label(i, fontsize=7)

    plt.savefig(FIGURE_PATH, dpi=FIG_DPI)
    return FIGURE_PATH


demo = gr.Blocks()

with demo:
    gr.Markdown(
        """
        # Assisted Generation Benchmark
        """
    )
    # components shared across tabs
    omit_offload_fn = functools.partial(
        gr.Radio, ["Yes", "No"], value="No", label="Omit cases with memory offload?", interactive=True
    )

    def gpu_selector_fn(gpu_list):
        return gr.Dropdown(
            gpu_list, value=gpu_list[-1], label="GPU", interactive=True
        )

    with gr.Tabs():
        with gr.TabItem("OPT: Open"):
            plot_fn = functools.partial(get_plot, "OPT: Open Text Generation")
            with gr.Row():
                with gr.Column():
                    gpu_selector = gpu_selector_fn(["3090", "T4", "T4 *2", "A100 (80GB)"])
                with gr.Column():
                    omit_offload = omit_offload_fn()

            # Show plot when the gradio app is initialized
            plot = gr.Image(value=plot_fn("A100 (80GB)", "No"))
            gr.Markdown(
                """
                ### Assistant Model
                - `facebook/opt-125m`

                ### Model Names:
                - 1.3B: `facebook/opt-1.3b`
                - 6.7B: `facebook/opt-6.7b`
                - 30B: `facebook/opt-30b`
                - 66B: `facebook/opt-66b`

                ### Dataset used as input prompt:
                - C4 (en, validation set)
                """
            )
            # Update plot when any of the inputs change
            plot_inputs = [gpu_selector, omit_offload]
            gpu_selector.change(fn=plot_fn, inputs=plot_inputs, outputs=plot)
            omit_offload.change(fn=plot_fn, inputs=plot_inputs, outputs=plot)
        with gr.TabItem("OPT: Summ"):
            plot_fn = functools.partial(get_plot, "OPT: Summarization")
            with gr.Row():
                with gr.Column():
                    gpu_selector = gpu_selector_fn(["3090", "T4", "T4 *2", "A100 (80GB)"])
                with gr.Column():
                    omit_offload = omit_offload_fn()

            # Show plot when the gradio app is initialized
            plot = gr.Image(value=plot_fn("A100 (80GB)", "No"))
            gr.Markdown(
                """
                ### Assistant Model
                - `facebook/opt-125m`

                ### Model Names:
                - 1.3B: `facebook/opt-1.3b`
                - 6.7B: `facebook/opt-6.7b`
                - 30B: `facebook/opt-30b`
                - 66B: `facebook/opt-66b`

                ### Dataset used as input prompt:
                - CNN Dailymail (3.0.0, validation set)
                """
            )
            # Update plot when any of the inputs change
            plot_inputs = [gpu_selector, omit_offload]
            gpu_selector.change(fn=plot_fn, inputs=plot_inputs, outputs=plot)
            omit_offload.change(fn=plot_fn, inputs=plot_inputs, outputs=plot)
        with gr.TabItem("Whisper: ARS"):
            plot_fn = functools.partial(get_plot, "Whisper: ARS")
            with gr.Row():
                with gr.Column():
                    gpu_selector = gpu_selector_fn(["3090", "T4"])
                with gr.Column():
                    omit_offload = omit_offload_fn()

            # Show plot when the gradio app is initialized
            plot = gr.Image(value=plot_fn("T4", "No"))
            gr.Markdown(
                """
                ### Assistant Model
                - `openai/whisper-tiny`

                ### Model Names:
                - large-v2: `openai/whisper-large-v2`

                ### Dataset used as input prompt:
                - Librispeech ARS (clean, validation set)



                """
            )
            # Update plot when any of the inputs change
            plot_inputs = [gpu_selector, omit_offload]
            gpu_selector.change(fn=plot_fn, inputs=plot_inputs, outputs=plot)
            omit_offload.change(fn=plot_fn, inputs=plot_inputs, outputs=plot)
        with gr.TabItem("CodeGen: Code"):
            plot_fn = functools.partial(get_plot, "CodeGen: Code Generation")
            with gr.Row():
                with gr.Column():
                    gpu_selector = gpu_selector_fn(["3090", "T4", "T4 *2", "A100 (80GB)"])
                with gr.Column():
                    omit_offload = omit_offload_fn()
            # Show plot when the gradio app is initialized
            plot = gr.Image(value=plot_fn("A100 (80GB)", "No"))
            gr.Markdown(
                """
                ### Assistant Model
                - `Salesforce/codegen-350M-mono`

                ### Model Names:
                - 2B: `Salesforce/codegen-2B-mono`
                - 6B: `Salesforce/codegen-6B-mono`
                - 16B: `Salesforce/codegen-16B-mono`

                ### Dataset used as input prompt:
                - The Stack (python)

                """
            )
            # Update plot when any of the inputs change
            plot_inputs = [gpu_selector, omit_offload]
            gpu_selector.change(fn=plot_fn, inputs=plot_inputs, outputs=plot)
            omit_offload.change(fn=plot_fn, inputs=plot_inputs, outputs=plot)
        with gr.TabItem("Flan-T5: Summ"):
            plot_fn = functools.partial(get_plot, "Flan-T5: Summarization")
            with gr.Row():
                with gr.Column():
                    gpu_selector = gpu_selector_fn(["3090", "T4", "T4 *2", "A100 (80GB)"])
                with gr.Column():
                    omit_offload = omit_offload_fn()

            # Show plot when the gradio app is initialized
            plot = gr.Image(value=plot_fn("A100 (80GB)", "No"))
            gr.Markdown(
                """
                ### Assistant Model
                - `google/flan-t5-small`

                ### Model Names:
                - large: `google/flan-t5-large`
                - xl: `google/flan-t5-xl`
                - xxl: `google/flan-t5-xxl`
                - ul2: `google/flan-ul2`

                ### Dataset used as input prompt:
                - CNN Dailymail (3.0.0, validation set)
                """
            )
            # Update plot when any of the inputs change
            plot_inputs = [gpu_selector, omit_offload]
            gpu_selector.change(fn=plot_fn, inputs=plot_inputs, outputs=plot)
            omit_offload.change(fn=plot_fn, inputs=plot_inputs, outputs=plot)
        with gr.TabItem("Benchmark Info"):
            gr.Dataframe(
                headers=["Parameter", "Value"],
                value=[
                    ["Transformers Version", "4.29dev0"],
                    ["Pytorch Version", "2.0.0"],
                    ["OS", "22.04 LTS (3090) / Debian 10 (other GPUs)"],
                    ["CUDA", "11.8 (3090) / 11.3 (others GPUs)"],
                    ["Number of input samples", "20-100 (depending on the model size)"],
                    ["Is there code to reproduce?", "Yes -- https://github.com/gante/huggingface-demos/tree/main/experiments/faster_generation"],
                ],
            )

demo.launch()