Spaces:
Running
Running
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
import gradio as gr | |
import re | |
# Set page title and description | |
title = "π Pickup Line Generator" | |
description = """ | |
<div style="text-align: center; max-width: 650px; margin: 0 auto;"> | |
<div> | |
<p style="color: #333333; font-size: 1.1rem; font-weight: 500;">Generate fun, clever, or cringey pickup lines using SmolLM-135M! Select a vibe and click generate to get started! π</p> | |
</div> | |
</div> | |
""" | |
# Load model and tokenizer | |
print("Loading SmolLM-135M model...") | |
MODEL_NAME = "HuggingFaceTB/SmolLM-135M" | |
# Check for CUDA availability | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
print(f"Using device: {device}") | |
# Load the model and tokenizer | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
# Set pad_token to eos_token to handle padding | |
tokenizer.pad_token = tokenizer.eos_token | |
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME).to(device) | |
print(f"Model loaded successfully! Memory footprint: {model.get_memory_footprint() / 1e6:.2f} MB") | |
def get_vibe_guidance(vibe): | |
"""Get specific guidance for each vibe with examples""" | |
vibe_patterns = { | |
"romantic": """Generate a romantic and sweet pickup line that's genuine and heartfelt. | |
Example: | |
Input: Generate a romantic pickup line | |
Output: Are you a magician? Because whenever I look at you, everyone else disappears. β€οΈ | |
Now generate a romantic pickup line: """, | |
"cheesy": """Generate a super cheesy and over-the-top pickup line. | |
Example: | |
Input: Generate a cheesy pickup line | |
Output: Are you a parking ticket? Because you've got FINE written all over you! π | |
Now generate a cheesy pickup line: """, | |
"nerdy": """Generate a nerdy, science-themed pickup line. | |
Example: | |
Input: Generate a nerdy pickup line | |
Output: Are you made of copper and tellurium? Because you're Cu-Te! π¬ | |
Now generate a nerdy pickup line: """, | |
"cringe": """Generate the most cringey and over-the-top pickup line imaginable. | |
Example: | |
Input: Generate a cringe pickup line | |
Output: Are you a dictionary? Because you're adding meaning to my life! π | |
Now generate a cringe pickup line: """, | |
"flirty": """Generate a bold and flirty pickup line. | |
Example: | |
Input: Generate a flirty pickup line | |
Output: Is your name Google? Because you've got everything I've been searching for! π | |
Now generate a flirty pickup line: """ | |
} | |
return vibe_patterns.get(vibe, "Generate a pickup line with a ") | |
def generate_pickup_line(vibe): | |
"""Generate a pickup line based on the selected vibe""" | |
# Get the vibe guidance | |
vibe_guide = get_vibe_guidance(vibe) | |
# Create the prompt | |
prompt = f"""Instructions: Generate a pickup line with a {vibe} vibe. | |
{vibe_guide}""" | |
# Prepare inputs with explicit attention mask | |
encoded_input = tokenizer.encode_plus( | |
prompt, | |
return_tensors="pt", | |
padding=True, | |
return_attention_mask=True | |
) | |
input_ids = encoded_input["input_ids"].to(device) | |
attention_mask = encoded_input["attention_mask"].to(device) | |
# Generate multiple responses and pick the best one | |
num_tries = 3 | |
best_response = None | |
for _ in range(num_tries): | |
with torch.no_grad(): | |
outputs = model.generate( | |
input_ids, | |
attention_mask=attention_mask, | |
max_new_tokens=100, | |
do_sample=True, | |
temperature=0.8, | |
top_p=0.92, | |
top_k=50, | |
pad_token_id=tokenizer.eos_token_id, | |
eos_token_id=tokenizer.eos_token_id, | |
) | |
# Get the full generated text | |
full_response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
# Extract just the pickup line | |
if full_response.startswith(prompt): | |
response = full_response[len(prompt):].strip() | |
else: | |
response = full_response.replace(prompt, "").strip() | |
# Clean up the response | |
for marker in ["Instructions:", "Generate a pickup line:", "\n"]: | |
if marker in response: | |
response = response.split(marker, 1)[0].strip() | |
# Add appropriate emoji based on vibe | |
if vibe == "romantic": | |
response += " β€οΈ" | |
elif vibe == "cheesy": | |
response += " π" | |
elif vibe == "nerdy": | |
response += " π¬" | |
elif vibe == "cringe": | |
response += " π" | |
elif vibe == "flirty": | |
response += " π" | |
best_response = response | |
break | |
return best_response | |
# Create custom CSS | |
custom_css = """ | |
#component-0 { | |
background-color: #fef6f9 !important; | |
max-width: 800px !important; | |
margin: 0 auto !important; | |
padding: 2rem !important; | |
border-radius: 15px !important; | |
box-shadow: 0 4px 6px rgba(255, 105, 180, 0.1) !important; | |
} | |
h1 { | |
font-family: 'Lobster', cursive !important; | |
color: #ff1493 !important; | |
text-align: center !important; | |
font-size: 2.5rem !important; | |
margin-bottom: 1rem !important; | |
text-shadow: 1px 1px 2px rgba(0,0,0,0.1) !important; | |
} | |
.generate-btn { | |
background: linear-gradient(45deg, #ff69b4, #ff1493) !important; | |
color: white !important; | |
border: none !important; | |
padding: 0.75rem 1.5rem !important; | |
border-radius: 8px !important; | |
font-weight: bold !important; | |
transition: all 0.3s ease !important; | |
width: 100% !important; | |
margin-top: 1rem !important; | |
font-size: 1.1rem !important; | |
} | |
.generate-btn:hover { | |
transform: translateY(-2px) !important; | |
box-shadow: 0 4px 8px rgba(255, 105, 180, 0.3) !important; | |
} | |
.copy-btn { | |
background: white !important; | |
color: #ff1493 !important; | |
border: 2px solid #ff1493 !important; | |
padding: 0.75rem 1.5rem !important; | |
border-radius: 8px !important; | |
font-weight: bold !important; | |
transition: all 0.3s ease !important; | |
width: 100% !important; | |
font-size: 1.1rem !important; | |
} | |
.copy-btn:hover { | |
background: #ff1493 !important; | |
color: white !important; | |
} | |
.gr-dropdown { | |
border: 2px solid #ff69b4 !important; | |
border-radius: 8px !important; | |
background: white !important; | |
} | |
.gr-dropdown > select { | |
color: #333333 !important; | |
font-size: 1.1rem !important; | |
} | |
.gr-textbox { | |
border: 2px solid #ff69b4 !important; | |
border-radius: 8px !important; | |
background: white !important; | |
color: #333333 !important; | |
font-size: 1.1rem !important; | |
} | |
.gr-form { | |
background: white !important; | |
padding: 1rem !important; | |
border-radius: 8px !important; | |
box-shadow: 0 2px 4px rgba(0,0,0,0.1) !important; | |
} | |
/* Label styling */ | |
label { | |
color: #333333 !important; | |
font-weight: 600 !important; | |
font-size: 1.1rem !important; | |
} | |
/* Examples styling */ | |
.gr-samples { | |
background: white !important; | |
border-radius: 8px !important; | |
padding: 1rem !important; | |
margin-top: 1rem !important; | |
} | |
.gr-samples button { | |
color: #333333 !important; | |
font-weight: 500 !important; | |
} | |
/* Footer styling */ | |
.footer-text { | |
color: #666666 !important; | |
font-size: 1rem !important; | |
margin-top: 2rem !important; | |
} | |
""" | |
# Create the Gradio interface | |
with gr.Blocks(theme=gr.themes.Soft(), css=custom_css) as demo: | |
gr.Markdown(f"# {title}") | |
gr.Markdown(description) | |
with gr.Row(): | |
with gr.Column(): | |
vibe_dropdown = gr.Dropdown( | |
choices=[ | |
"romantic", | |
"cheesy", | |
"nerdy", | |
"cringe", | |
"flirty" | |
], | |
label="Pick a vibe", | |
value="romantic" | |
) | |
generate_btn = gr.Button("Generate Line β¨", elem_classes=["generate-btn"]) | |
with gr.Column(): | |
output = gr.Textbox( | |
label="Your pickup line", | |
lines=3, | |
interactive=False | |
) | |
copy_btn = gr.Button("π Copy to Clipboard", elem_classes=["copy-btn"]) | |
generate_btn.click( | |
fn=generate_pickup_line, | |
inputs=[vibe_dropdown], | |
outputs=output | |
) | |
# Updated copy functionality | |
copy_btn.click( | |
fn=lambda x: x, | |
inputs=[output], | |
outputs=[], | |
js="(text) => { navigator.clipboard.writeText(text); }" | |
) | |
# Footer with custom class | |
gr.Markdown(""" | |
<div class="footer-text" style="text-align: center;"> | |
Built by Nath with SmolLM π₯ | |
</div> | |
""") | |
# Launch the app | |
if __name__ == "__main__": | |
demo.launch(share=True) # Set share=False if you don't want to create a public link | |