Spaces:
Running
Running
File size: 3,782 Bytes
6db1e39 5c0bae3 6db1e39 5c0bae3 6db1e39 5c0bae3 6db1e39 5c0bae3 6db1e39 5c0bae3 6db1e39 5c0bae3 6db1e39 5c0bae3 6db1e39 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 |
from transformers import Blip2Processor, Blip2ForConditionalGeneration, AutoModelForCausalLM, AutoTokenizer
from accelerate import init_empty_weights, load_checkpoint_and_dispatch
import torch
from PIL import Image, ImageDraw, ImageFont
import gradio as gr
import os
os.makedirs("./offload", exist_ok=True)
from accelerate import infer_auto_device_map
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
# For BLIP-2
blip_model = Blip2ForConditionalGeneration.from_pretrained(
"Salesforce/blip2-opt-2.7b",
torch_dtype=torch.float16,
device_map="auto",
offload_folder="./offload",
no_split_module_classes=["Blip2QFormerModel"]
)
# For Phi-3
phi_model = AutoModelForCausalLM.from_pretrained(
"microsoft/Phi-3-mini-4k-instruct",
trust_remote_code=True,
device_map="auto",
torch_dtype=torch.float16,
offload_folder="./offload",
no_split_module_classes=["PhiDecoderLayer"],
load_in_4bit=True # Add 4-bit quantization
)
phi_tokenizer = AutoTokenizer.from_pretrained(
"microsoft/Phi-3-mini-4k-instruct",
token=HF_TOKEN
)
def analyze_image(image):
inputs = blip_processor(image, return_tensors="pt").to(blip_model.device)
generated_ids = blip_model.generate(**inputs, max_length=50)
return blip_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
def generate_meme_caption(image_desc, user_prompt):
messages = [
{"role": "system", "content": "You are a meme expert. Create funny captions in format: TOP TEXT | BOTTOM TEXT"},
{"role": "user", "content": f"Image context: {image_desc}\nUser input: {user_prompt}\nGenerate 3 meme captions (max 10 words each):"}
]
inputs = phi_tokenizer.apply_chat_template(
messages,
return_tensors="pt",
add_generation_prompt=True
).to(phi_model.device)
outputs = phi_model.generate(
inputs,
max_new_tokens=200,
temperature=0.7,
do_sample=True
)
return phi_tokenizer.decode(outputs[0], skip_special_tokens=True)
def create_meme(image, top_text, bottom_text):
img = image.copy()
draw = ImageDraw.Draw(img)
# Use available font (works in Colab/Spaces)
try:
font = ImageFont.truetype("arial.ttf", size=min(img.size)//12)
except:
font = ImageFont.load_default()
# Top text
draw.text(
(img.width/2, 10),
top_text,
font=font,
fill="white",
anchor="mt",
stroke_width=2,
stroke_fill="black"
)
# Bottom text
draw.text(
(img.width/2, img.height-10),
bottom_text,
font=font,
fill="white",
anchor="mb",
stroke_width=2,
stroke_fill="black"
)
return img
def process_meme(image, user_prompt):
image_desc = analyze_image(image)
raw_output = generate_meme_caption(image_desc, user_prompt)
captions = []
for line in raw_output.split("\n"):
if "|" in line:
parts = line.split("|", 1)
if len(parts) == 2:
captions.append((parts[0].strip(), parts[1].strip()))
memes = [create_meme(image, top, bottom) for top, bottom in captions[:3]]
return memes
with gr.Blocks(title="AI Meme Generator") as demo:
gr.Markdown("# 🚀 AI Meme Generator")
with gr.Row():
image_input = gr.Image(type="pil", label="Upload Image")
text_input = gr.Textbox(label="Meme Theme/Prompt")
submit_btn = gr.Button("Generate Memes!")
gallery = gr.Gallery(label="Generated Memes", columns=3)
submit_btn.click(
fn=process_meme,
inputs=[image_input, text_input],
outputs=gallery
)
if __name__ == "__main__":
demo.launch() |