import gradio as gr
import transformers
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from PIL import Image
import warnings

# Disable warnings and progress bars
transformers.logging.set_verbosity_error()
transformers.logging.disable_progress_bar()
warnings.filterwarnings('ignore')

# Initialize model and tokenizer
def load_model(device='cpu'):
    model = AutoModelForCausalLM.from_pretrained(
        'qnguyen3/nanoLLaVA',
        torch_dtype=torch.float16,
        device_map='auto',
        trust_remote_code=True
    )
    tokenizer = AutoTokenizer.from_pretrained(
        'qnguyen3/nanoLLaVA',
        trust_remote_code=True
    )
    return model, tokenizer

def generate_caption(image, model, tokenizer):
    # Prepare the prompt
    prompt = "Describe this image in detail"
    messages = [
        {"role": "system", "content": "Answer the question"},
        {"role": "user", "content": f'<image>\n{prompt}'}
    ]
    
    # Apply chat template
    text = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True
    )
    
    # Process text and image
    text_chunks = [tokenizer(chunk).input_ids for chunk in text.split('<image>')]
    input_ids = torch.tensor(text_chunks[0] + [-200] + text_chunks[1], dtype=torch.long).unsqueeze(0)
    image_tensor = model.process_images([image], model.config).to(dtype=model.dtype)
    
    # Generate caption
    output_ids = model.generate(
        input_ids,
        images=image_tensor,
        max_new_tokens=2048,
        use_cache=True
    )[0]
    
    # Decode the output
    caption = tokenizer.decode(output_ids[input_ids.shape[1]:], skip_special_tokens=True).strip()
    return caption

def create_persona(caption):
    persona_prompt = f"""<|im_start|>system
You are a character based on this description: {caption}

Role: An entity exactly as described in the image
Background: Your appearance and characteristics match the image description
Personality: Reflect the mood, style, and elements captured in the image
Goal: Interact authentically based on your visual characteristics

Please stay in character and respond as this entity would, incorporating visual elements from your description into your responses.<|im_end|>"""
    
    return persona_prompt

def process_image_to_persona(image, model, tokenizer):
    if image is None:
        return "Please upload an image.", ""
    # Convert to PIL Image if needed
    if not isinstance(image, Image.Image):
        image = Image.fromarray(image)
    
    # Generate caption from image
    caption = generate_caption(image, model, tokenizer)
    
    # Transform caption into persona
    persona = create_persona(caption)
    
    return caption, persona

# Create Gradio interface
def create_interface():
    # Load model and tokenizer
    model, tokenizer = load_model()
    
    with gr.Blocks() as app:
        gr.Markdown("# Image to Chatbot Persona Generator")
        gr.Markdown("Upload an image of a character to generate a persona for a chatbot based on the image.")
        
        with gr.Row():
            image_input = gr.Image(type="pil", label="Upload Character Image")
        
        with gr.Row():
            generate_button = gr.Button("Generate Persona")
        
        with gr.Row():
            caption_output = gr.Textbox(label="Generated Caption", lines=3)
            persona_output = gr.Textbox(label="Chatbot Persona", lines=10)
        
        generate_button.click(
            fn=lambda img: process_image_to_persona(img, model, tokenizer),
            inputs=[image_input],
            outputs=[caption_output, persona_output]
        )
    
    return app

# Launch the app
if __name__ == "__main__":
    app = create_interface()
    app.launch(share=True)