digiPal / app.py
BladeSzaSza's picture
added logs
2153bff
raw
history blame
15.8 kB
import gradio as gr
import spaces
import os
import json
import torch
import gc
from datetime import datetime
from pathlib import Path
# Disable torch dynamo globally to avoid ConstantVariable errors
torch._dynamo.config.suppress_errors = True
# Initialize directories
DATA_DIR = Path("/data") if os.path.exists("/data") else Path("./data")
DATA_DIR.mkdir(exist_ok=True)
(DATA_DIR / "users").mkdir(exist_ok=True)
(DATA_DIR / "monsters").mkdir(exist_ok=True)
(DATA_DIR / "models").mkdir(exist_ok=True)
(DATA_DIR / "cache").mkdir(exist_ok=True)
# Ensure Gradio cache directory exists
import tempfile
gradio_cache_dir = Path("/tmp/gradio")
gradio_cache_dir.mkdir(parents=True, exist_ok=True)
# Set environment variable for Gradio cache
os.environ.setdefault("GRADIO_TEMP_DIR", str(gradio_cache_dir))
# Import modules (to be created)
from core.ai_pipeline import MonsterGenerationPipeline
from core.game_mechanics import GameMechanics
from core.state_manager import StateManager
from core.auth_manager import AuthManager
from ui.themes import get_cyberpunk_theme, CYBERPUNK_CSS
from ui.interfaces import create_voice_interface, create_visual_interface
# Initialize with GPU optimization
@spaces.GPU(duration=300)
def initialize_systems():
"""Initialize all core systems with GPU"""
pipeline = MonsterGenerationPipeline()
return pipeline
# Initialize core systems (defer GPU initialization)
pipeline = None
def get_pipeline():
"""Get or initialize the pipeline with GPU support"""
global pipeline
if pipeline is None:
try:
pipeline = initialize_systems()
except Exception as e:
print(f"GPU initialization failed, falling back to CPU: {e}")
pipeline = MonsterGenerationPipeline(device="cpu")
return pipeline
game_mechanics = GameMechanics()
state_manager = StateManager(DATA_DIR)
auth_manager = AuthManager()
# Main generation function
@spaces.GPU(duration=180)
def generate_monster(oauth_profile: gr.OAuthProfile | None, audio_input=None, text_input=None, reference_images=None,
training_focus="balanced", care_level="normal"):
"""Generate a new monster with AI pipeline"""
if oauth_profile is None:
return {
"message": "๐Ÿ”’ Please log in to create monsters!",
"image": None,
"model_3d": None,
"stats": None,
"dialogue": None
}
user_id = oauth_profile.username
try:
# Generate monster using AI pipeline
current_pipeline = get_pipeline()
result = current_pipeline.generate_monster(
audio_input=audio_input,
text_input=text_input,
reference_images=reference_images,
user_id=user_id
)
# Create game monster from AI result
monster = game_mechanics.create_monster(result, {
"training_focus": training_focus,
"care_level": care_level
}, user_id)
# Save to persistent storage
state_manager.save_monster(user_id, monster)
# Prepare response
response_dict = {
"message": f"โœจ {monster.name} has been created!",
"image": result.get('image'),
"model_3d": result.get('model_3d'),
"stats": monster.get_stats_display(),
"dialogue": result.get('dialogue', "๐Ÿค–๐Ÿ’š1๏ธโƒฃ0๏ธโƒฃ0๏ธโƒฃ")
}
return (
response_dict["message"],
response_dict["image"],
response_dict["model_3d"],
response_dict["stats"],
response_dict["dialogue"]
)
except Exception as e:
print(f"Error generating monster: {str(e)}")
# Use fallback generation
current_pipeline = get_pipeline()
fallback_result = current_pipeline.fallback_generation(text_input or "friendly digital creature")
fallback_dict = {
"message": "โšก Created using quick generation mode",
"image": fallback_result.get('image'),
"model_3d": None,
"stats": fallback_result.get('stats'),
"dialogue": "๐Ÿค–โ“9๏ธโƒฃ9๏ธโƒฃ"
}
return (
fallback_dict["message"],
fallback_dict["image"],
fallback_dict["model_3d"],
fallback_dict["stats"],
fallback_dict["dialogue"]
)
# Training function
def train_monster(oauth_profile: gr.OAuthProfile | None, training_type, intensity):
"""Train the active monster"""
if oauth_profile is None:
return "๐Ÿ”’ Please log in to train monsters!", None, None
user_id = oauth_profile.username
current_monster = state_manager.get_current_monster(user_id)
if not current_monster:
return "โŒ No active monster to train!", None, None
# Apply training
result = game_mechanics.train_monster(current_monster, training_type, intensity)
if result['success']:
state_manager.update_monster(user_id, current_monster)
return (
result['message'],
current_monster.get_stats_display(),
result.get('evolution_check')
)
else:
return result['message'], None, None
# Care functions
def feed_monster(oauth_profile: gr.OAuthProfile | None, food_type):
"""Feed the active monster"""
if oauth_profile is None:
return "๐Ÿ”’ Please log in to care for monsters!"
user_id = oauth_profile.username
current_monster = state_manager.get_current_monster(user_id)
if not current_monster:
return "โŒ No active monster to feed!"
result = game_mechanics.feed_monster(current_monster, food_type)
state_manager.update_monster(user_id, current_monster)
return result['message']
# Build the Gradio interface
with gr.Blocks(
theme=get_cyberpunk_theme(),
css=CYBERPUNK_CSS,
title="DigiPal - Digital Monster Companion"
) as demo:
# Header with cyberpunk styling
gr.HTML("""
<div class="cyber-header">
<h1 class="glitch-text">๐Ÿค– DigiPal ๐Ÿค–</h1>
<p class="cyber-subtitle">Your AI-Powered Digital Monster Companion</p>
<div class="pulse-line"></div>
</div>
""")
# Authentication
with gr.Row():
login_btn = gr.LoginButton("๐Ÿ” Connect to Digital World", size="lg")
user_display = gr.Markdown("", elem_classes=["user-status"])
# Main interface tabs
with gr.Tabs(elem_classes=["cyber-tabs"]):
# Monster Creation Tab
with gr.TabItem("๐Ÿงฌ Create Monster", elem_classes=["cyber-tab-content"]):
with gr.Row():
# Input Column
with gr.Column(scale=1):
gr.Markdown("### ๐ŸŽ™๏ธ Voice Input")
audio_input = gr.Audio(
label="Describe your monster",
sources=["microphone", "upload"],
type="filepath",
elem_classes=["cyber-input"]
)
gr.Markdown("### ๐Ÿ’ฌ Text Input")
text_input = gr.Textbox(
label="Or type a description",
placeholder="Describe your ideal digital monster...",
lines=3,
elem_classes=["cyber-input"]
)
gr.Markdown("### ๐Ÿ–ผ๏ธ Reference Images")
reference_images = gr.File(
label="Upload reference images (optional)",
file_count="multiple",
file_types=["image"],
elem_classes=["cyber-input"]
)
with gr.Row():
training_focus = gr.Radio(
choices=["balanced", "strength", "defense", "speed", "intelligence"],
label="Training Focus",
value="balanced",
elem_classes=["cyber-radio"]
)
generate_btn = gr.Button(
"โšก Generate Monster",
variant="primary",
size="lg",
elem_classes=["cyber-button", "generate-button"]
)
# Output Column
with gr.Column(scale=1):
generation_message = gr.Markdown("", elem_classes=["cyber-message"])
monster_image = gr.Image(
label="Monster Appearance",
type="pil",
elem_classes=["monster-display"]
)
monster_model = gr.Model3D(
label="3D Model",
height=400,
elem_classes=["monster-display"]
)
monster_dialogue = gr.Textbox(
label="Monster Says",
interactive=False,
elem_classes=["cyber-dialogue"]
)
monster_stats = gr.JSON(
label="Stats",
elem_classes=["cyber-stats"]
)
# Monster Status Tab
with gr.TabItem("๐Ÿ“Š Monster Status", elem_classes=["cyber-tab-content"]):
with gr.Row():
with gr.Column():
current_monster_display = gr.Model3D(
label="Your Digital Monster",
height=400,
elem_classes=["monster-display"]
)
monster_communication = gr.Textbox(
label="Monster Communication",
placeholder="Your monster speaks in emojis and numbers...",
interactive=False,
elem_classes=["cyber-dialogue"]
)
with gr.Column():
stats_display = gr.JSON(
label="Current Stats",
elem_classes=["cyber-stats"]
)
care_metrics = gr.JSON(
label="Care Status",
elem_classes=["cyber-stats"]
)
evolution_progress = gr.HTML(
elem_classes=["evolution-display"]
)
refresh_btn = gr.Button(
"๐Ÿ”„ Refresh Status",
elem_classes=["cyber-button"]
)
# Training Tab
with gr.TabItem("๐Ÿ’ช Training", elem_classes=["cyber-tab-content"]):
with gr.Row():
with gr.Column():
training_type = gr.Radio(
choices=["Strength", "Defense", "Speed", "Intelligence", "Special"],
label="Training Type",
value="Strength",
elem_classes=["cyber-radio"]
)
training_intensity = gr.Slider(
minimum=1,
maximum=10,
value=5,
step=1,
label="Training Intensity",
elem_classes=["cyber-slider"]
)
train_btn = gr.Button(
"๐Ÿ‹๏ธ Start Training",
variant="primary",
elem_classes=["cyber-button"]
)
with gr.Column():
training_result = gr.Textbox(
label="Training Result",
interactive=False,
elem_classes=["cyber-output"]
)
updated_stats = gr.JSON(
label="Updated Stats",
elem_classes=["cyber-stats"]
)
evolution_check = gr.HTML(
elem_classes=["evolution-display"]
)
# Care Tab
with gr.TabItem("โค๏ธ Care", elem_classes=["cyber-tab-content"]):
with gr.Row():
with gr.Column():
gr.Markdown("### ๐Ÿ– Feeding")
food_type = gr.Radio(
choices=["Meat", "Fish", "Vegetable", "Treat", "Medicine"],
label="Select Food",
value="Meat",
elem_classes=["cyber-radio"]
)
feed_btn = gr.Button(
"๐Ÿฝ๏ธ Feed Monster",
elem_classes=["cyber-button"]
)
feeding_result = gr.Textbox(
label="Feeding Result",
interactive=False,
elem_classes=["cyber-output"]
)
with gr.Column():
gr.Markdown("### ๐ŸŽฎ Interaction")
play_btn = gr.Button(
"๐ŸŽพ Play",
elem_classes=["cyber-button"]
)
praise_btn = gr.Button(
"๐Ÿ‘ Praise",
elem_classes=["cyber-button"]
)
scold_btn = gr.Button(
"๐Ÿ‘Ž Scold",
elem_classes=["cyber-button"]
)
interaction_result = gr.Textbox(
label="Monster Response",
interactive=False,
elem_classes=["cyber-output"]
)
# Event handlers
generate_btn.click(
fn=generate_monster,
inputs=[
audio_input,
text_input,
reference_images,
training_focus,
gr.State("normal") # care_level
],
outputs=[
generation_message,
monster_image,
monster_model,
monster_stats,
monster_dialogue
]
)
train_btn.click(
fn=train_monster,
inputs=[
training_type,
training_intensity
],
outputs=[
training_result,
updated_stats,
evolution_check
]
)
feed_btn.click(
fn=feed_monster,
inputs=[
food_type
],
outputs=[feeding_result]
)
# Launch the app
if __name__ == "__main__":
# Suppress MCP warnings if needed
import warnings
warnings.filterwarnings("ignore", category=UserWarning, module="gradio.mcp")
demo.queue(
default_concurrency_limit=10,
max_size=100
).launch(
server_name="0.0.0.0",
server_port=7860,
show_api=False,
show_error=True
)