Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch | |
import numpy as np | |
import soundfile as sf | |
import librosa | |
import warnings | |
from transformers import pipeline, AutoProcessor, AutoModel | |
from dia.model import Dia | |
import asyncio | |
import time | |
from collections import deque | |
import json | |
# Suppress warnings | |
warnings.filterwarnings("ignore") | |
# Global variables for model caching | |
dia_model = None | |
asr_model = None | |
emotion_classifier = None | |
conversation_histories = {} | |
MAX_HISTORY = 50 | |
MAX_CONCURRENT_USERS = 20 | |
class ConversationManager: | |
def __init__(self): | |
self.histories = {} | |
self.max_history = MAX_HISTORY | |
def get_history(self, session_id): | |
if session_id not in self.histories: | |
self.histories[session_id] = deque(maxlen=self.max_history) | |
return list(self.histories[session_id]) | |
def add_exchange(self, session_id, user_input, ai_response, user_emotion=None, ai_emotion=None): | |
if session_id not in self.histories: | |
self.histories[session_id] = deque(maxlen=self.max_history) | |
exchange = { | |
"user": user_input, | |
"ai": ai_response, | |
"user_emotion": user_emotion, | |
"ai_emotion": ai_emotion, | |
"timestamp": time.time() | |
} | |
self.histories[session_id].append(exchange) | |
def clear_history(self, session_id): | |
if session_id in self.histories: | |
del self.histories[session_id] | |
conversation_manager = ConversationManager() | |
def load_models(): | |
"""Load all models once and cache globally""" | |
global dia_model, asr_model, emotion_classifier | |
if dia_model is None: | |
print("Loading Dia TTS model...") | |
try: | |
# FIXED: Remove torch_dtype parameter - only use compute_dtype | |
dia_model = Dia.from_pretrained( | |
"nari-labs/Dia-1.6B", | |
compute_dtype="float16" | |
) | |
print("โ Dia model loaded successfully!") | |
except Exception as e: | |
print(f"โ Error loading Dia model: {e}") | |
raise | |
if asr_model is None: | |
print("Loading ASR model...") | |
try: | |
# Using Whisper for ASR with optimizations | |
asr_model = pipeline( | |
"automatic-speech-recognition", | |
model="openai/whisper-small", | |
torch_dtype=torch.float16, | |
device="cuda" if torch.cuda.is_available() else "cpu" | |
) | |
print("โ ASR model loaded successfully!") | |
except Exception as e: | |
print(f"โ Error loading ASR model: {e}") | |
raise | |
if emotion_classifier is None: | |
print("Loading emotion classifier...") | |
try: | |
emotion_classifier = pipeline( | |
"text-classification", | |
model="j-hartmann/emotion-english-distilroberta-base", | |
torch_dtype=torch.float16, | |
device="cuda" if torch.cuda.is_available() else "cpu" | |
) | |
print("โ Emotion classifier loaded successfully!") | |
except Exception as e: | |
print(f"โ Error loading emotion classifier: {e}") | |
raise | |
def detect_emotion(text): | |
"""Detect emotion from text""" | |
try: | |
if emotion_classifier is None: | |
return "neutral" | |
result = emotion_classifier(text) | |
return result[0]['label'].lower() if result else "neutral" | |
except Exception as e: | |
print(f"Error in emotion detection: {e}") | |
return "neutral" | |
def transcribe_audio(audio_data): | |
"""Transcribe audio to text with emotion detection""" | |
try: | |
if audio_data is None: | |
return "", "neutral" | |
# Handle different audio input formats | |
if isinstance(audio_data, tuple): | |
sample_rate, audio = audio_data | |
audio = audio.astype(np.float32) | |
else: | |
audio = audio_data | |
sample_rate = 16000 | |
# Ensure audio is in the right format for Whisper | |
if len(audio.shape) > 1: | |
audio = audio.mean(axis=1) | |
# Resample to 16kHz if needed | |
if sample_rate != 16000: | |
audio = librosa.resample(audio, orig_sr=sample_rate, target_sr=16000) | |
# Transcribe | |
result = asr_model(audio) | |
text = result["text"].strip() | |
# Detect emotion from transcribed text | |
emotion = detect_emotion(text) | |
return text, emotion | |
except Exception as e: | |
print(f"Error in transcription: {e}") | |
return "", "neutral" | |
def generate_emotional_response(user_text, user_emotion, conversation_history, session_id): | |
"""Generate contextually aware emotional response""" | |
try: | |
# Build context from conversation history | |
context = "" | |
if conversation_history: | |
recent_exchanges = conversation_history[-5:] # Last 5 exchanges for context | |
for exchange in recent_exchanges: | |
context += f"User: {exchange['user']}\nAI: {exchange['ai']}\n" | |
# Emotional adaptation logic | |
emotion_responses = { | |
"joy": ["excited", "happy", "cheerful"], | |
"sadness": ["empathetic", "gentle", "comforting"], | |
"anger": ["calm", "understanding", "patient"], | |
"fear": ["reassuring", "supportive", "confident"], | |
"surprise": ["curious", "engaged", "interested"], | |
"disgust": ["neutral", "diplomatic", "respectful"], | |
"neutral": ["friendly", "conversational", "natural"] | |
} | |
ai_emotion = np.random.choice(emotion_responses.get(user_emotion, ["friendly"])) | |
# Generate response based on context and emotion | |
if "supernatural" in user_text.lower() or "magic" in user_text.lower(): | |
response_templates = [ | |
"The mystical energies around us are quite fascinating, aren't they?", | |
"I sense something extraordinary in your words...", | |
"The supernatural realm holds many mysteries we're yet to understand.", | |
"There's an otherworldly quality to our conversation that intrigues me." | |
] | |
elif user_emotion == "sadness": | |
response_templates = [ | |
"I understand how you're feeling, and I'm here to listen.", | |
"Your emotions are valid, and it's okay to feel this way.", | |
"Sometimes sharing our feelings can help lighten the burden." | |
] | |
elif user_emotion == "joy": | |
response_templates = [ | |
"Your happiness is contagious! I love your positive energy!", | |
"It's wonderful to hear such joy in your voice!", | |
"Your enthusiasm brightens up our conversation!" | |
] | |
else: | |
response_templates = [ | |
f"That's an interesting perspective on {user_text.split()[-1] if user_text.split() else 'that'}.", | |
"I find our conversation quite engaging and thought-provoking.", | |
"Your thoughts resonate with me in unexpected ways." | |
] | |
response = np.random.choice(response_templates) | |
# Add emotional cues for TTS | |
emotion_cues = { | |
"excited": "(excited)", | |
"happy": "(laughs)", | |
"gentle": "(sighs)", | |
"empathetic": "(softly)", | |
"reassuring": "(warmly)", | |
"curious": "(intrigued)" | |
} | |
if ai_emotion in emotion_cues: | |
response += f" {emotion_cues[ai_emotion]}" | |
return response, ai_emotion | |
except Exception as e: | |
print(f"Error generating response: {e}") | |
return "I'm here to listen and understand you better.", "neutral" | |
def generate_speech(text, emotion="neutral", speaker="S1"): | |
"""Generate speech with emotional conditioning""" | |
try: | |
if dia_model is None: | |
load_models() | |
# Clear GPU cache | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
# Format text for Dia model with speaker tags | |
formatted_text = f"[{speaker}] {text}" | |
# Set seed for consistency | |
torch.manual_seed(42) | |
if torch.cuda.is_available(): | |
torch.cuda.manual_seed(42) | |
print(f"Generating speech: {formatted_text[:100]}...") | |
# Generate audio with optimizations | |
with torch.no_grad(): | |
audio_output = dia_model.generate( | |
formatted_text, | |
use_torch_compile=False, # Disabled for stability | |
verbose=False | |
) | |
# Convert to numpy if needed | |
if isinstance(audio_output, torch.Tensor): | |
audio_output = audio_output.cpu().numpy() | |
# Normalize audio | |
if len(audio_output) > 0: | |
max_val = np.max(np.abs(audio_output)) | |
if max_val > 1.0: | |
audio_output = audio_output / max_val * 0.95 | |
return (44100, audio_output) | |
except Exception as e: | |
print(f"Error in speech generation: {e}") | |
return None | |
def process_conversation(audio_input, session_id, history): | |
"""Main conversation processing pipeline""" | |
start_time = time.time() | |
try: | |
# Step 1: Transcribe audio (Target: <100ms) | |
transcription_start = time.time() | |
user_text, user_emotion = transcribe_audio(audio_input) | |
transcription_time = (time.time() - transcription_start) * 1000 | |
if not user_text: | |
return None, "โ Could not transcribe audio", history, f"Transcription failed" | |
# Step 2: Get conversation history | |
conversation_history = conversation_manager.get_history(session_id) | |
# Step 3: Generate response (Target: <200ms) | |
response_start = time.time() | |
ai_response, ai_emotion = generate_emotional_response( | |
user_text, user_emotion, conversation_history, session_id | |
) | |
response_time = (time.time() - response_start) * 1000 | |
# Step 4: Generate speech (Target: <200ms) | |
tts_start = time.time() | |
audio_output = generate_speech(ai_response, ai_emotion, "S2") | |
tts_time = (time.time() - tts_start) * 1000 | |
# Step 5: Update conversation history | |
conversation_manager.add_exchange( | |
session_id, user_text, ai_response, user_emotion, ai_emotion | |
) | |
# Update gradio history | |
history.append([user_text, ai_response]) | |
total_time = (time.time() - start_time) * 1000 | |
status = f"""โ Processing Complete! | |
๐ Transcription: {transcription_time:.0f}ms | |
๐ง Response Generation: {response_time:.0f}ms | |
๐ต Speech Synthesis: {tts_time:.0f}ms | |
โฑ๏ธ Total Latency: {total_time:.0f}ms | |
๐ User Emotion: {user_emotion} | |
๐ค AI Emotion: {ai_emotion} | |
๐ฌ History: {len(conversation_history)}/50 exchanges""" | |
return audio_output, status, history, f"User: {user_text}" | |
except Exception as e: | |
error_msg = f"โ Error: {str(e)}" | |
return None, error_msg, history, "Processing failed" | |
# Initialize models on startup | |
load_models() | |
# Create Gradio interface | |
with gr.Blocks(title="Supernatural AI Agent", theme=gr.themes.Soft()) as demo: | |
gr.HTML(""" | |
<div style="text-align: center; padding: 20px; background: linear-gradient(45deg, #1a1a2e, #16213e); color: white; border-radius: 15px; margin-bottom: 20px;"> | |
<h1>๐ฎ Supernatural Conversational AI Agent</h1> | |
<p style="font-size: 18px;">Human-like emotional intelligence with <500ms latency โข Speech-to-Speech AI</p> | |
<p style="font-size: 14px; opacity: 0.8;">Powered by Dia TTS โข Emotional Recognition โข 50 Exchange Memory</p> | |
</div> | |
""") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
# Session management | |
session_id = gr.Textbox( | |
label="๐ Session ID", | |
value="user_001", | |
info="Unique ID for conversation history" | |
) | |
# Audio input | |
audio_input = gr.Audio( | |
label="๐ค Speak to the AI", | |
type="numpy", | |
format="wav" | |
) | |
# Process button | |
process_btn = gr.Button( | |
"๐ฃ๏ธ Process Conversation", | |
variant="primary", | |
size="lg" | |
) | |
# Clear history button | |
clear_btn = gr.Button( | |
"๐๏ธ Clear History", | |
variant="secondary" | |
) | |
with gr.Column(scale=2): | |
# Chat history | |
chatbot = gr.Chatbot( | |
label="๐ฌ Conversation History", | |
height=400, | |
show_copy_button=True | |
) | |
# Audio output | |
audio_output = gr.Audio( | |
label="๐ AI Response", | |
type="numpy", | |
autoplay=True | |
) | |
# Status display | |
status_display = gr.Textbox( | |
label="๐ Processing Status", | |
lines=8, | |
interactive=False | |
) | |
# Last input display | |
last_input = gr.Textbox( | |
label="๐ Last Transcription", | |
interactive=False | |
) | |
# Event handlers | |
process_btn.click( | |
fn=process_conversation, | |
inputs=[audio_input, session_id, chatbot], | |
outputs=[audio_output, status_display, chatbot, last_input], | |
concurrency_limit=MAX_CONCURRENT_USERS | |
) | |
def clear_conversation_history(session_id_val): | |
conversation_manager.clear_history(session_id_val) | |
return [], "โ Conversation history cleared!" | |
clear_btn.click( | |
fn=clear_conversation_history, | |
inputs=[session_id], | |
outputs=[chatbot, status_display] | |
) | |
# Usage instructions | |
gr.HTML(""" | |
<div style="margin-top: 20px; padding: 15px; background: #f8f9fa; border-radius: 10px;"> | |
<h3>๐ฏ Usage Instructions:</h3> | |
<ul> | |
<li><strong>Record Audio:</strong> Click the microphone and speak naturally</li> | |
<li><strong>Emotional AI:</strong> The AI detects and responds to your emotions</li> | |
<li><strong>Memory:</strong> Maintains up to 50 conversation exchanges</li> | |
<li><strong>Latency:</strong> Optimized for <500ms response time</li> | |
<li><strong>Concurrent Users:</strong> Supports up to 20 simultaneous users</li> | |
</ul> | |
<h3>๐ฎ Supernatural Features:</h3> | |
<p>Try mentioning supernatural, mystical, or magical topics for specialized responses!</p> | |
<h3>โก Performance Metrics:</h3> | |
<p><strong>Target Latency:</strong> <500ms | <strong>Memory:</strong> 50 exchanges | <strong>Concurrent Users:</strong> 20</p> | |
</div> | |
""") | |
# Configure queue for optimal performance | |
demo.queue( | |
default_concurrency_limit=MAX_CONCURRENT_USERS, | |
max_size=100 | |
) | |
if __name__ == "__main__": | |
demo.launch( | |
server_name="0.0.0.0", | |
server_port=7860, | |
share=False | |
) | |