import torch
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer

# Function to load the model and tokenizer (only needs to run once)
def load_model():
    model_id = "microsoft/bitnet-b1.58-2B-4T"
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        torch_dtype=torch.bfloat16,
        device_map="auto"  # This will use available GPU if present
    )
    return model, tokenizer

# Load the model and tokenizer
print("Loading model, please wait...")
model, tokenizer = load_model()
print("Model loaded successfully!")

def generate_response(message, chat_history, max_length=4096):
    """
    Generates a response from the BitNet model based on the user's message
    """
    if not message.strip():
        return "", chat_history
    
    # Create a chat prompt based on the history and new message
    full_prompt = ""
    for user_msg, bot_msg in chat_history:
        full_prompt += f"User: {user_msg}\nAssistant: {bot_msg}\n\n"
    
    full_prompt += f"User: {message}\nAssistant:"

    # Create inputs for the model
    inputs = tokenizer(full_prompt, return_tensors="pt").to(model.device)
    
    # Generate response
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_length,
            do_sample=True,
            temperature=0.7,  # Slightly higher temperature for more creative responses
            top_p=0.95,
        )
    
    # Extract only the generated part (the response)
    response = tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)
    
    # Update chat history
    chat_history.append((message, response.strip()))
    
    return "", chat_history

# Define the Gradio interface
def create_chat_interface():
    with gr.Blocks(title="BitNet Chat Assistant") as demo:
        gr.Markdown("# 💬 BitNet Chat Assistant")
        gr.Markdown("A lightweight chat application powered by Microsoft's BitNet b1.58 2B4T model.")
        
        chatbot = gr.Chatbot(height=400)
        msg = gr.Textbox(
            show_label=False,
            placeholder="Type your message here...",
            container=False
        )
        
        clear = gr.Button("Clear Conversation")
        
        def clear_convo():
            return "", []
        
        msg.submit(
            fn=generate_response,
            inputs=[msg, chatbot],
            outputs=[msg, chatbot]
        )
        
        clear.click(fn=clear_convo, inputs=[], outputs=[msg, chatbot])
        
        # Add some example inputs
        examples = [
            ["Hello, how are you today?"],
            ["Can you tell me about artificial intelligence?"],
            ["What's your favorite book?"],
            ["Write a short poem about technology."],
        ]
        gr.Examples(examples=examples, inputs=[msg])
        
        gr.Markdown("""
        ## About
        This application uses Microsoft's BitNet b1.58 2B4T, a 1-bit Large Language Model, for conversational AI.
        The model runs efficiently on consumer hardware due to its 1-bit architecture, offering significant
        advantages in memory usage, energy consumption, and latency.
        
        Note: This is a demonstration of the lightweight model's capabilities.
        """)
        
    return demo

# Create and launch the Gradio interface
if __name__ == "__main__":
    demo = create_chat_interface()
    demo.launch(share=True)  # Set share=False if you don't want a public link