import torch from diffusers import DiffusionPipeline, StableDiffusionPipeline from PIL import Image import numpy as np from typing import Optional, List, Union import gc # Disable torch dynamo to avoid ConstantVariable errors torch._dynamo.config.suppress_errors = True class OmniGenImageGenerator: """Image generation using OmniGen2 model""" def __init__(self, device: str = "cuda"): self.device = device if torch.cuda.is_available() else "cpu" self.pipeline = None self.model_id = "runwayml/stable-diffusion-v1-5" # Using working Stable Diffusion model # Generation parameters self.default_width = 512 self.default_height = 512 self.num_inference_steps = 30 self.guidance_scale = 7.5 # Memory optimization self.enable_attention_slicing = True self.enable_vae_slicing = True self.enable_cpu_offload = self.device == "cuda" def load_model(self): """Lazy load the image generation model""" if self.pipeline is None: try: # Determine torch dtype torch_dtype = torch.float16 if self.device == "cuda" else torch.float32 # Load pipeline with optimizations self.pipeline = StableDiffusionPipeline.from_pretrained( self.model_id, torch_dtype=torch_dtype, use_safetensors=True, variant="fp16" if self.device == "cuda" else None ) # Apply optimizations and device placement if self.device == "cuda": if self.enable_cpu_offload: self.pipeline.enable_sequential_cpu_offload() else: # Safely move pipeline to CUDA try: self.pipeline = self.pipeline.to(self.device) except RuntimeError as e: if "meta tensor" in str(e): # Handle meta tensor issue by loading with device_map print(f"Meta tensor issue detected, using CPU fallback: {e}") self.device = "cpu" self.pipeline = self.pipeline.to("cpu") else: raise e if self.enable_attention_slicing and hasattr(self.pipeline, 'enable_attention_slicing'): self.pipeline.enable_attention_slicing(1) if self.enable_vae_slicing and hasattr(self.pipeline, 'enable_vae_slicing'): self.pipeline.enable_vae_slicing() else: self.pipeline = self.pipeline.to(self.device) # Disable torch.compile to avoid dynamo issues that cause ConstantVariable errors print("Skipping torch.compile to avoid dynamo compatibility issues") except Exception as e: print(f"Failed to load image generation model: {e}") # Try fallback to stable diffusion try: self.model_id = "runwayml/stable-diffusion-v1-5" self._load_fallback_model() except: raise def _load_fallback_model(self): """Load fallback Stable Diffusion model""" from diffusers import StableDiffusionPipeline torch_dtype = torch.float16 if self.device == "cuda" else torch.float32 self.pipeline = StableDiffusionPipeline.from_pretrained( self.model_id, torch_dtype=torch_dtype, use_safetensors=True, trust_remote_code=True ) if self.device == "cuda" and self.enable_cpu_offload: self.pipeline.enable_sequential_cpu_offload() else: self.pipeline = self.pipeline.to(self.device) def _truncate_prompt(self, prompt: str, max_tokens: int = 75) -> str: """Truncate prompt to fit CLIP token limit""" words = prompt.split() if len(words) <= max_tokens: return prompt truncated = ' '.join(words[:max_tokens]) print(f"Warning: Prompt truncated from {len(words)} to {max_tokens} words") return truncated def generate(self, prompt: str, reference_images: Optional[List[Union[str, Image.Image]]] = None, negative_prompt: Optional[str] = None, width: Optional[int] = None, height: Optional[int] = None, num_images: int = 1, seed: Optional[int] = None) -> Union[Image.Image, List[Image.Image]]: """Generate monster image from prompt""" try: # Load model if needed self.load_model() # Truncate prompt to avoid CLIP token limit issues prompt = self._truncate_prompt(prompt) if negative_prompt: negative_prompt = self._truncate_prompt(negative_prompt) # Set dimensions width = width or self.default_width height = height or self.default_height # Ensure dimensions are multiples of 8 width = (width // 8) * 8 height = (height // 8) * 8 # Enhance prompt for monster generation enhanced_prompt = self._enhance_prompt(prompt) # Default negative prompt for quality if negative_prompt is None: negative_prompt = ( "low quality, blurry, distorted, disfigured, " "bad anatomy, wrong proportions, ugly, duplicate, " "morbid, mutilated, extra limbs, malformed" ) # Set seed for reproducibility generator = None if seed is not None: generator = torch.Generator(device=self.device).manual_seed(seed) # Generate images with torch.no_grad(): if hasattr(self.pipeline, '__call__'): # Standard diffusion pipeline images = self.pipeline( prompt=enhanced_prompt, negative_prompt=negative_prompt, width=width, height=height, num_inference_steps=self.num_inference_steps, guidance_scale=self.guidance_scale, num_images_per_prompt=num_images, generator=generator ).images else: # OmniGen specific generation (if different API) images = self._omnigen_generate( enhanced_prompt, reference_images, width, height, num_images ) # Clean up memory if self.device == "cuda": torch.cuda.empty_cache() # Return single image or list if num_images == 1: return images[0] return images except Exception as e: print(f"Image generation error: {e}") # Return fallback image return self._generate_fallback_image(width, height) def _enhance_prompt(self, base_prompt: str) -> str: """Enhance prompt for better monster generation""" enhancements = [ "digital art", "creature design", "game character", "detailed", "vibrant colors", "fantasy creature", "high quality", "professional artwork" ] # Combine base prompt with enhancements enhanced = f"{base_prompt}, {', '.join(enhancements)}" return enhanced def _omnigen_generate(self, prompt: str, reference_images: Optional[List], width: int, height: int, num_images: int) -> List[Image.Image]: """OmniGen specific generation with multimodal inputs""" # This would be implemented based on OmniGen's specific API # For now, fall back to standard generation return self.pipeline( prompt=prompt, width=width, height=height, num_images_per_prompt=num_images ).images def _generate_fallback_image(self, width: int, height: int) -> Image.Image: """Generate a fallback monster image""" # Create a simple procedural monster image img_array = np.zeros((height, width, 3), dtype=np.uint8) # Add some basic shapes and colors center_x, center_y = width // 2, height // 2 radius = min(width, height) // 3 # Create circular body y, x = np.ogrid[:height, :width] mask = (x - center_x)**2 + (y - center_y)**2 <= radius**2 # Random monster color color = np.random.randint(50, 200, size=3) img_array[mask] = color # Add eyes eye_y = center_y - radius // 3 eye_left_x = center_x - radius // 3 eye_right_x = center_x + radius // 3 eye_radius = radius // 8 # Left eye eye_mask = (x - eye_left_x)**2 + (y - eye_y)**2 <= eye_radius**2 img_array[eye_mask] = [255, 255, 255] # Right eye eye_mask = (x - eye_right_x)**2 + (y - eye_y)**2 <= eye_radius**2 img_array[eye_mask] = [255, 255, 255] # Convert to PIL Image return Image.fromarray(img_array) def edit_image(self, image: Union[str, Image.Image], prompt: str, mask: Optional[Union[str, Image.Image]] = None) -> Image.Image: """Edit existing image (for future monster customization)""" # This would implement image editing capabilities raise NotImplementedError("Image editing not yet implemented") def to(self, device: str): """Move pipeline to specified device""" self.device = device if self.pipeline: if device == "cuda" and self.enable_cpu_offload: self.pipeline.enable_sequential_cpu_offload() else: self.pipeline = self.pipeline.to(device) def __del__(self): """Cleanup when object is destroyed""" if hasattr(self, 'pipeline') and self.pipeline is not None: del self.pipeline gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache()