import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

# Load model and tokenizer
model = AutoModelForCausalLM.from_pretrained("HuggingFaceTB/SmolLM2-135M")
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM2-135M")

def get_next_token_probs(text):
    # Handle empty input
    if not text.strip():
        return ["No input text"] * 20
    
    # Tokenize input
    input_ids = tokenizer.encode(text, return_tensors="pt")
    
    # Get predictions
    with torch.no_grad():
        outputs = model(input_ids)
        logits = outputs.logits
    
    # Get probabilities for next token
    next_token_logits = logits[0, -1, :]
    next_token_probs = torch.softmax(next_token_logits, dim=0)
    
    # Get top-20 tokens and their probabilities
    topk_probs, topk_indices = torch.topk(next_token_probs, 20)
    topk_tokens = [tokenizer.decode([idx]) for idx in topk_indices]
    
    # Format the results as strings
    formatted_results = []
    for i, (token, prob) in enumerate(zip(topk_tokens, topk_probs)):
        # Format probability as percentage with 1 decimal place
        prob_percent = f"{prob.item()*100:.1f}%"
        # Clean up token display (replace space with visible space symbol)
        display_token = token.replace(" ", "␣")
        # Format the output string
        formatted_results.append(f"{i+1}. \"{display_token}\" ({prob_percent})")
    
    return formatted_results

# Create minimal interface with simpler components
with gr.Blocks(css="footer {display: none}") as demo:
    gr.Markdown("### SmolLM2 Next Token Predictor")
    
    # Input textbox
    input_text = gr.Textbox(
        label="Text Input",
        placeholder="Type here and watch predictions update...",
        value="The weather tomorrow will be"
    )
    
    # Simple header for results
    gr.Markdown("##### Most likely next tokens:")
    
    # Create 20 individual output markdown components
    token_outputs = [gr.Markdown() for _ in range(20)]
    
    # Set up the live update
    input_text.change(
        fn=get_next_token_probs,
        inputs=input_text,
        outputs=token_outputs
    )
    
    # Initialize with default text
    demo.load(
        fn=get_next_token_probs,
        inputs=input_text,
        outputs=token_outputs
    )

# Launch the app
demo.launch()