imageGenerator / app.py
gg3554's picture
Update app.py
9d9eb70 verified
import torch
from PIL import Image
import io
import os
from transformers import CLIPProcessor, CLIPModel, CLIPTokenizerFast, CLIPImageProcessor
import numpy as np
from diffusers import DiffusionPipeline
import warnings
import gradio as gr
warnings.filterwarnings("ignore")
# Global evaluator instance (lazy loaded)
evaluator = None
# Check if running on Hugging Face Spaces with ZeroGPU
try:
import spaces
ZERO_GPU_AVAILABLE = True
except ImportError:
ZERO_GPU_AVAILABLE = False
class TextToImageEvaluator:
def __init__(self):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.dtype = torch.float16 if self.device == "cuda" else torch.float32
clip_model_name = "openai/clip-vit-large-patch14-336"
tokenizer = CLIPTokenizerFast.from_pretrained(clip_model_name)
image_processor = CLIPImageProcessor.from_pretrained(clip_model_name)
self.clip_model = CLIPModel.from_pretrained(clip_model_name)
self.clip_processor = CLIPProcessor(tokenizer=tokenizer, image_processor=image_processor)
print("Loading image generation model...")
self.generator = DiffusionPipeline.from_pretrained(
"Lykon/dreamshaper-xl-v2-turbo",
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
)
self.clip_model.to(self.device)
self.generator.to(self.device)
if self.device == "cuda":
self.generator.enable_attention_slicing()
self.generator.enable_vae_slicing()
# Try to enable xformers if available
try:
self.generator.enable_xformers_memory_efficient_attention()
print("xformers enabled for memory efficient attention")
except Exception:
pass
else:
# CPU optimizations
self.generator.enable_attention_slicing(1)
print(f"Models loaded successfully on {self.device}")
def generate_image(self, text, num_inference_steps=6, guidance_scale=2):
"""Generate image from text using Stable Diffusion"""
self.generator.to(self.device)
generator = torch.Generator(device=self.generator.device).manual_seed(42)
with torch.inference_mode():
if self.device == "cuda":
with torch.autocast(device_type="cuda", dtype=torch.float16):
image = self.generator(
text,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
generator=generator
).images[0]
else:
image = self.generator(
text,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
generator=generator
).images[0]
if self.device == "cuda":
torch.cuda.empty_cache()
return image
def calculate_clip_score(self, image, text):
"""Calculate CLIPScore between image and text"""
self.clip_model.to(self.device)
inputs = self.clip_processor(
text=[text],
images=[image],
return_tensors="pt",
padding=True,
truncation=True,
max_length=77
)
inputs = {k: v.to(self.device) for k, v in inputs.items()}
with torch.inference_mode():
outputs = self.clip_model(**inputs)
image_embeds = outputs.image_embeds
text_embeds = outputs.text_embeds
# Normalize embeddings
image_embeds = image_embeds / image_embeds.norm(dim=-1, keepdim=True)
text_embeds = text_embeds / text_embeds.norm(dim=-1, keepdim=True)
# Calculate cosine similarity
similarity = (image_embeds * text_embeds).sum(dim=-1)
score = similarity.cpu().item()
if self.device == "cuda":
torch.cuda.empty_cache()
return score
def process_prompt(self, text):
"""Process a single text prompt and return image with scores"""
if not text or text.strip() == "":
raise gr.Error("Please enter a prompt")
text = text.strip()
print(f"Processing prompt: {text}")
# Generate image
print("Generating image...")
generated_image = self.generate_image(text)
# Calculate CLIP score
print("Calculating similarity scores...")
clip_score = self.calculate_clip_score(generated_image, text)
geneval_score = clip_score * 2.5
return generated_image, round(clip_score, 4), round(geneval_score, 4)
def get_evaluator():
"""Lazy load the evaluator"""
global evaluator
if evaluator is None:
evaluator = TextToImageEvaluator()
return evaluator
def generate_and_evaluate(prompt):
"""Main function for Gradio interface"""
eval_instance = get_evaluator()
image, clip_score, geneval_score = eval_instance.process_prompt(prompt)
return image, f"{clip_score}", f"{geneval_score}"
# Use ZeroGPU decorator if available on HF Spaces
if ZERO_GPU_AVAILABLE:
# Store reference to original function before reassignment
_generate_and_evaluate_impl = generate_and_evaluate
@spaces.GPU
def generate_and_evaluate_gpu(prompt):
return _generate_and_evaluate_impl(prompt)
generate_and_evaluate = generate_and_evaluate_gpu
# Create Gradio interface
with gr.Blocks(title="Text-to-Image Generator & Evaluator") as demo:
gr.Markdown(
"""
# 🎨 Text-to-Image Generator & Evaluator
Generate images from text prompts and evaluate them using CLIP scores.
- **CLIP Score**: Measures how well the generated image matches the text prompt (0-1 scale)
- **GenEval Score**: Scaled evaluation score (CLIP Score × 2.5)
"""
)
with gr.Row():
with gr.Column(scale=1):
prompt_input = gr.Textbox(
label="Enter your prompt",
placeholder="A beautiful sunset over mountains...",
lines=3
)
generate_btn = gr.Button("🚀 Generate Image", variant="primary")
gr.Markdown("### Evaluation Scores")
with gr.Row():
clip_score_output = gr.Textbox(label="CLIP Score", interactive=False)
geneval_score_output = gr.Textbox(label="GenEval Score", interactive=False)
with gr.Column(scale=1):
image_output = gr.Image(label="Generated Image", type="pil")
# Example prompts
gr.Examples(
examples=[
["A futuristic city with flying cars at night"],
["A cute cat wearing a wizard hat"],
["An astronaut riding a horse on Mars"],
["A cozy coffee shop interior with warm lighting"]
],
inputs=prompt_input
)
# Connect the button to the function
generate_btn.click(
fn=generate_and_evaluate,
inputs=prompt_input,
outputs=[image_output, clip_score_output, geneval_score_output]
)
# Also allow Enter key to submit
prompt_input.submit(
fn=generate_and_evaluate,
inputs=prompt_input,
outputs=[image_output, clip_score_output, geneval_score_output]
)
if __name__ == "__main__":
print("TEXT-TO-IMAGE GENERATOR - GRADIO APP")
print("=" * 60)
print(f"Device: {'CUDA' if torch.cuda.is_available() else 'CPU'}")
if torch.cuda.is_available():
print(f"GPU: {torch.cuda.get_device_name(0)}")
demo.queue(max_size=10).launch()