import os
import random
import numpy as np
import torch
import torch.backends.cudnn as cudnn
import gradio as gr

os.system('git lfs install')
os.system("git clone https://huggingface.co/Pluto0616/L2_InternVL")

from L2_InternVL.utils import load_json, init_logger
from L2_InternVL.demo import ConversationalAgent, CustomTheme

FOOD_EXAMPLES = "./L2_InternVL/demo/food_for_demo.json"
# MODEL_PATH = "/root/share/new_models/OpenGVLab/InternVL2-2B"
MODEL_PATH = "./L2_InternVL/work_dirs/internvl_v2_internlm2_2b_lora_finetune_food/lr35_ep10"
OUTPUT_PATH = "./L2_InternVL/outputs"

def setup_seeds():
    seed = 42

    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

    cudnn.benchmark = False
    cudnn.deterministic = True


def main():
    setup_seeds()
    # logging
    init_logger(OUTPUT_PATH)
    # food examples
    food_examples = load_json(FOOD_EXAMPLES)
    
    agent = ConversationalAgent(model_path=MODEL_PATH,
                                outputs_dir=OUTPUT_PATH)
    
    theme = CustomTheme()
    
    titles = [
        """<center><B><font face="Comic Sans MS" size=10>书生大模型实战营</font></B></center>"""  ## Kalam:wght@700
        """<center><B><font face="Courier" size=5>「进阶岛」InternVL 多模态模型部署微调实践</font></B></center>"""
    ]
    
    language = """Language: 中文 and English"""
    with gr.Blocks(theme) as demo_chatbot:
        for title in titles:
            gr.Markdown(title)
        # gr.Markdown(article)
        gr.Markdown(language)
        
        with gr.Row():
            with gr.Column(scale=3):
                start_btn = gr.Button("Start Chat", variant="primary", interactive=True)
                clear_btn = gr.Button("Clear Context", interactive=False)
                image = gr.Image(type="pil", interactive=False)
                upload_btn = gr.Button("🖼️ Upload Image", interactive=False)
                
                with gr.Accordion("Generation Settings"):                    
                    top_p = gr.Slider(minimum=0, maximum=1, step=0.1,
                                      value=0.8,
                                      interactive=True,
                                      label='top-p value',
                                      visible=True)
                    
                    temperature = gr.Slider(minimum=0, maximum=1.5, step=0.1,
                                            value=0.8,
                                            interactive=True,
                                            label='temperature',
                                            visible=True)
                    
            with gr.Column(scale=7):
                chat_state = gr.State()
                chatbot = gr.Chatbot(label='InternVL2', height=800, avatar_images=((os.path.join(os.path.dirname(__file__), 'demo/user.png')), (os.path.join(os.path.dirname(__file__), "demo/bot.png"))))
                text_input = gr.Textbox(label='User', placeholder="Please click the <Start Chat> button to start chat!", interactive=False)
                gr.Markdown("### 输入示例")
                def on_text_change(text):
                    return gr.update(interactive=True)
                text_input.change(fn=on_text_change, inputs=text_input, outputs=text_input)
                gr.Examples(
                    examples=[["图片中的食物通常属于哪个菜系?"],
                              ["如果让你简单形容一下品尝图片中的食物的滋味,你会描述它"],
                              ["去哪个地方游玩时应该品尝当地的特色美食图片中的食物?"],
                              ["食用图片中的食物时,一般它上菜或摆盘时的特点是?"]],
                    inputs=[text_input]
                )
        
        with gr.Row():
            gr.Markdown("### 食物快捷栏")
        with gr.Row():
            example_xinjiang_food = gr.Examples(examples=food_examples["新疆菜"], inputs=image, label="新疆菜")
            example_sichuan_food = gr.Examples(examples=food_examples["川菜(四川,重庆)"], inputs=image, label="川菜(四川,重庆)")
            example_xibei_food = gr.Examples(examples=food_examples["西北菜 (陕西,甘肃等地)"], inputs=image, label="西北菜 (陕西,甘肃等地)")
        with gr.Row():
            example_guizhou_food = gr.Examples(examples=food_examples["黔菜 (贵州)"], inputs=image, label="黔菜 (贵州)")
            example_jiangsu_food = gr.Examples(examples=food_examples["苏菜(江苏)"], inputs=image, label="苏菜(江苏)")
            example_guangdong_food = gr.Examples(examples=food_examples["粤菜(广东等地)"], inputs=image, label="粤菜(广东等地)")
        with gr.Row():
            example_hunan_food = gr.Examples(examples=food_examples["湘菜(湖南)"], inputs=image, label="湘菜(湖南)")
            example_fujian_food = gr.Examples(examples=food_examples["闽菜(福建)"], inputs=image, label="闽菜(福建)")
            example_zhejiang_food = gr.Examples(examples=food_examples["浙菜(浙江)"], inputs=image, label="浙菜(浙江)")
        with gr.Row():
            example_dongbei_food = gr.Examples(examples=food_examples["东北菜 (黑龙江等地)"], inputs=image, label="东北菜 (黑龙江等地)")
            
                
        start_btn.click(agent.start_chat, [chat_state], [text_input, start_btn, clear_btn, image, upload_btn, chat_state])
        clear_btn.click(agent.restart_chat, [chat_state], [chatbot, text_input, start_btn, clear_btn, image, upload_btn, chat_state], queue=False)
        upload_btn.click(agent.upload_image, [image, chatbot, chat_state], [image, chatbot, chat_state])
        text_input.submit(
            agent.respond,
            inputs=[text_input, image, chatbot, top_p, temperature, chat_state], 
            outputs=[text_input, image, chatbot, chat_state]
        )

    demo_chatbot.launch(share=True, server_name="127.0.0.1", server_port=1096, allowed_paths=['./'])
    demo_chatbot.queue()
    

if __name__ == "__main__":
    main()