digiPal / core /ai_pipeline.py
BladeSzaSza's picture
added more logs
e8293cd
import spaces
import torch
import gc
import os
from typing import Optional, List, Dict, Any
from datetime import datetime
from pathlib import Path
import numpy as np
from PIL import Image
import tempfile
# Model imports (to be implemented)
from models.stt_processor import KyutaiSTTProcessor
from models.text_generator import QwenTextGenerator
from models.image_generator import OmniGenImageGenerator
from models.model_3d_generator import Hunyuan3DGenerator
from models.rigging_processor import UniRigProcessor
from utils.fallbacks import FallbackManager
from utils.caching import ModelCache
class MonsterGenerationPipeline:
"""Main AI pipeline for monster generation"""
def __init__(self, device: str = "cuda"):
self.device = device if torch.cuda.is_available() else "cpu"
self.cache = ModelCache()
self.fallback_manager = FallbackManager()
self.models = {}
self.model_loaded = {
'stt': False,
'text_gen': False,
'image_gen': False,
'3d_gen': False,
'rigging': False
}
# Pipeline configuration
self.config = {
'max_retries': 3,
'timeout': 180,
'enable_caching': True,
'low_vram_mode': True
}
def _cleanup_memory(self):
"""Clear GPU memory"""
if self.device == "cuda":
torch.cuda.empty_cache()
torch.cuda.synchronize()
gc.collect()
def _lazy_load_model(self, model_type: str):
"""Lazy loading with memory optimization"""
if self.model_loaded[model_type]:
return self.models[model_type]
# Clear memory before loading new model
self._cleanup_memory()
try:
if model_type == 'stt':
self.models['stt'] = KyutaiSTTProcessor(device=self.device)
elif model_type == 'text_gen':
self.models['text_gen'] = QwenTextGenerator(device=self.device)
elif model_type == 'image_gen':
self.models['image_gen'] = OmniGenImageGenerator(device=self.device)
elif model_type == '3d_gen':
self.models['3d_gen'] = Hunyuan3DGenerator(device=self.device)
elif model_type == 'rigging':
self.models['rigging'] = UniRigProcessor(device=self.device)
self.model_loaded[model_type] = True
return self.models[model_type]
except Exception as e:
print(f"Failed to load {model_type}: {e}")
return None
def _unload_model(self, model_type: str):
"""Unload model to free memory"""
if model_type in self.models and self.model_loaded[model_type]:
if hasattr(self.models[model_type], 'to'):
self.models[model_type].to('cpu')
del self.models[model_type]
self.model_loaded[model_type] = False
self._cleanup_memory()
@spaces.GPU(duration=300)
def generate_monster(self,
audio_input: Optional[str] = None,
text_input: Optional[str] = None,
reference_images: Optional[List] = None,
user_id: Optional[str] = None) -> Dict[str, Any]:
"""Main monster generation pipeline"""
generation_log = {
'user_id': user_id,
'timestamp': datetime.now().isoformat(),
'stages_completed': [],
'fallbacks_used': [],
'success': False,
'errors': []
}
try:
print("🚀 Starting monster generation pipeline...")
# Stage 1: Speech to Text (if audio provided)
description = ""
if audio_input and os.path.exists(audio_input):
try:
print("🎤 Processing audio input...")
stt_model = self._lazy_load_model('stt')
if stt_model:
description = stt_model.transcribe(audio_input)
generation_log['stages_completed'].append('stt')
print(f"✅ STT completed: {description[:100]}...")
else:
raise Exception("STT model failed to load")
except Exception as e:
print(f"❌ STT failed: {e}")
description = text_input or "Create a friendly digital monster"
generation_log['fallbacks_used'].append('stt')
generation_log['errors'].append(f"STT error: {str(e)}")
finally:
# Unload STT to free memory
self._unload_model('stt')
else:
description = text_input or "Create a friendly digital monster"
print(f"📝 Using text input: {description}")
# Stage 2: Generate monster characteristics
monster_traits = {}
monster_dialogue = ""
try:
print("🧠 Generating monster traits and dialogue...")
text_gen = self._lazy_load_model('text_gen')
if text_gen:
monster_traits = text_gen.generate_traits(description)
monster_dialogue = text_gen.generate_dialogue(monster_traits)
generation_log['stages_completed'].append('text_gen')
print(f"✅ Text generation completed: {monster_traits.get('name', 'Unknown')}")
else:
raise Exception("Text generation model failed to load")
except Exception as e:
print(f"❌ Text generation failed: {e}")
monster_traits, monster_dialogue = self.fallback_manager.handle_text_gen_failure(description)
generation_log['fallbacks_used'].append('text_gen')
generation_log['errors'].append(f"Text generation error: {str(e)}")
finally:
self._unload_model('text_gen')
# Stage 3: Generate monster image
monster_image = None
try:
print("🎨 Generating monster image...")
image_gen = self._lazy_load_model('image_gen')
if image_gen:
# Create enhanced prompt from traits
image_prompt = self._create_image_prompt(description, monster_traits)
monster_image = image_gen.generate(
prompt=image_prompt,
reference_images=reference_images,
width=512,
height=512
)
generation_log['stages_completed'].append('image_gen')
print("✅ Image generation completed")
else:
raise Exception("Image generation model failed to load")
except Exception as e:
print(f"❌ Image generation failed: {e}")
monster_image = self.fallback_manager.handle_image_gen_failure(description)
generation_log['fallbacks_used'].append('image_gen')
generation_log['errors'].append(f"Image generation error: {str(e)}")
finally:
self._unload_model('image_gen')
# Stage 4: Convert to 3D model
model_3d = None
model_3d_path = None
try:
print("🔲 Converting to 3D model...")
model_3d_gen = self._lazy_load_model('3d_gen')
if model_3d_gen and monster_image:
model_3d = model_3d_gen.image_to_3d(monster_image)
# Save 3D model
model_3d_path = self._save_3d_model(model_3d, user_id)
generation_log['stages_completed'].append('3d_gen')
print("✅ 3D generation completed")
else:
raise Exception("3D generation failed - no model or image")
except Exception as e:
print(f"❌ 3D generation failed: {e}")
model_3d = self.fallback_manager.handle_3d_gen_failure(monster_image)
generation_log['fallbacks_used'].append('3d_gen')
generation_log['errors'].append(f"3D generation error: {str(e)}")
finally:
self._unload_model('3d_gen')
# Stage 5: Add rigging (optional, can be skipped for performance)
rigged_model = model_3d
if model_3d and self.config.get('enable_rigging', False):
try:
print("🦴 Adding rigging...")
rigging_proc = self._lazy_load_model('rigging')
if rigging_proc:
rigged_model = rigging_proc.rig_mesh(model_3d)
generation_log['stages_completed'].append('rigging')
print("✅ Rigging completed")
except Exception as e:
print(f"❌ Rigging failed: {e}")
generation_log['fallbacks_used'].append('rigging')
generation_log['errors'].append(f"Rigging error: {str(e)}")
finally:
self._unload_model('rigging')
# Prepare download files
download_files = self._prepare_download_files(
rigged_model or model_3d,
monster_image,
user_id
)
generation_log['success'] = True
print("🎉 Monster generation pipeline completed successfully!")
return {
'description': description,
'traits': monster_traits,
'dialogue': monster_dialogue,
'image': monster_image,
'model_3d': model_3d_path,
'download_files': download_files,
'generation_log': generation_log,
'status': 'success'
}
except Exception as e:
print(f"💥 Pipeline error: {e}")
generation_log['error'] = str(e)
generation_log['errors'].append(f"Pipeline error: {str(e)}")
return self.fallback_generation(description or "digital monster", generation_log)
def _create_image_prompt(self, base_description: str, traits: Dict) -> str:
"""Create enhanced prompt for image generation"""
prompt_parts = [base_description]
if traits:
if 'appearance' in traits:
prompt_parts.append(traits['appearance'])
if 'personality' in traits:
prompt_parts.append(f"with {traits['personality']} personality")
if 'color_scheme' in traits:
prompt_parts.append(f"featuring {traits['color_scheme']} colors")
prompt_parts.extend([
"digital monster",
"creature design",
"game character",
"high quality",
"detailed"
])
return ", ".join(prompt_parts)
def _save_3d_model(self, model_3d, user_id: Optional[str]) -> Optional[str]:
"""Save 3D model to persistent storage"""
if not model_3d:
return None
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
user_id_str = user_id or "anonymous"
filename = f"monster_{user_id_str}_{timestamp}.glb"
# Use HuggingFace Spaces persistent storage
if os.path.exists("/data"):
filepath = f"/data/models/{filename}"
else:
filepath = f"./data/models/{filename}"
os.makedirs(os.path.dirname(filepath), exist_ok=True)
# Save model (implementation depends on model format)
# This is a placeholder - actual implementation would depend on model format
with open(filepath, 'wb') as f:
if hasattr(model_3d, 'export'):
model_3d.export(f)
else:
# Fallback: save as binary data
f.write(str(model_3d).encode())
return filepath
def _prepare_download_files(self, model_3d, image, user_id: Optional[str]) -> List[str]:
"""Prepare downloadable files for user"""
files = []
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
user_id_str = user_id or "anonymous"
# Save image
if image:
if isinstance(image, Image.Image):
image_path = f"/tmp/monster_{user_id_str}_{timestamp}.png"
image.save(image_path)
files.append(image_path)
elif isinstance(image, np.ndarray):
image_path = f"/tmp/monster_{user_id_str}_{timestamp}.png"
Image.fromarray(image).save(image_path)
files.append(image_path)
# Save 3D model in multiple formats if available
if model_3d:
# GLB format
glb_path = f"/tmp/monster_{user_id_str}_{timestamp}.glb"
files.append(glb_path)
# OBJ format (optional)
obj_path = f"/tmp/monster_{user_id_str}_{timestamp}.obj"
files.append(obj_path)
return files
def fallback_generation(self, description: str, generation_log: Dict) -> Dict[str, Any]:
"""Complete fallback generation when pipeline fails"""
return self.fallback_manager.complete_fallback_generation(description, generation_log)
def cleanup(self):
"""Clean up all loaded models"""
for model_type in list(self.models.keys()):
self._unload_model(model_type)
self._cleanup_memory()