Spaces:
Running
on
Zero
Running
on
Zero
import spaces | |
import torch | |
import gc | |
import os | |
from typing import Optional, List, Dict, Any | |
from datetime import datetime | |
from pathlib import Path | |
import numpy as np | |
from PIL import Image | |
import tempfile | |
# Model imports (to be implemented) | |
from models.stt_processor import KyutaiSTTProcessor | |
from models.text_generator import QwenTextGenerator | |
from models.image_generator import OmniGenImageGenerator | |
from models.model_3d_generator import Hunyuan3DGenerator | |
from models.rigging_processor import UniRigProcessor | |
from utils.fallbacks import FallbackManager | |
from utils.caching import ModelCache | |
class MonsterGenerationPipeline: | |
"""Main AI pipeline for monster generation""" | |
def __init__(self, device: str = "cuda"): | |
self.device = device if torch.cuda.is_available() else "cpu" | |
self.cache = ModelCache() | |
self.fallback_manager = FallbackManager() | |
self.models = {} | |
self.model_loaded = { | |
'stt': False, | |
'text_gen': False, | |
'image_gen': False, | |
'3d_gen': False, | |
'rigging': False | |
} | |
# Pipeline configuration | |
self.config = { | |
'max_retries': 3, | |
'timeout': 180, | |
'enable_caching': True, | |
'low_vram_mode': True | |
} | |
def _cleanup_memory(self): | |
"""Clear GPU memory""" | |
if self.device == "cuda": | |
torch.cuda.empty_cache() | |
torch.cuda.synchronize() | |
gc.collect() | |
def _lazy_load_model(self, model_type: str): | |
"""Lazy loading with memory optimization""" | |
if self.model_loaded[model_type]: | |
return self.models[model_type] | |
# Clear memory before loading new model | |
self._cleanup_memory() | |
try: | |
if model_type == 'stt': | |
self.models['stt'] = KyutaiSTTProcessor(device=self.device) | |
elif model_type == 'text_gen': | |
self.models['text_gen'] = QwenTextGenerator(device=self.device) | |
elif model_type == 'image_gen': | |
self.models['image_gen'] = OmniGenImageGenerator(device=self.device) | |
elif model_type == '3d_gen': | |
self.models['3d_gen'] = Hunyuan3DGenerator(device=self.device) | |
elif model_type == 'rigging': | |
self.models['rigging'] = UniRigProcessor(device=self.device) | |
self.model_loaded[model_type] = True | |
return self.models[model_type] | |
except Exception as e: | |
print(f"Failed to load {model_type}: {e}") | |
return None | |
def _unload_model(self, model_type: str): | |
"""Unload model to free memory""" | |
if model_type in self.models and self.model_loaded[model_type]: | |
if hasattr(self.models[model_type], 'to'): | |
self.models[model_type].to('cpu') | |
del self.models[model_type] | |
self.model_loaded[model_type] = False | |
self._cleanup_memory() | |
def generate_monster(self, | |
audio_input: Optional[str] = None, | |
text_input: Optional[str] = None, | |
reference_images: Optional[List] = None, | |
user_id: Optional[str] = None) -> Dict[str, Any]: | |
"""Main monster generation pipeline""" | |
generation_log = { | |
'user_id': user_id, | |
'timestamp': datetime.now().isoformat(), | |
'stages_completed': [], | |
'fallbacks_used': [], | |
'success': False, | |
'errors': [] | |
} | |
try: | |
print("🚀 Starting monster generation pipeline...") | |
# Stage 1: Speech to Text (if audio provided) | |
description = "" | |
if audio_input and os.path.exists(audio_input): | |
try: | |
print("🎤 Processing audio input...") | |
stt_model = self._lazy_load_model('stt') | |
if stt_model: | |
description = stt_model.transcribe(audio_input) | |
generation_log['stages_completed'].append('stt') | |
print(f"✅ STT completed: {description[:100]}...") | |
else: | |
raise Exception("STT model failed to load") | |
except Exception as e: | |
print(f"❌ STT failed: {e}") | |
description = text_input or "Create a friendly digital monster" | |
generation_log['fallbacks_used'].append('stt') | |
generation_log['errors'].append(f"STT error: {str(e)}") | |
finally: | |
# Unload STT to free memory | |
self._unload_model('stt') | |
else: | |
description = text_input or "Create a friendly digital monster" | |
print(f"📝 Using text input: {description}") | |
# Stage 2: Generate monster characteristics | |
monster_traits = {} | |
monster_dialogue = "" | |
try: | |
print("🧠 Generating monster traits and dialogue...") | |
text_gen = self._lazy_load_model('text_gen') | |
if text_gen: | |
monster_traits = text_gen.generate_traits(description) | |
monster_dialogue = text_gen.generate_dialogue(monster_traits) | |
generation_log['stages_completed'].append('text_gen') | |
print(f"✅ Text generation completed: {monster_traits.get('name', 'Unknown')}") | |
else: | |
raise Exception("Text generation model failed to load") | |
except Exception as e: | |
print(f"❌ Text generation failed: {e}") | |
monster_traits, monster_dialogue = self.fallback_manager.handle_text_gen_failure(description) | |
generation_log['fallbacks_used'].append('text_gen') | |
generation_log['errors'].append(f"Text generation error: {str(e)}") | |
finally: | |
self._unload_model('text_gen') | |
# Stage 3: Generate monster image | |
monster_image = None | |
try: | |
print("🎨 Generating monster image...") | |
image_gen = self._lazy_load_model('image_gen') | |
if image_gen: | |
# Create enhanced prompt from traits | |
image_prompt = self._create_image_prompt(description, monster_traits) | |
monster_image = image_gen.generate( | |
prompt=image_prompt, | |
reference_images=reference_images, | |
width=512, | |
height=512 | |
) | |
generation_log['stages_completed'].append('image_gen') | |
print("✅ Image generation completed") | |
else: | |
raise Exception("Image generation model failed to load") | |
except Exception as e: | |
print(f"❌ Image generation failed: {e}") | |
monster_image = self.fallback_manager.handle_image_gen_failure(description) | |
generation_log['fallbacks_used'].append('image_gen') | |
generation_log['errors'].append(f"Image generation error: {str(e)}") | |
finally: | |
self._unload_model('image_gen') | |
# Stage 4: Convert to 3D model | |
model_3d = None | |
model_3d_path = None | |
try: | |
print("🔲 Converting to 3D model...") | |
model_3d_gen = self._lazy_load_model('3d_gen') | |
if model_3d_gen and monster_image: | |
model_3d = model_3d_gen.image_to_3d(monster_image) | |
# Save 3D model | |
model_3d_path = self._save_3d_model(model_3d, user_id) | |
generation_log['stages_completed'].append('3d_gen') | |
print("✅ 3D generation completed") | |
else: | |
raise Exception("3D generation failed - no model or image") | |
except Exception as e: | |
print(f"❌ 3D generation failed: {e}") | |
model_3d = self.fallback_manager.handle_3d_gen_failure(monster_image) | |
generation_log['fallbacks_used'].append('3d_gen') | |
generation_log['errors'].append(f"3D generation error: {str(e)}") | |
finally: | |
self._unload_model('3d_gen') | |
# Stage 5: Add rigging (optional, can be skipped for performance) | |
rigged_model = model_3d | |
if model_3d and self.config.get('enable_rigging', False): | |
try: | |
print("🦴 Adding rigging...") | |
rigging_proc = self._lazy_load_model('rigging') | |
if rigging_proc: | |
rigged_model = rigging_proc.rig_mesh(model_3d) | |
generation_log['stages_completed'].append('rigging') | |
print("✅ Rigging completed") | |
except Exception as e: | |
print(f"❌ Rigging failed: {e}") | |
generation_log['fallbacks_used'].append('rigging') | |
generation_log['errors'].append(f"Rigging error: {str(e)}") | |
finally: | |
self._unload_model('rigging') | |
# Prepare download files | |
download_files = self._prepare_download_files( | |
rigged_model or model_3d, | |
monster_image, | |
user_id | |
) | |
generation_log['success'] = True | |
print("🎉 Monster generation pipeline completed successfully!") | |
return { | |
'description': description, | |
'traits': monster_traits, | |
'dialogue': monster_dialogue, | |
'image': monster_image, | |
'model_3d': model_3d_path, | |
'download_files': download_files, | |
'generation_log': generation_log, | |
'status': 'success' | |
} | |
except Exception as e: | |
print(f"💥 Pipeline error: {e}") | |
generation_log['error'] = str(e) | |
generation_log['errors'].append(f"Pipeline error: {str(e)}") | |
return self.fallback_generation(description or "digital monster", generation_log) | |
def _create_image_prompt(self, base_description: str, traits: Dict) -> str: | |
"""Create enhanced prompt for image generation""" | |
prompt_parts = [base_description] | |
if traits: | |
if 'appearance' in traits: | |
prompt_parts.append(traits['appearance']) | |
if 'personality' in traits: | |
prompt_parts.append(f"with {traits['personality']} personality") | |
if 'color_scheme' in traits: | |
prompt_parts.append(f"featuring {traits['color_scheme']} colors") | |
prompt_parts.extend([ | |
"digital monster", | |
"creature design", | |
"game character", | |
"high quality", | |
"detailed" | |
]) | |
return ", ".join(prompt_parts) | |
def _save_3d_model(self, model_3d, user_id: Optional[str]) -> Optional[str]: | |
"""Save 3D model to persistent storage""" | |
if not model_3d: | |
return None | |
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
user_id_str = user_id or "anonymous" | |
filename = f"monster_{user_id_str}_{timestamp}.glb" | |
# Use HuggingFace Spaces persistent storage | |
if os.path.exists("/data"): | |
filepath = f"/data/models/{filename}" | |
else: | |
filepath = f"./data/models/{filename}" | |
os.makedirs(os.path.dirname(filepath), exist_ok=True) | |
# Save model (implementation depends on model format) | |
# This is a placeholder - actual implementation would depend on model format | |
with open(filepath, 'wb') as f: | |
if hasattr(model_3d, 'export'): | |
model_3d.export(f) | |
else: | |
# Fallback: save as binary data | |
f.write(str(model_3d).encode()) | |
return filepath | |
def _prepare_download_files(self, model_3d, image, user_id: Optional[str]) -> List[str]: | |
"""Prepare downloadable files for user""" | |
files = [] | |
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
user_id_str = user_id or "anonymous" | |
# Save image | |
if image: | |
if isinstance(image, Image.Image): | |
image_path = f"/tmp/monster_{user_id_str}_{timestamp}.png" | |
image.save(image_path) | |
files.append(image_path) | |
elif isinstance(image, np.ndarray): | |
image_path = f"/tmp/monster_{user_id_str}_{timestamp}.png" | |
Image.fromarray(image).save(image_path) | |
files.append(image_path) | |
# Save 3D model in multiple formats if available | |
if model_3d: | |
# GLB format | |
glb_path = f"/tmp/monster_{user_id_str}_{timestamp}.glb" | |
files.append(glb_path) | |
# OBJ format (optional) | |
obj_path = f"/tmp/monster_{user_id_str}_{timestamp}.obj" | |
files.append(obj_path) | |
return files | |
def fallback_generation(self, description: str, generation_log: Dict) -> Dict[str, Any]: | |
"""Complete fallback generation when pipeline fails""" | |
return self.fallback_manager.complete_fallback_generation(description, generation_log) | |
def cleanup(self): | |
"""Clean up all loaded models""" | |
for model_type in list(self.models.keys()): | |
self._unload_model(model_type) | |
self._cleanup_memory() |