import argparse
import gradio as gr
import os
from PIL import Image
import spaces

from eagle_vl.serve.frontend import reload_javascript
from eagle_vl.serve.utils import (
    configure_logger,
    pil_to_base64,
    parse_ref_bbox,
    strip_stop_words,
    is_variable_assigned,
)
from eagle_vl.serve.gradio_utils import (
    cancel_outputing,
    delete_last_conversation,
    reset_state,
    reset_textbox,
    transfer_input,
    wrap_gen_fn,
)
from eagle_vl.serve.chat_utils import (
    generate_prompt_with_history,
    convert_conversation_to_prompts,
    to_gradio_chatbot,
    to_gradio_history,
)
from eagle_vl.serve.inference import eagle_vl_generate, load_model
from eagle_vl.serve.examples import get_examples

TITLE = """<h1 align="left" style="min-width:200px; margin-top:0;">Chat with Eagle2-VL </h1>"""
DESCRIPTION_TOP = """<a href="https://github.com/NVlabs/EAGLE" target="_blank">Eagle2-VL</a> is a multi-modal LLM that can understand text, images and videos, and generate text"""
DESCRIPTION = """"""
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
DEPLOY_MODELS = dict()
logger = configure_logger()


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--model", type=str, default="Eagle2-8B")
    parser.add_argument(
        "--local-path",
        type=str,
        default="",
        help="huggingface ckpt, optional",
    )
    parser.add_argument("--ip", type=str, default="0.0.0.0")
    parser.add_argument("--port", type=int, default=7860)
    return parser.parse_args()


def fetch_model(model_name: str):
    global args, DEPLOY_MODELS

    if args.local_path:
        model_path = args.local_path
    else:
        model_path = f"NVEagle/{args.model}"

    if model_name in DEPLOY_MODELS:
        model_info = DEPLOY_MODELS[model_name]
        print(f"{model_name} has been loaded.")
    else:
        print(f"{model_name} is loading...")
        DEPLOY_MODELS[model_name] = load_model(model_path)
        print(f"Load {model_name} successfully...")
        model_info = DEPLOY_MODELS[model_name]

    return model_info


def preview_images(files) -> list[str]:
    if files is None:
        return []

    image_paths = []
    for file in files:
        image_paths.append(file.name)
    return image_paths


def get_prompt(conversation) -> str:
    """
    Get the prompt for the conversation.
    """
    system_prompt = conversation.system_template.format(system_message=conversation.system_message)
    return system_prompt


@wrap_gen_fn
@spaces.GPU(duration=180)
def predict(
    text,
    images,
    chatbot,
    history,
    top_p,
    temperature,
    max_generate_length,
    max_context_length_tokens,
    video_nframes,
    chunk_size: int = 512,
):
    """
    Predict the response for the input text and images.
    Args:
        text (str): The input text.
        images (list[PIL.Image.Image]): The input images.
        chatbot (list): The chatbot.
        history (list): The history.
        top_p (float): The top-p value.
        temperature (float): The temperature value.
        repetition_penalty (float): The repetition penalty value.
        max_generate_length (int): The max length tokens.
        max_context_length_tokens (int): The max context length tokens.
        chunk_size (int): The chunk size.
    """


    if images is None:
        images = []

    # load images
    pil_images = []
    for img_or_file in images:
        try:
            logger.info(f"img_or_file: {img_or_file}")
            # load as pil image
            if isinstance(images, Image.Image):
                pil_images.append(img_or_file)
            elif isinstance(img_or_file, str):
                if img_or_file.endswith((".mp4", ".mov", ".avi", ".webm")):
                    pil_images.append(img_or_file)
                else:
                    image = Image.open(img_or_file.name).convert("RGB")
                    pil_images.append(image)
        except Exception as e:
            print(f"Error loading image: {e}")


    print("running the prediction function")
    try:
        logger.info("fetching model")
        model, processor = fetch_model(args.model)
        logger.info("model fetched")
        if text == "":
            yield chatbot, history, "Empty context."
            return
    except KeyError:
        logger.info("no model found")
        yield [[text, "No Model Found"]], [], "No Model Found"
        return
    
    # generate prompt
    conversation = generate_prompt_with_history(
        text,
        pil_images,
        history,
        processor,
        max_length=max_context_length_tokens,
    )
    all_conv, last_image = convert_conversation_to_prompts(conversation)
    stop_words = conversation.stop_str
    gradio_chatbot_output = to_gradio_chatbot(conversation)

    full_response = ""
    for x in eagle_vl_generate(
            conversations=all_conv,
            model=model,
            processor=processor,
            stop_words=stop_words,
            max_length=max_generate_length,
            temperature=temperature,
            top_p=top_p,
            video_nframes=video_nframes,
        ):
            full_response += x
            response = strip_stop_words(full_response, stop_words)
            conversation.update_last_message(response)
            gradio_chatbot_output[-1][1] = response

            yield gradio_chatbot_output, to_gradio_history(conversation), "Generating..."

    # if last_image is not None:
    #     vg_image = parse_ref_bbox(response, last_image)
    #     if vg_image is not None:
    #         vg_base64 = pil_to_base64(vg_image, "vg", max_size=800, min_size=400)
    #         gradio_chatbot_output[-1][1] += vg_base64
    #         yield gradio_chatbot_output, to_gradio_history(conversation), "Generating..."

    logger.info("flushed result to gradio")

    if is_variable_assigned("x"):
        print(
            f"temperature: {temperature}, "
            f"top_p: {top_p}, "
            f"max_generate_length: {max_generate_length}"
        )

    yield gradio_chatbot_output, to_gradio_history(conversation), "Generate: Success"


def retry(
    text,
    images,
    chatbot,
    history,
    top_p,
    temperature,
    max_generate_length,
    max_context_length_tokens,
    video_nframes,
    chunk_size: int = 512,
):
    """
    Retry the response for the input text and images.
    """
    if len(history) == 0:
        yield (chatbot, history, "Empty context")
        return

    chatbot.pop()
    history.pop()
    text = history.pop()[-1]
    if type(text) is tuple:
        text, _ = text

    yield from predict(
        text,
        images,
        chatbot,
        history,
        top_p,
        temperature,
        max_generate_length,
        max_context_length_tokens,
        video_nframes,
        chunk_size,
    )


def build_demo(args: argparse.Namespace) -> gr.Blocks:
    with gr.Blocks(theme=gr.themes.Soft(), delete_cache=(1800, 1800)) as demo:
        history = gr.State([])
        input_text = gr.State()
        input_images = gr.State()

        with gr.Row():
            gr.HTML(TITLE)
            status_display = gr.Markdown("Success", elem_id="status_display")
        gr.Markdown(DESCRIPTION_TOP)

        with gr.Row(equal_height=True):
            with gr.Column(scale=4):
                with gr.Row():
                    chatbot = gr.Chatbot(
                        elem_id="Eagle2-VL-8B-chatbot",
                        show_share_button=True,
                        bubble_full_width=False,
                        height=600,
                    )
                with gr.Row():
                    with gr.Column(scale=4):
                        text_box = gr.Textbox(show_label=False, placeholder="Enter text", container=False)
                    with gr.Column(min_width=70):
                        submit_btn = gr.Button("Send")
                    with gr.Column(min_width=70):
                        cancel_btn = gr.Button("Stop")
                with gr.Row():
                    empty_btn = gr.Button("๐Ÿงน New Conversation")
                    retry_btn = gr.Button("๐Ÿ”„ Regenerate")
                    del_last_btn = gr.Button("๐Ÿ—‘๏ธ Remove Last Turn")

            with gr.Column():
                # add note no more than 2 images once
                gr.Markdown("Note: you can upload images or videos!")
                upload_images = gr.Files(file_types=["image", "video"], show_label=True)
                gallery = gr.Gallery(columns=[3], height="200px", show_label=True)
                upload_images.change(preview_images, inputs=upload_images, outputs=gallery)
                
                # Parameter Setting Tab for control the generation parameters
                with gr.Tab(label="Parameter Setting"):
                    top_p = gr.Slider(minimum=-0, maximum=1.0, value=1.0, step=0.05, interactive=True, label="Top-p")
                    temperature = gr.Slider(
                        minimum=0, maximum=1.0, value=0.8, step=0.1, interactive=True, label="Temperature"
                    )
                    max_generate_length = gr.Slider(
                        minimum=512, maximum=8192, value=4096, step=64, interactive=True, label="Max Generate Length"
                    )
                    max_context_length_tokens = gr.Slider(
                        minimum=512, maximum=65536, value=16384, step=64, interactive=True, label="Max Context Length Tokens"
                    )
                    video_nframes = gr.Slider(
                        minimum=1, maximum=128, value=16, step=1, interactive=True, label="Video Nframes"
                    )
                    show_images = gr.HTML(visible=False)
                gr.Markdown("This demo is based on `moonshotai/Kimi-VL-A3B-Thinking` & `deepseek-ai/deepseek-vl2-small` and extends it by adding support for video input.")

        gr.Examples(
            examples=get_examples(ROOT_DIR),
            inputs=[upload_images, show_images, text_box],
        )
        gr.Markdown()

        input_widgets = [
            input_text,
            input_images,
            chatbot,
            history,
            top_p,
            temperature,
            max_generate_length,
            max_context_length_tokens,
            video_nframes
        ]
        output_widgets = [chatbot, history, status_display]

        transfer_input_args = dict(
            fn=transfer_input,
            inputs=[text_box, upload_images],
            outputs=[input_text, input_images, text_box, upload_images, submit_btn],
            show_progress=True,
        )

        predict_args = dict(fn=predict, inputs=input_widgets, outputs=output_widgets, show_progress=True)
        retry_args = dict(fn=retry, inputs=input_widgets, outputs=output_widgets, show_progress=True)
        reset_args = dict(fn=reset_textbox, inputs=[], outputs=[text_box, status_display])

        predict_events = [
            text_box.submit(**transfer_input_args).then(**predict_args),
            submit_btn.click(**transfer_input_args).then(**predict_args),
        ]

        empty_btn.click(reset_state, outputs=output_widgets, show_progress=True)
        empty_btn.click(**reset_args)
        retry_btn.click(**retry_args)
        del_last_btn.click(delete_last_conversation, [chatbot, history], output_widgets, show_progress=True)
        cancel_btn.click(cancel_outputing, [], [status_display], cancels=predict_events)

    demo.title = "Eagle2-VL-8B Chatbot"
    return demo


def main(args: argparse.Namespace):
    demo = build_demo(args)
    reload_javascript()

    # concurrency_count=CONCURRENT_COUNT, max_size=MAX_EVENTS
    favicon_path = os.path.join("eagle_vl/serve/assets/favicon.ico")
    demo.queue().launch(
        favicon_path=favicon_path,
        server_name=args.ip,
        server_port=args.port,
    )


if __name__ == "__main__":
    args = parse_args()
    main(args)