import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, StoppingCriteria
from modeling_llava_qwen2 import LlavaQwen2ForCausalLM
from threading import Thread
import re
import time 
from PIL import Image
import torch
import spaces
import subprocess
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)

tokenizer = AutoTokenizer.from_pretrained(
    'qnguyen3/nanoLLaVA-1.5',
    trust_remote_code=True)

model = LlavaQwen2ForCausalLM.from_pretrained(
    'qnguyen3/nanoLLaVA-1.5',
    torch_dtype=torch.bfloat16,
    trust_remote_code=True,
    device_map='auto')

class KeywordsStoppingCriteria(StoppingCriteria):
    def __init__(self, keywords, tokenizer, input_ids):
        self.keywords = keywords
        self.keyword_ids = []
        self.max_keyword_len = 0
        for keyword in keywords:
            cur_keyword_ids = tokenizer(keyword).input_ids
            if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id:
                cur_keyword_ids = cur_keyword_ids[1:]
            if len(cur_keyword_ids) > self.max_keyword_len:
                self.max_keyword_len = len(cur_keyword_ids)
            self.keyword_ids.append(torch.tensor(cur_keyword_ids))
        self.tokenizer = tokenizer
        self.start_len = input_ids.shape[1]
        
    def call_for_batch(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
        offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len)
        self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids]
        for keyword_id in self.keyword_ids:
            truncated_output_ids = output_ids[0, -keyword_id.shape[0]:]
            if torch.equal(truncated_output_ids, keyword_id):
                return True
        outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0]
        for keyword in self.keywords:
            if keyword in outputs:
                return True
        return False
        
    def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
        outputs = []
        for i in range(output_ids.shape[0]):
            outputs.append(self.call_for_batch(output_ids[i].unsqueeze(0), scores))
        return all(outputs)


@spaces.GPU
def bot_streaming(message, history):
    messages = []
    if message["files"]:
      image = message["files"][-1]["path"]
    else:
      for i, hist in enumerate(history):
        if type(hist[0])==tuple:
          image = hist[0][0]
          image_turn = i
            
    if len(history) > 0 and image is not None:
        messages.append({"role": "user", "content": f'<image>\n{history[1][0]}'})
        messages.append({"role": "assistant", "content": history[1][1] })
        for human, assistant in history[2:]:
            messages.append({"role": "user", "content": human })
            messages.append({"role": "assistant", "content": assistant })
        messages.append({"role": "user", "content": message['text']})
    elif len(history) > 0 and image is None:
        for human, assistant in history:
            messages.append({"role": "user", "content": human })
            messages.append({"role": "assistant", "content": assistant })
        messages.append({"role": "user", "content": message['text']})
    elif len(history) == 0 and image is not None:
        messages.append({"role": "user", "content": f"<image>\n{message['text']}"})
    elif len(history) == 0 and image is None:
        messages.append({"role": "user", "content": message['text'] })

    # if image is None:
    #     gr.Error("You need to upload an image for LLaVA to work.")
    image = Image.open(image).convert("RGB")
    text = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True)
    text_chunks = [tokenizer(chunk).input_ids for chunk in text.split('<image>')]
    input_ids = torch.tensor(text_chunks[0] + [-200] + text_chunks[1], dtype=torch.long).unsqueeze(0)
    stop_str = '<|im_end|>'
    keywords = [stop_str]
    stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
    streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
    
    image_tensor = model.process_images([image], model.config).to(dtype=model.dtype)
    generation_kwargs = dict(input_ids=input_ids.to('cuda'), 
                             images=image_tensor.to('cuda'), 
                             streamer=streamer, max_new_tokens=512, 
                             stopping_criteria=[stopping_criteria], temperature=0.01)
    generated_text = ""
    thread = Thread(target=model.generate, kwargs=generation_kwargs)
    thread.start()
    text_prompt =f"<|im_start|>user\n{message['text']}<|im_end|>"
    
    buffer = ""
    for new_text in streamer:
      
      buffer += new_text
      
      generated_text_without_prompt = buffer[:]
      time.sleep(0.04)
      yield generated_text_without_prompt


demo = gr.ChatInterface(fn=bot_streaming, title="🚀nanoLLaVA-1.5", examples=[{"text": "Who is this guy?", "files":["./demo_1.jpg"]},
                                                                      {"text": "What does the text say?", "files":["./demo_2.jpeg"]}], 
                        description="Try [nanoLLaVA](https://huggingface.co/qnguyen3/nanoLLaVA-1.5) in this demo. Built on top of [Quyen-SE-v0.1](https://huggingface.co/vilm/Quyen-SE-v0.1) (Qwen1.5-0.5B) and [Google SigLIP-400M](https://huggingface.co/google/siglip-so400m-patch14-384). Upload an image and start chatting about it, or simply try one of the examples below. If you don't upload an image, you will receive an error.",
                        stop_btn="Stop Generation", multimodal=True)
demo.queue().launch()