File size: 3,782 Bytes
6db1e39
5c0bae3
6db1e39
 
 
 
5c0bae3
 
6db1e39
5c0bae3
 
6db1e39
5c0bae3
 
6db1e39
 
 
5c0bae3
 
 
6db1e39
 
5c0bae3
6db1e39
 
 
 
 
5c0bae3
 
 
6db1e39
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
from transformers import Blip2Processor, Blip2ForConditionalGeneration, AutoModelForCausalLM, AutoTokenizer
from accelerate import init_empty_weights, load_checkpoint_and_dispatch
import torch
from PIL import Image, ImageDraw, ImageFont
import gradio as gr
import os
os.makedirs("./offload", exist_ok=True)
from accelerate import infer_auto_device_map

torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True


# For BLIP-2
blip_model = Blip2ForConditionalGeneration.from_pretrained(
    "Salesforce/blip2-opt-2.7b",
    torch_dtype=torch.float16,
    device_map="auto",
    offload_folder="./offload",
    no_split_module_classes=["Blip2QFormerModel"]
)

# For Phi-3
phi_model = AutoModelForCausalLM.from_pretrained(
    "microsoft/Phi-3-mini-4k-instruct",
    trust_remote_code=True,
    device_map="auto",
    torch_dtype=torch.float16,
    offload_folder="./offload",
    no_split_module_classes=["PhiDecoderLayer"],
    load_in_4bit=True  # Add 4-bit quantization
)
phi_tokenizer = AutoTokenizer.from_pretrained(
    "microsoft/Phi-3-mini-4k-instruct",
    token=HF_TOKEN
)

def analyze_image(image):
    inputs = blip_processor(image, return_tensors="pt").to(blip_model.device)
    generated_ids = blip_model.generate(**inputs, max_length=50)
    return blip_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]

def generate_meme_caption(image_desc, user_prompt):
    messages = [
        {"role": "system", "content": "You are a meme expert. Create funny captions in format: TOP TEXT | BOTTOM TEXT"},
        {"role": "user", "content": f"Image context: {image_desc}\nUser input: {user_prompt}\nGenerate 3 meme captions (max 10 words each):"}
    ]
    
    inputs = phi_tokenizer.apply_chat_template(
        messages,
        return_tensors="pt",
        add_generation_prompt=True
    ).to(phi_model.device)
    
    outputs = phi_model.generate(
        inputs,
        max_new_tokens=200,
        temperature=0.7,
        do_sample=True
    )
    return phi_tokenizer.decode(outputs[0], skip_special_tokens=True)

def create_meme(image, top_text, bottom_text):
    img = image.copy()
    draw = ImageDraw.Draw(img)
    
    # Use available font (works in Colab/Spaces)
    try:
        font = ImageFont.truetype("arial.ttf", size=min(img.size)//12)
    except:
        font = ImageFont.load_default()
    
    # Top text
    draw.text(
        (img.width/2, 10), 
        top_text, 
        font=font, 
        fill="white", 
        anchor="mt",
        stroke_width=2, 
        stroke_fill="black"
    )
    
    # Bottom text
    draw.text(
        (img.width/2, img.height-10), 
        bottom_text, 
        font=font, 
        fill="white", 
        anchor="mb",
        stroke_width=2, 
        stroke_fill="black"
    )
    return img

def process_meme(image, user_prompt):
    image_desc = analyze_image(image)
    raw_output = generate_meme_caption(image_desc, user_prompt)
    
    captions = []
    for line in raw_output.split("\n"):
        if "|" in line:
            parts = line.split("|", 1)
            if len(parts) == 2:
                captions.append((parts[0].strip(), parts[1].strip()))
    
    memes = [create_meme(image, top, bottom) for top, bottom in captions[:3]]
    return memes

with gr.Blocks(title="AI Meme Generator") as demo:
    gr.Markdown("# 🚀 AI Meme Generator")
    with gr.Row():
        image_input = gr.Image(type="pil", label="Upload Image")
        text_input = gr.Textbox(label="Meme Theme/Prompt")
    submit_btn = gr.Button("Generate Memes!")
    gallery = gr.Gallery(label="Generated Memes", columns=3)
    
    submit_btn.click(
        fn=process_meme,
        inputs=[image_input, text_input],
        outputs=gallery
    )

if __name__ == "__main__":
    demo.launch()