import gradio as gr import torch from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline # ============================================================ # 1️⃣ Load model and tokenizer # ============================================================ MODEL_ID = "gokaygokay/prompt-enhancer-gemma-3-270m-it" # Use GPU if available device = 0 if torch.cuda.is_available() else -1 tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) model = AutoModelForCausalLM.from_pretrained(MODEL_ID) pipe = pipeline( "text-generation", model=model, tokenizer=tokenizer, device=device, # 0 for GPU, -1 for CPU ) # ============================================================ # 2️⃣ Define the generation function (chat-template style) # ============================================================ def enhance_prompt(user_prompt, temperature, max_tokens, chat_history): chat_history = chat_history or [] # Build messages using proper roles messages = [ {"role": "system", "content": "Enhance and expand the following prompt with more details and context:"}, {"role": "user", "content": user_prompt} ] # Use tokenizer chat template to build the input prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) # Generate output output = pipe( prompt, max_new_tokens=int(max_tokens), temperature=float(temperature), do_sample=True, )[0]["generated_text"].strip() # Append conversation to history chat_history.append({"role": "user", "content": user_prompt}) chat_history.append({"role": "assistant", "content": output}) return chat_history # ============================================================ # 3️⃣ Gradio UI # ============================================================ with gr.Blocks(title="Prompt Enhancer – Gemma 3 270M", theme=gr.themes.Soft()) as demo: gr.Markdown( """ # ✨ Prompt Enhancer (Gemma 3 270M) Enter a short prompt, and the model will **expand it with details and creative context** using the Gemma chat-template interface. """ ) with gr.Row(): chatbot = gr.Chatbot(height=400, label="Enhanced Prompts", type="messages") with gr.Column(scale=1): user_prompt = gr.Textbox( placeholder="Enter a short prompt...", label="Your Prompt", lines=3, ) temperature = gr.Slider(0.0, 1.0, value=0.7, step=0.05, label="Temperature") max_tokens = gr.Slider(32, 256, value=128, step=16, label="Max Tokens") send_btn = gr.Button("🚀 Enhance Prompt", variant="primary") clear_btn = gr.Button("🧹 Clear Chat") # Bind UI actions send_btn.click(enhance_prompt, [user_prompt, temperature, max_tokens, chatbot], chatbot) user_prompt.submit(enhance_prompt, [user_prompt, temperature, max_tokens, chatbot], chatbot) clear_btn.click(lambda: [], None, chatbot) gr.Markdown( """ --- 💡 **Tips:** - Works best with short, descriptive prompts (e.g., "a cat sitting on a chair") - Increase *Temperature* for more creative output. """ ) # ============================================================ # 4️⃣ Launch # ============================================================ if __name__ == "__main__": demo.launch(show_error=True)