Devakumar868 commited on
Commit
5adc99b
·
verified ·
1 Parent(s): baf7f5d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -431
app.py CHANGED
@@ -1,434 +1,46 @@
1
  import gradio as gr
2
- import torch
3
- import numpy as np
4
- import soundfile as sf
5
- import librosa
6
- import warnings
7
- from transformers import pipeline, AutoProcessor, AutoModel
8
  from dia.model import Dia
9
- import asyncio
10
- import time
11
- from collections import deque
12
- import json
13
-
14
- # Suppress warnings
15
- warnings.filterwarnings("ignore")
16
-
17
- # Global variables for model caching
18
- dia_model = None
19
- asr_model = None
20
- emotion_classifier = None
21
- conversation_histories = {}
22
- MAX_HISTORY = 50
23
- MAX_CONCURRENT_USERS = 20
24
-
25
- class ConversationManager:
26
- def __init__(self):
27
- self.histories = {}
28
- self.max_history = MAX_HISTORY
29
-
30
- def get_history(self, session_id):
31
- if session_id not in self.histories:
32
- self.histories[session_id] = deque(maxlen=self.max_history)
33
- return list(self.histories[session_id])
34
-
35
- def add_exchange(self, session_id, user_input, ai_response, user_emotion=None, ai_emotion=None):
36
- if session_id not in self.histories:
37
- self.histories[session_id] = deque(maxlen=self.max_history)
38
-
39
- exchange = {
40
- "user": user_input,
41
- "ai": ai_response,
42
- "user_emotion": user_emotion,
43
- "ai_emotion": ai_emotion,
44
- "timestamp": time.time()
45
- }
46
- self.histories[session_id].append(exchange)
47
-
48
- def clear_history(self, session_id):
49
- if session_id in self.histories:
50
- del self.histories[session_id]
51
-
52
- conversation_manager = ConversationManager()
53
-
54
- def load_models():
55
- """Load all models once and cache globally"""
56
- global dia_model, asr_model, emotion_classifier
57
-
58
- if dia_model is None:
59
- print("Loading Dia TTS model...")
60
- try:
61
- # FIXED: Remove torch_dtype parameter - only use compute_dtype
62
- dia_model = Dia.from_pretrained(
63
- "nari-labs/Dia-1.6B",
64
- compute_dtype="float16"
65
- )
66
- print("✅ Dia model loaded successfully!")
67
- except Exception as e:
68
- print(f"❌ Error loading Dia model: {e}")
69
- raise
70
-
71
- if asr_model is None:
72
- print("Loading ASR model...")
73
- try:
74
- # Using Whisper for ASR with optimizations
75
- asr_model = pipeline(
76
- "automatic-speech-recognition",
77
- model="openai/whisper-small",
78
- torch_dtype=torch.float16,
79
- device="cuda" if torch.cuda.is_available() else "cpu"
80
- )
81
- print("✅ ASR model loaded successfully!")
82
- except Exception as e:
83
- print(f"❌ Error loading ASR model: {e}")
84
- raise
85
-
86
- if emotion_classifier is None:
87
- print("Loading emotion classifier...")
88
- try:
89
- emotion_classifier = pipeline(
90
- "text-classification",
91
- model="j-hartmann/emotion-english-distilroberta-base",
92
- torch_dtype=torch.float16,
93
- device="cuda" if torch.cuda.is_available() else "cpu"
94
- )
95
- print("✅ Emotion classifier loaded successfully!")
96
- except Exception as e:
97
- print(f"❌ Error loading emotion classifier: {e}")
98
- raise
99
-
100
- def detect_emotion(text):
101
- """Detect emotion from text"""
102
- try:
103
- if emotion_classifier is None:
104
- return "neutral"
105
-
106
- result = emotion_classifier(text)
107
- return result[0]['label'].lower() if result else "neutral"
108
- except Exception as e:
109
- print(f"Error in emotion detection: {e}")
110
- return "neutral"
111
-
112
- def transcribe_audio(audio_data):
113
- """Transcribe audio to text with emotion detection"""
114
- try:
115
- if audio_data is None:
116
- return "", "neutral"
117
-
118
- # Handle different audio input formats
119
- if isinstance(audio_data, tuple):
120
- sample_rate, audio = audio_data
121
- audio = audio.astype(np.float32)
122
- else:
123
- audio = audio_data
124
- sample_rate = 16000
125
-
126
- # Ensure audio is in the right format for Whisper
127
- if len(audio.shape) > 1:
128
- audio = audio.mean(axis=1)
129
-
130
- # Resample to 16kHz if needed
131
- if sample_rate != 16000:
132
- audio = librosa.resample(audio, orig_sr=sample_rate, target_sr=16000)
133
-
134
- # Transcribe
135
- result = asr_model(audio)
136
- text = result["text"].strip()
137
-
138
- # Detect emotion from transcribed text
139
- emotion = detect_emotion(text)
140
-
141
- return text, emotion
142
-
143
- except Exception as e:
144
- print(f"Error in transcription: {e}")
145
- return "", "neutral"
146
-
147
- def generate_emotional_response(user_text, user_emotion, conversation_history, session_id):
148
- """Generate contextually aware emotional response"""
149
- try:
150
- # Build context from conversation history
151
- context = ""
152
- if conversation_history:
153
- recent_exchanges = conversation_history[-5:] # Last 5 exchanges for context
154
- for exchange in recent_exchanges:
155
- context += f"User: {exchange['user']}\nAI: {exchange['ai']}\n"
156
-
157
- # Emotional adaptation logic
158
- emotion_responses = {
159
- "joy": ["excited", "happy", "cheerful"],
160
- "sadness": ["empathetic", "gentle", "comforting"],
161
- "anger": ["calm", "understanding", "patient"],
162
- "fear": ["reassuring", "supportive", "confident"],
163
- "surprise": ["curious", "engaged", "interested"],
164
- "disgust": ["neutral", "diplomatic", "respectful"],
165
- "neutral": ["friendly", "conversational", "natural"]
166
- }
167
-
168
- ai_emotion = np.random.choice(emotion_responses.get(user_emotion, ["friendly"]))
169
-
170
- # Generate response based on context and emotion
171
- if "supernatural" in user_text.lower() or "magic" in user_text.lower():
172
- response_templates = [
173
- "The mystical energies around us are quite fascinating, aren't they?",
174
- "I sense something extraordinary in your words...",
175
- "The supernatural realm holds many mysteries we're yet to understand.",
176
- "There's an otherworldly quality to our conversation that intrigues me."
177
- ]
178
- elif user_emotion == "sadness":
179
- response_templates = [
180
- "I understand how you're feeling, and I'm here to listen.",
181
- "Your emotions are valid, and it's okay to feel this way.",
182
- "Sometimes sharing our feelings can help lighten the burden."
183
- ]
184
- elif user_emotion == "joy":
185
- response_templates = [
186
- "Your happiness is contagious! I love your positive energy!",
187
- "It's wonderful to hear such joy in your voice!",
188
- "Your enthusiasm brightens up our conversation!"
189
- ]
190
- else:
191
- response_templates = [
192
- f"That's an interesting perspective on {user_text.split()[-1] if user_text.split() else 'that'}.",
193
- "I find our conversation quite engaging and thought-provoking.",
194
- "Your thoughts resonate with me in unexpected ways."
195
- ]
196
-
197
- response = np.random.choice(response_templates)
198
-
199
- # Add emotional cues for TTS
200
- emotion_cues = {
201
- "excited": "(excited)",
202
- "happy": "(laughs)",
203
- "gentle": "(sighs)",
204
- "empathetic": "(softly)",
205
- "reassuring": "(warmly)",
206
- "curious": "(intrigued)"
207
- }
208
-
209
- if ai_emotion in emotion_cues:
210
- response += f" {emotion_cues[ai_emotion]}"
211
-
212
- return response, ai_emotion
213
-
214
- except Exception as e:
215
- print(f"Error generating response: {e}")
216
- return "I'm here to listen and understand you better.", "neutral"
217
-
218
- def generate_speech(text, emotion="neutral", speaker="S1"):
219
- """Generate speech with emotional conditioning"""
220
- try:
221
- if dia_model is None:
222
- load_models()
223
-
224
- # Clear GPU cache
225
- if torch.cuda.is_available():
226
- torch.cuda.empty_cache()
227
-
228
- # Format text for Dia model with speaker tags
229
- formatted_text = f"[{speaker}] {text}"
230
-
231
- # Set seed for consistency
232
- torch.manual_seed(42)
233
- if torch.cuda.is_available():
234
- torch.cuda.manual_seed(42)
235
-
236
- print(f"Generating speech: {formatted_text[:100]}...")
237
-
238
- # Generate audio with optimizations
239
- with torch.no_grad():
240
- audio_output = dia_model.generate(
241
- formatted_text,
242
- use_torch_compile=False, # Disabled for stability
243
- verbose=False
244
- )
245
-
246
- # Convert to numpy if needed
247
- if isinstance(audio_output, torch.Tensor):
248
- audio_output = audio_output.cpu().numpy()
249
-
250
- # Normalize audio
251
- if len(audio_output) > 0:
252
- max_val = np.max(np.abs(audio_output))
253
- if max_val > 1.0:
254
- audio_output = audio_output / max_val * 0.95
255
-
256
- return (44100, audio_output)
257
-
258
- except Exception as e:
259
- print(f"Error in speech generation: {e}")
260
- return None
261
-
262
- def process_conversation(audio_input, session_id, history):
263
- """Main conversation processing pipeline"""
264
- start_time = time.time()
265
-
266
- try:
267
- # Step 1: Transcribe audio (Target: <100ms)
268
- transcription_start = time.time()
269
- user_text, user_emotion = transcribe_audio(audio_input)
270
- transcription_time = (time.time() - transcription_start) * 1000
271
-
272
- if not user_text:
273
- return None, "❌ Could not transcribe audio", history, f"Transcription failed"
274
-
275
- # Step 2: Get conversation history
276
- conversation_history = conversation_manager.get_history(session_id)
277
-
278
- # Step 3: Generate response (Target: <200ms)
279
- response_start = time.time()
280
- ai_response, ai_emotion = generate_emotional_response(
281
- user_text, user_emotion, conversation_history, session_id
282
- )
283
- response_time = (time.time() - response_start) * 1000
284
-
285
- # Step 4: Generate speech (Target: <200ms)
286
- tts_start = time.time()
287
- audio_output = generate_speech(ai_response, ai_emotion, "S2")
288
- tts_time = (time.time() - tts_start) * 1000
289
-
290
- # Step 5: Update conversation history
291
- conversation_manager.add_exchange(
292
- session_id, user_text, ai_response, user_emotion, ai_emotion
293
- )
294
-
295
- # Update gradio history
296
- history.append([user_text, ai_response])
297
-
298
- total_time = (time.time() - start_time) * 1000
299
-
300
- status = f"""✅ Processing Complete!
301
- 📝 Transcription: {transcription_time:.0f}ms
302
- 🧠 Response Generation: {response_time:.0f}ms
303
- 🎵 Speech Synthesis: {tts_time:.0f}ms
304
- ⏱️ Total Latency: {total_time:.0f}ms
305
- 😊 User Emotion: {user_emotion}
306
- 🤖 AI Emotion: {ai_emotion}
307
- 💬 History: {len(conversation_history)}/50 exchanges"""
308
-
309
- return audio_output, status, history, f"User: {user_text}"
310
-
311
- except Exception as e:
312
- error_msg = f"❌ Error: {str(e)}"
313
- return None, error_msg, history, "Processing failed"
314
-
315
- # Initialize models on startup
316
- load_models()
317
-
318
- # Create Gradio interface
319
- with gr.Blocks(title="Supernatural AI Agent", theme=gr.themes.Soft()) as demo:
320
- gr.HTML("""
321
- <div style="text-align: center; padding: 20px; background: linear-gradient(45deg, #1a1a2e, #16213e); color: white; border-radius: 15px; margin-bottom: 20px;">
322
- <h1>🔮 Supernatural Conversational AI Agent</h1>
323
- <p style="font-size: 18px;">Human-like emotional intelligence with <500ms latency • Speech-to-Speech AI</p>
324
- <p style="font-size: 14px; opacity: 0.8;">Powered by Dia TTS • Emotional Recognition • 50 Exchange Memory</p>
325
- </div>
326
- """)
327
-
328
- with gr.Row():
329
- with gr.Column(scale=1):
330
- # Session management
331
- session_id = gr.Textbox(
332
- label="🆔 Session ID",
333
- value="user_001",
334
- info="Unique ID for conversation history"
335
- )
336
-
337
- # Audio input
338
- audio_input = gr.Audio(
339
- label="🎤 Speak to the AI",
340
- type="numpy",
341
- format="wav"
342
- )
343
-
344
- # Process button
345
- process_btn = gr.Button(
346
- "🗣️ Process Conversation",
347
- variant="primary",
348
- size="lg"
349
- )
350
-
351
- # Clear history button
352
- clear_btn = gr.Button(
353
- "🗑️ Clear History",
354
- variant="secondary"
355
- )
356
-
357
- with gr.Column(scale=2):
358
- # Chat history
359
- chatbot = gr.Chatbot(
360
- label="💬 Conversation History",
361
- height=400,
362
- show_copy_button=True
363
- )
364
-
365
- # Audio output
366
- audio_output = gr.Audio(
367
- label="🔊 AI Response",
368
- type="numpy",
369
- autoplay=True
370
- )
371
-
372
- # Status display
373
- status_display = gr.Textbox(
374
- label="📊 Processing Status",
375
- lines=8,
376
- interactive=False
377
- )
378
-
379
- # Last input display
380
- last_input = gr.Textbox(
381
- label="📝 Last Transcription",
382
- interactive=False
383
- )
384
-
385
- # Event handlers
386
- process_btn.click(
387
- fn=process_conversation,
388
- inputs=[audio_input, session_id, chatbot],
389
- outputs=[audio_output, status_display, chatbot, last_input],
390
- concurrency_limit=MAX_CONCURRENT_USERS
391
- )
392
-
393
- def clear_conversation_history(session_id_val):
394
- conversation_manager.clear_history(session_id_val)
395
- return [], "✅ Conversation history cleared!"
396
-
397
- clear_btn.click(
398
- fn=clear_conversation_history,
399
- inputs=[session_id],
400
- outputs=[chatbot, status_display]
401
- )
402
-
403
- # Usage instructions
404
- gr.HTML("""
405
- <div style="margin-top: 20px; padding: 15px; background: #f8f9fa; border-radius: 10px;">
406
- <h3>🎯 Usage Instructions:</h3>
407
- <ul>
408
- <li><strong>Record Audio:</strong> Click the microphone and speak naturally</li>
409
- <li><strong>Emotional AI:</strong> The AI detects and responds to your emotions</li>
410
- <li><strong>Memory:</strong> Maintains up to 50 conversation exchanges</li>
411
- <li><strong>Latency:</strong> Optimized for <500ms response time</li>
412
- <li><strong>Concurrent Users:</strong> Supports up to 20 simultaneous users</li>
413
- </ul>
414
-
415
- <h3>🔮 Supernatural Features:</h3>
416
- <p>Try mentioning supernatural, mystical, or magical topics for specialized responses!</p>
417
-
418
- <h3>⚡ Performance Metrics:</h3>
419
- <p><strong>Target Latency:</strong> <500ms | <strong>Memory:</strong> 50 exchanges | <strong>Concurrent Users:</strong> 20</p>
420
- </div>
421
- """)
422
-
423
- # Configure queue for optimal performance
424
- demo.queue(
425
- default_concurrency_limit=MAX_CONCURRENT_USERS,
426
- max_size=100
427
- )
428
-
429
- if __name__ == "__main__":
430
- demo.launch(
431
- server_name="0.0.0.0",
432
- server_port=7860,
433
- share=False
434
  )
 
 
1
  import gradio as gr
2
+ from transformers import AutoProcessor, CsmForConditionalGeneration
 
 
 
 
 
3
  from dia.model import Dia
4
+ from pyannote.audio import Pipeline as VAD
5
+ import torch, numpy as np
6
+
7
+ # Load models
8
+ ultra_proc = AutoProcessor.from_pretrained("fixie-ai/ultravox-v0_4")
9
+ ultra_model = CsmForConditionalGeneration.from_pretrained("fixie-ai/ultravox-v0_4", device_map="auto", torch_dtype=torch.float16)
10
+ ser = AutoProcessor.from_pretrained("r-f/wav2vec-english-speech-emotion-recognition")
11
+ ser_model = torch.hub.load("jonatasgrosman/wav2vec2-large-xlsr-53-english", "wav2vec2_large_xlsr", pretrained=True).to("cuda")
12
+ diff_pipe = torch.hub.load("teticio/audio-diffusion-instrumental-hiphop-256", "audio_diffusion").to("cuda")
13
+ rvq = torch.hub.load("ibm/DAC.speech.v1.0", "DAC_speech_v1_0").to("cuda")
14
+ vad = VAD.from_pretrained("pyannote/voice-activity-detection")
15
+ dia = Dia.from_pretrained("nari-labs/Dia-1.6B", compute_dtype="float16")
16
+
17
+ def process(audio):
18
+ # VAD
19
+ speech = vad({"waveform": audio["array"], "sample_rate": audio["sampling_rate"]})
20
+ # RVQ encode/decode
21
+ codes = rvq.encode(audio["array"])
22
+ dec_audio = rvq.decode(codes)
23
+ # Emotion
24
+ emo_inputs = ser(audio["array"], sampling_rate=audio["sampling_rate"], return_tensors="pt")
25
+ emotion = ser_model(**emo_inputs).logits.argmax(-1).item()
26
+ # Ultravox generation
27
+ inputs = ultra_proc(audio["array"], sampling_rate=audio["sampling_rate"], return_tensors="pt").to("cuda")
28
+ speech_out = ultra_model.generate(**inputs, output_audio=True)
29
+ # Diffuse and clone voice
30
+ audio_diff = diff_pipe(speech_out.audio).audios[0]
31
+ # TTS
32
+ text = f"[S1][emotion={emotion}]" + " ".join(["..."]) # placeholder
33
+ dia_audio = dia.generate(text)
34
+ # Normalize
35
+ dia_audio = dia_audio / np.max(np.abs(dia_audio)) * 0.95
36
+ return 44100, dia_audio
37
+
38
+ with gr.Blocks() as demo:
39
+ state = gr.State([])
40
+ audio_in = gr.Audio(source="microphone", type="numpy")
41
+ chat = gr.Chatbot()
42
+ record = gr.Button("Record")
43
+ record.click(process, inputs=audio_in, outputs=[audio_in]).then(
44
+ lambda a: chat.update(value=[("User", ""), ("AI", "")]),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  )
46
+ demo.queue(concurrency_limit=20, max_size=50).launch()