digiPal / models /image_generator.py
BladeSzaSza's picture
fix: resolve boundary index calculation error in 3D mesh generation and AttributeError in image generator cleanup
c95a6a2
import torch
from diffusers import DiffusionPipeline, StableDiffusionPipeline
from PIL import Image
import numpy as np
from typing import Optional, List, Union
import gc
# Disable torch dynamo to avoid ConstantVariable errors
torch._dynamo.config.suppress_errors = True
class OmniGenImageGenerator:
"""Image generation using OmniGen2 model"""
def __init__(self, device: str = "cuda"):
self.device = device if torch.cuda.is_available() else "cpu"
self.pipeline = None
self.model_id = "runwayml/stable-diffusion-v1-5" # Using working Stable Diffusion model
# Generation parameters
self.default_width = 512
self.default_height = 512
self.num_inference_steps = 30
self.guidance_scale = 7.5
# Memory optimization
self.enable_attention_slicing = True
self.enable_vae_slicing = True
self.enable_cpu_offload = self.device == "cuda"
def load_model(self):
"""Lazy load the image generation model"""
if self.pipeline is None:
try:
# Determine torch dtype
torch_dtype = torch.float16 if self.device == "cuda" else torch.float32
# Load pipeline with optimizations
self.pipeline = StableDiffusionPipeline.from_pretrained(
self.model_id,
torch_dtype=torch_dtype,
use_safetensors=True,
variant="fp16" if self.device == "cuda" else None
)
# Apply optimizations and device placement
if self.device == "cuda":
if self.enable_cpu_offload:
self.pipeline.enable_sequential_cpu_offload()
else:
# Safely move pipeline to CUDA
try:
self.pipeline = self.pipeline.to(self.device)
except RuntimeError as e:
if "meta tensor" in str(e):
# Handle meta tensor issue by loading with device_map
print(f"Meta tensor issue detected, using CPU fallback: {e}")
self.device = "cpu"
self.pipeline = self.pipeline.to("cpu")
else:
raise e
if self.enable_attention_slicing and hasattr(self.pipeline, 'enable_attention_slicing'):
self.pipeline.enable_attention_slicing(1)
if self.enable_vae_slicing and hasattr(self.pipeline, 'enable_vae_slicing'):
self.pipeline.enable_vae_slicing()
else:
self.pipeline = self.pipeline.to(self.device)
# Disable torch.compile to avoid dynamo issues that cause ConstantVariable errors
print("Skipping torch.compile to avoid dynamo compatibility issues")
except Exception as e:
print(f"Failed to load image generation model: {e}")
# Try fallback to stable diffusion
try:
self.model_id = "runwayml/stable-diffusion-v1-5"
self._load_fallback_model()
except:
raise
def _load_fallback_model(self):
"""Load fallback Stable Diffusion model"""
from diffusers import StableDiffusionPipeline
torch_dtype = torch.float16 if self.device == "cuda" else torch.float32
self.pipeline = StableDiffusionPipeline.from_pretrained(
self.model_id,
torch_dtype=torch_dtype,
use_safetensors=True,
trust_remote_code=True
)
if self.device == "cuda" and self.enable_cpu_offload:
self.pipeline.enable_sequential_cpu_offload()
else:
self.pipeline = self.pipeline.to(self.device)
def _truncate_prompt(self, prompt: str, max_tokens: int = 75) -> str:
"""Truncate prompt to fit CLIP token limit"""
words = prompt.split()
if len(words) <= max_tokens:
return prompt
truncated = ' '.join(words[:max_tokens])
print(f"Warning: Prompt truncated from {len(words)} to {max_tokens} words")
return truncated
def generate(self,
prompt: str,
reference_images: Optional[List[Union[str, Image.Image]]] = None,
negative_prompt: Optional[str] = None,
width: Optional[int] = None,
height: Optional[int] = None,
num_images: int = 1,
seed: Optional[int] = None) -> Union[Image.Image, List[Image.Image]]:
"""Generate monster image from prompt"""
try:
# Load model if needed
self.load_model()
# Truncate prompt to avoid CLIP token limit issues
prompt = self._truncate_prompt(prompt)
if negative_prompt:
negative_prompt = self._truncate_prompt(negative_prompt)
# Set dimensions
width = width or self.default_width
height = height or self.default_height
# Ensure dimensions are multiples of 8
width = (width // 8) * 8
height = (height // 8) * 8
# Enhance prompt for monster generation
enhanced_prompt = self._enhance_prompt(prompt)
# Default negative prompt for quality
if negative_prompt is None:
negative_prompt = (
"low quality, blurry, distorted, disfigured, "
"bad anatomy, wrong proportions, ugly, duplicate, "
"morbid, mutilated, extra limbs, malformed"
)
# Set seed for reproducibility
generator = None
if seed is not None:
generator = torch.Generator(device=self.device).manual_seed(seed)
# Generate images
with torch.no_grad():
if hasattr(self.pipeline, '__call__'):
# Standard diffusion pipeline
images = self.pipeline(
prompt=enhanced_prompt,
negative_prompt=negative_prompt,
width=width,
height=height,
num_inference_steps=self.num_inference_steps,
guidance_scale=self.guidance_scale,
num_images_per_prompt=num_images,
generator=generator
).images
else:
# OmniGen specific generation (if different API)
images = self._omnigen_generate(
enhanced_prompt,
reference_images,
width,
height,
num_images
)
# Clean up memory
if self.device == "cuda":
torch.cuda.empty_cache()
# Return single image or list
if num_images == 1:
return images[0]
return images
except Exception as e:
print(f"Image generation error: {e}")
# Return fallback image
return self._generate_fallback_image(width, height)
def _enhance_prompt(self, base_prompt: str) -> str:
"""Enhance prompt for better monster generation"""
enhancements = [
"digital art",
"creature design",
"game character",
"detailed",
"vibrant colors",
"fantasy creature",
"high quality",
"professional artwork"
]
# Combine base prompt with enhancements
enhanced = f"{base_prompt}, {', '.join(enhancements)}"
return enhanced
def _omnigen_generate(self, prompt: str, reference_images: Optional[List],
width: int, height: int, num_images: int) -> List[Image.Image]:
"""OmniGen specific generation with multimodal inputs"""
# This would be implemented based on OmniGen's specific API
# For now, fall back to standard generation
return self.pipeline(
prompt=prompt,
width=width,
height=height,
num_images_per_prompt=num_images
).images
def _generate_fallback_image(self, width: int, height: int) -> Image.Image:
"""Generate a fallback monster image"""
# Create a simple procedural monster image
img_array = np.zeros((height, width, 3), dtype=np.uint8)
# Add some basic shapes and colors
center_x, center_y = width // 2, height // 2
radius = min(width, height) // 3
# Create circular body
y, x = np.ogrid[:height, :width]
mask = (x - center_x)**2 + (y - center_y)**2 <= radius**2
# Random monster color
color = np.random.randint(50, 200, size=3)
img_array[mask] = color
# Add eyes
eye_y = center_y - radius // 3
eye_left_x = center_x - radius // 3
eye_right_x = center_x + radius // 3
eye_radius = radius // 8
# Left eye
eye_mask = (x - eye_left_x)**2 + (y - eye_y)**2 <= eye_radius**2
img_array[eye_mask] = [255, 255, 255]
# Right eye
eye_mask = (x - eye_right_x)**2 + (y - eye_y)**2 <= eye_radius**2
img_array[eye_mask] = [255, 255, 255]
# Convert to PIL Image
return Image.fromarray(img_array)
def edit_image(self,
image: Union[str, Image.Image],
prompt: str,
mask: Optional[Union[str, Image.Image]] = None) -> Image.Image:
"""Edit existing image (for future monster customization)"""
# This would implement image editing capabilities
raise NotImplementedError("Image editing not yet implemented")
def to(self, device: str):
"""Move pipeline to specified device"""
self.device = device
if self.pipeline:
if device == "cuda" and self.enable_cpu_offload:
self.pipeline.enable_sequential_cpu_offload()
else:
self.pipeline = self.pipeline.to(device)
def __del__(self):
"""Cleanup when object is destroyed"""
if hasattr(self, 'pipeline') and self.pipeline is not None:
del self.pipeline
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()