Spaces:
Sleeping
Sleeping
| import torch | |
| from PIL import Image | |
| import io | |
| import os | |
| from transformers import CLIPProcessor, CLIPModel, CLIPTokenizerFast, CLIPImageProcessor | |
| import numpy as np | |
| from diffusers import DiffusionPipeline | |
| import warnings | |
| import gradio as gr | |
| warnings.filterwarnings("ignore") | |
| # Global evaluator instance (lazy loaded) | |
| evaluator = None | |
| # Check if running on Hugging Face Spaces with ZeroGPU | |
| try: | |
| import spaces | |
| ZERO_GPU_AVAILABLE = True | |
| except ImportError: | |
| ZERO_GPU_AVAILABLE = False | |
| class TextToImageEvaluator: | |
| def __init__(self): | |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| self.dtype = torch.float16 if self.device == "cuda" else torch.float32 | |
| clip_model_name = "openai/clip-vit-large-patch14-336" | |
| tokenizer = CLIPTokenizerFast.from_pretrained(clip_model_name) | |
| image_processor = CLIPImageProcessor.from_pretrained(clip_model_name) | |
| self.clip_model = CLIPModel.from_pretrained(clip_model_name) | |
| self.clip_processor = CLIPProcessor(tokenizer=tokenizer, image_processor=image_processor) | |
| print("Loading image generation model...") | |
| self.generator = DiffusionPipeline.from_pretrained( | |
| "Lykon/dreamshaper-xl-v2-turbo", | |
| torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32 | |
| ) | |
| self.clip_model.to(self.device) | |
| self.generator.to(self.device) | |
| if self.device == "cuda": | |
| self.generator.enable_attention_slicing() | |
| self.generator.enable_vae_slicing() | |
| # Try to enable xformers if available | |
| try: | |
| self.generator.enable_xformers_memory_efficient_attention() | |
| print("xformers enabled for memory efficient attention") | |
| except Exception: | |
| pass | |
| else: | |
| # CPU optimizations | |
| self.generator.enable_attention_slicing(1) | |
| print(f"Models loaded successfully on {self.device}") | |
| def generate_image(self, text, num_inference_steps=6, guidance_scale=2): | |
| """Generate image from text using Stable Diffusion""" | |
| self.generator.to(self.device) | |
| generator = torch.Generator(device=self.generator.device).manual_seed(42) | |
| with torch.inference_mode(): | |
| if self.device == "cuda": | |
| with torch.autocast(device_type="cuda", dtype=torch.float16): | |
| image = self.generator( | |
| text, | |
| num_inference_steps=num_inference_steps, | |
| guidance_scale=guidance_scale, | |
| generator=generator | |
| ).images[0] | |
| else: | |
| image = self.generator( | |
| text, | |
| num_inference_steps=num_inference_steps, | |
| guidance_scale=guidance_scale, | |
| generator=generator | |
| ).images[0] | |
| if self.device == "cuda": | |
| torch.cuda.empty_cache() | |
| return image | |
| def calculate_clip_score(self, image, text): | |
| """Calculate CLIPScore between image and text""" | |
| self.clip_model.to(self.device) | |
| inputs = self.clip_processor( | |
| text=[text], | |
| images=[image], | |
| return_tensors="pt", | |
| padding=True, | |
| truncation=True, | |
| max_length=77 | |
| ) | |
| inputs = {k: v.to(self.device) for k, v in inputs.items()} | |
| with torch.inference_mode(): | |
| outputs = self.clip_model(**inputs) | |
| image_embeds = outputs.image_embeds | |
| text_embeds = outputs.text_embeds | |
| # Normalize embeddings | |
| image_embeds = image_embeds / image_embeds.norm(dim=-1, keepdim=True) | |
| text_embeds = text_embeds / text_embeds.norm(dim=-1, keepdim=True) | |
| # Calculate cosine similarity | |
| similarity = (image_embeds * text_embeds).sum(dim=-1) | |
| score = similarity.cpu().item() | |
| if self.device == "cuda": | |
| torch.cuda.empty_cache() | |
| return score | |
| def process_prompt(self, text): | |
| """Process a single text prompt and return image with scores""" | |
| if not text or text.strip() == "": | |
| raise gr.Error("Please enter a prompt") | |
| text = text.strip() | |
| print(f"Processing prompt: {text}") | |
| # Generate image | |
| print("Generating image...") | |
| generated_image = self.generate_image(text) | |
| # Calculate CLIP score | |
| print("Calculating similarity scores...") | |
| clip_score = self.calculate_clip_score(generated_image, text) | |
| geneval_score = clip_score * 2.5 | |
| return generated_image, round(clip_score, 4), round(geneval_score, 4) | |
| def get_evaluator(): | |
| """Lazy load the evaluator""" | |
| global evaluator | |
| if evaluator is None: | |
| evaluator = TextToImageEvaluator() | |
| return evaluator | |
| def generate_and_evaluate(prompt): | |
| """Main function for Gradio interface""" | |
| eval_instance = get_evaluator() | |
| image, clip_score, geneval_score = eval_instance.process_prompt(prompt) | |
| return image, f"{clip_score}", f"{geneval_score}" | |
| # Use ZeroGPU decorator if available on HF Spaces | |
| if ZERO_GPU_AVAILABLE: | |
| # Store reference to original function before reassignment | |
| _generate_and_evaluate_impl = generate_and_evaluate | |
| def generate_and_evaluate_gpu(prompt): | |
| return _generate_and_evaluate_impl(prompt) | |
| generate_and_evaluate = generate_and_evaluate_gpu | |
| # Create Gradio interface | |
| with gr.Blocks(title="Text-to-Image Generator & Evaluator") as demo: | |
| gr.Markdown( | |
| """ | |
| # 🎨 Text-to-Image Generator & Evaluator | |
| Generate images from text prompts and evaluate them using CLIP scores. | |
| - **CLIP Score**: Measures how well the generated image matches the text prompt (0-1 scale) | |
| - **GenEval Score**: Scaled evaluation score (CLIP Score × 2.5) | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| prompt_input = gr.Textbox( | |
| label="Enter your prompt", | |
| placeholder="A beautiful sunset over mountains...", | |
| lines=3 | |
| ) | |
| generate_btn = gr.Button("🚀 Generate Image", variant="primary") | |
| gr.Markdown("### Evaluation Scores") | |
| with gr.Row(): | |
| clip_score_output = gr.Textbox(label="CLIP Score", interactive=False) | |
| geneval_score_output = gr.Textbox(label="GenEval Score", interactive=False) | |
| with gr.Column(scale=1): | |
| image_output = gr.Image(label="Generated Image", type="pil") | |
| # Example prompts | |
| gr.Examples( | |
| examples=[ | |
| ["A futuristic city with flying cars at night"], | |
| ["A cute cat wearing a wizard hat"], | |
| ["An astronaut riding a horse on Mars"], | |
| ["A cozy coffee shop interior with warm lighting"] | |
| ], | |
| inputs=prompt_input | |
| ) | |
| # Connect the button to the function | |
| generate_btn.click( | |
| fn=generate_and_evaluate, | |
| inputs=prompt_input, | |
| outputs=[image_output, clip_score_output, geneval_score_output] | |
| ) | |
| # Also allow Enter key to submit | |
| prompt_input.submit( | |
| fn=generate_and_evaluate, | |
| inputs=prompt_input, | |
| outputs=[image_output, clip_score_output, geneval_score_output] | |
| ) | |
| if __name__ == "__main__": | |
| print("TEXT-TO-IMAGE GENERATOR - GRADIO APP") | |
| print("=" * 60) | |
| print(f"Device: {'CUDA' if torch.cuda.is_available() else 'CPU'}") | |
| if torch.cuda.is_available(): | |
| print(f"GPU: {torch.cuda.get_device_name(0)}") | |
| demo.queue(max_size=10).launch() |