import gradio as gr
import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer

# Load model and tokenizer
model = GPT2LMHeadModel.from_pretrained("gpt2")
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")

def get_next_token_probs(text):
    # Handle empty input
    if not text.strip():
        return ["No input text"] * 5
    
    # Tokenize input
    input_ids = tokenizer.encode(text, return_tensors="pt")
    
    # Get predictions
    with torch.no_grad():
        outputs = model(input_ids)
        logits = outputs.logits
    
    # Get probabilities for next token
    next_token_logits = logits[0, -1, :]
    next_token_probs = torch.softmax(next_token_logits, dim=0)
    
    # Get top-5 tokens and their probabilities
    topk_probs, topk_indices = torch.topk(next_token_probs, 5)
    topk_tokens = [tokenizer.decode([idx]) for idx in topk_indices]
    
    # Format the results as strings
    formatted_results = []
    for i, (token, prob) in enumerate(zip(topk_tokens, topk_probs)):
        # Format probability as percentage with 1 decimal place
        prob_percent = f"{prob.item()*100:.1f}%"
        # Clean up token display (replace space with visible space symbol)
        display_token = token.replace(" ", "␣")
        # Format the output string
        formatted_results.append(f"{i+1}. \"{display_token}\" ({prob_percent})")
    
    return formatted_results

# Create minimal interface with simpler components
with gr.Blocks(css="footer {display: none}") as demo:
    gr.Markdown("### GPT-2 Next Token Predictor")
    
    # Input textbox
    input_text = gr.Textbox(
        label="Text Input",
        placeholder="Type here and watch predictions update...",
        value="The weather tomorrow will be"
    )
    
    # Simple header for results
    gr.Markdown("##### Most likely next tokens:")
    
    # Individual output textboxes for each token
    token1 = gr.Markdown()
    token2 = gr.Markdown()
    token3 = gr.Markdown()
    token4 = gr.Markdown()
    token5 = gr.Markdown()
    
    token_outputs = [token1, token2, token3, token4, token5]
    
    # Set up the live update
    input_text.change(
        fn=get_next_token_probs,
        inputs=input_text,
        outputs=token_outputs
    )
    
    # Initialize with default text
    demo.load(
        fn=get_next_token_probs,
        inputs=input_text,
        outputs=token_outputs
    )

# Launch the app
demo.launch()