Devakumar868 commited on
Commit
cfde29f
Β·
verified Β·
1 Parent(s): e85d66b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +503 -69
app.py CHANGED
@@ -1,74 +1,508 @@
1
- import os, tempfile, uuid
2
- from fastapi import FastAPI
3
  import gradio as gr
4
- import soundfile as sf
5
  import torch
6
  import numpy as np
7
- import nemo.collections.asr as nemo_asr
8
- from speechbrain.pretrained import EncoderClassifier
9
- from transformers import AutoTokenizer, AutoModelForCausalLM
10
-
11
- # Initialize FastAPI and models
12
- app = FastAPI()
13
- conversation_history = {}
14
-
15
- # Model loading
16
- asr_model = nemo_asr.models.ASRModel.from_pretrained("nvidia/parakeet-tdt-0.6b-v2") # ASR [2]
17
- emotion_model = EncoderClassifier.from_hparams(
18
- source="speechbrain/emotion-recognition-wav2vec2-IEMOCAP",
19
- savedir="emotion_cache"
20
- ) # Emotion [3]
21
- llm_name = "microsoft/DialoGPT-medium"
22
- llm_tokenizer = AutoTokenizer.from_pretrained(llm_name)
23
- llm_model = AutoModelForCausalLM.from_pretrained(llm_name).to("cuda" if torch.cuda.is_available() else "cpu") # LLM [4]
24
-
25
- def transcribe_and_emote(audio_path):
26
- text = asr_model.transcribe([audio_path])[0].text
27
- emotion = emotion_model.classify_file(audio_path)[0]
28
- return text, emotion
29
-
30
- def generate_reply(user_text, emotion, uid):
31
- # Track and trim history
32
- hist = conversation_history.setdefault(uid, [])
33
- ctx = f"[Feeling:{emotion}] {user_text}"
34
- hist.append(ctx)
35
- hist = hist[-6:]
36
- conversation_history[uid] = hist
37
-
38
- prompt = " ".join(hist)
39
- inputs = llm_tokenizer.encode(prompt, return_tensors="pt").to(llm_model.device)
40
- out = llm_model.generate(inputs, max_new_tokens=100, pad_token_id=llm_tokenizer.eos_token_id)
41
- reply = llm_tokenizer.decode(out[0], skip_special_tokens=True)[len(prompt):].strip()
42
- hist.append(reply)
43
- return reply or "I’m here to help!"
44
-
45
- def process(audio, uid):
46
- if not audio:
47
- return "", "", "", uid
48
- # Save temp file
49
- tmp = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
50
- data, sr = audio
51
- sf.write(tmp.name, data, sr)
52
- # ASR + Emotion
53
- text, emo = transcribe_and_emote(tmp.name)
54
- # LLM response
55
- reply = generate_reply(text, emo, uid)
56
- # Clean up
57
- os.unlink(tmp.name)
58
- return text, emo, reply, uid
59
-
60
- # Gradio interface
61
- with gr.Blocks() as demo:
62
- uid_state = gr.State(value=str(uuid.uuid4()))
63
- audio_in = gr.Audio(source="microphone", type="numpy")
64
- txt_out = gr.Textbox(label="Transcription")
65
- emo_out = gr.Textbox(label="Emotion")
66
- rep_out = gr.Textbox(label="AI Reply")
67
- btn = gr.Button("Process")
68
- btn.click(process, inputs=[audio_in, uid_state], outputs=[txt_out, emo_out, rep_out, uid_state])
69
-
70
- app = gr.mount_gradio_app(app, demo, path="/")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
 
72
  if __name__ == "__main__":
73
- import uvicorn
74
- uvicorn.run(app, host="0.0.0.0", port=7860)
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
 
2
  import torch
3
  import numpy as np
4
+ import librosa
5
+ import soundfile as sf
6
+ import threading
7
+ import time
8
+ import queue
9
+ import warnings
10
+ from typing import Optional, List, Dict, Tuple
11
+ from dataclasses import dataclass
12
+ from collections import deque
13
+ import psutil
14
+ import gc
15
+
16
+ # Import models
17
+ from dia.model import Dia
18
+ from transformers import pipeline
19
+ import webrtcvad
20
+
21
+ warnings.filterwarnings("ignore", category=FutureWarning)
22
+ warnings.filterwarnings("ignore", category=UserWarning)
23
+
24
+ @dataclass
25
+ class ConversationTurn:
26
+ user_audio: np.ndarray
27
+ user_text: str
28
+ ai_response_text: str
29
+ ai_response_audio: np.ndarray
30
+ timestamp: float
31
+ emotion: str
32
+ speaker_id: str
33
+
34
+ class EmotionRecognizer:
35
+ def __init__(self):
36
+ self.emotion_pipeline = pipeline(
37
+ "audio-classification",
38
+ model="ehcalabres/wav2vec2-lg-xlsr-en-speech-emotion-recognition",
39
+ device=0 if torch.cuda.is_available() else -1
40
+ )
41
+
42
+ def detect_emotion(self, audio: np.ndarray, sample_rate: int = 16000) -> str:
43
+ try:
44
+ result = self.emotion_pipeline({"array": audio, "sampling_rate": sample_rate})
45
+ return result[0]["label"] if result else "neutral"
46
+ except Exception as e:
47
+ print(f"Emotion detection error: {e}")
48
+ return "neutral"
49
+
50
+ class VADProcessor:
51
+ def __init__(self, aggressiveness: int = 2):
52
+ self.vad = webrtcvad.Vad(aggressiveness)
53
+ self.sample_rate = 16000
54
+ self.frame_duration = 30 # ms
55
+ self.frame_size = int(self.sample_rate * self.frame_duration / 1000)
56
+
57
+ def is_speech(self, audio: np.ndarray) -> bool:
58
+ try:
59
+ # Convert to 16-bit PCM
60
+ audio_int16 = (audio * 32767).astype(np.int16)
61
+
62
+ # Process in frames
63
+ frames = []
64
+ for i in range(0, len(audio_int16) - self.frame_size, self.frame_size):
65
+ frame = audio_int16[i:i + self.frame_size].tobytes()
66
+ frames.append(self.vad.is_speech(frame, self.sample_rate))
67
+
68
+ # Return True if majority of frames contain speech
69
+ return sum(frames) > len(frames) * 0.3
70
+ except Exception:
71
+ return True # Default to treating as speech
72
+
73
+ class ConversationManager:
74
+ def __init__(self, max_exchanges: int = 50):
75
+ self.conversations: Dict[str, deque] = {}
76
+ self.max_exchanges = max_exchanges
77
+ self.lock = threading.RLock()
78
+
79
+ def add_turn(self, session_id: str, turn: ConversationTurn):
80
+ with self.lock:
81
+ if session_id not in self.conversations:
82
+ self.conversations[session_id] = deque(maxlen=self.max_exchanges)
83
+ self.conversations[session_id].append(turn)
84
+
85
+ def get_context(self, session_id: str, last_n: int = 5) -> List[ConversationTurn]:
86
+ with self.lock:
87
+ if session_id not in self.conversations:
88
+ return []
89
+ return list(self.conversations[session_id])[-last_n:]
90
+
91
+ def clear_session(self, session_id: str):
92
+ with self.lock:
93
+ if session_id in self.conversations:
94
+ del self.conversations[session_id]
95
+
96
+ class SupernaturalAI:
97
+ def __init__(self):
98
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
99
+ self.models_loaded = False
100
+ self.processing_queue = queue.Queue()
101
+ self.conversation_manager = ConversationManager()
102
+ self.emotion_recognizer = None
103
+ self.vad_processor = VADProcessor()
104
+
105
+ # Models
106
+ self.ultravox_model = None
107
+ self.dia_model = None
108
+
109
+ # Performance tracking
110
+ self.active_sessions = set()
111
+ self.processing_times = deque(maxlen=100)
112
+
113
+ print("Initializing Supernatural AI...")
114
+ self._initialize_models()
115
+
116
+ def _initialize_models(self):
117
+ try:
118
+ print("Loading Ultravox model...")
119
+ self.ultravox_model = pipeline(
120
+ 'automatic-speech-recognition',
121
+ model='fixie-ai/ultravox-v0_2',
122
+ trust_remote_code=True,
123
+ device=0 if torch.cuda.is_available() else -1,
124
+ torch_dtype=torch.float16
125
+ )
126
+
127
+ print("Loading Dia TTS model...")
128
+ self.dia_model = Dia.from_pretrained(
129
+ "nari-labs/Dia-1.6B",
130
+ compute_dtype="float16"
131
+ )
132
+
133
+ print("Loading emotion recognition...")
134
+ self.emotion_recognizer = EmotionRecognizer()
135
+
136
+ self.models_loaded = True
137
+ print("βœ… All models loaded successfully!")
138
+
139
+ # Memory cleanup
140
+ if torch.cuda.is_available():
141
+ torch.cuda.empty_cache()
142
+
143
+ except Exception as e:
144
+ print(f"❌ Error loading models: {e}")
145
+ self.models_loaded = False
146
+
147
+ def _get_memory_usage(self) -> Dict[str, float]:
148
+ """Get current memory usage statistics"""
149
+ memory = psutil.virtual_memory()
150
+ gpu_memory = {}
151
+
152
+ if torch.cuda.is_available():
153
+ for i in range(torch.cuda.device_count()):
154
+ gpu_memory[f"GPU_{i}"] = {
155
+ "allocated": torch.cuda.memory_allocated(i) / 1024**3,
156
+ "cached": torch.cuda.memory_reserved(i) / 1024**3
157
+ }
158
+
159
+ return {
160
+ "RAM": memory.percent,
161
+ "GPU": gpu_memory
162
+ }
163
+
164
+ def _generate_contextual_prompt(self,
165
+ user_text: str,
166
+ emotion: str,
167
+ context: List[ConversationTurn]) -> str:
168
+ """Generate contextual prompt with emotion and conversation history"""
169
+
170
+ # Build context from previous turns
171
+ context_text = ""
172
+ if context:
173
+ for turn in context[-3:]: # Last 3 exchanges
174
+ context_text += f"[S1] {turn.user_text} [S2] {turn.ai_response_text} "
175
+
176
+ # Emotion-aware response generation
177
+ emotion_modifiers = {
178
+ "happy": "(cheerful)",
179
+ "sad": "(sympathetic)",
180
+ "angry": "(calming)",
181
+ "fear": "(reassuring)",
182
+ "surprise": "(excited)",
183
+ "neutral": ""
184
+ }
185
+
186
+ modifier = emotion_modifiers.get(emotion.lower(), "")
187
+
188
+ # Create supernatural AI personality
189
+ prompt = f"{context_text}[S1] {user_text} [S2] {modifier} As a supernatural AI with deep emotional understanding, I sense your {emotion} energy. "
190
+
191
+ return prompt
192
+
193
+ def process_audio_input(self,
194
+ audio_data: Tuple[int, np.ndarray],
195
+ session_id: str) -> Tuple[Optional[Tuple[int, np.ndarray]], str, str]:
196
+ """Main processing pipeline for audio input"""
197
+
198
+ if not self.models_loaded:
199
+ return None, "❌ Models not loaded", "Please wait for initialization"
200
+
201
+ if audio_data is None:
202
+ return None, "❌ No audio received", "Please record some audio"
203
+
204
+ start_time = time.time()
205
+
206
+ try:
207
+ sample_rate, audio = audio_data
208
+
209
+ # Ensure audio is mono and proper format
210
+ if len(audio.shape) > 1:
211
+ audio = np.mean(audio, axis=1)
212
+
213
+ # Normalize audio
214
+ audio = audio.astype(np.float32)
215
+ if np.max(np.abs(audio)) > 0:
216
+ audio = audio / np.max(np.abs(audio)) * 0.95
217
+
218
+ # Voice Activity Detection
219
+ if not self.vad_processor.is_speech(audio):
220
+ return None, "πŸ”‡ No speech detected", "Please speak clearly"
221
+
222
+ # Resample if needed
223
+ if sample_rate != 16000:
224
+ audio = librosa.resample(audio, orig_sr=sample_rate, target_sr=16000)
225
+ sample_rate = 16000
226
+
227
+ # Speech Recognition with Ultravox
228
+ try:
229
+ speech_result = self.ultravox_model({
230
+ 'array': audio,
231
+ 'sampling_rate': sample_rate
232
+ })
233
+ user_text = speech_result.get('text', '').strip()
234
+
235
+ if not user_text:
236
+ return None, "❌ Could not understand speech", "Please speak more clearly"
237
+
238
+ except Exception as e:
239
+ print(f"ASR Error: {e}")
240
+ return None, f"❌ Speech recognition failed: {str(e)}", "Please try again"
241
+
242
+ # Emotion Recognition
243
+ emotion = self.emotion_recognizer.detect_emotion(audio, sample_rate)
244
+
245
+ # Get conversation context
246
+ context = self.conversation_manager.get_context(session_id)
247
+
248
+ # Generate contextual response
249
+ prompt = self._generate_contextual_prompt(user_text, emotion, context)
250
+
251
+ # Generate speech with Dia TTS
252
+ try:
253
+ with torch.no_grad():
254
+ audio_output = self.dia_model.generate(
255
+ prompt,
256
+ use_torch_compile=False, # Better stability
257
+ verbose=False
258
+ )
259
+
260
+ # Ensure audio output is proper format
261
+ if isinstance(audio_output, torch.Tensor):
262
+ audio_output = audio_output.cpu().numpy()
263
+
264
+ # Normalize output
265
+ if len(audio_output) > 0:
266
+ max_val = np.max(np.abs(audio_output))
267
+ if max_val > 1.0:
268
+ audio_output = audio_output / max_val * 0.95
269
+
270
+ except Exception as e:
271
+ print(f"TTS Error: {e}")
272
+ return None, f"❌ Speech generation failed: {str(e)}", "Please try again"
273
+
274
+ # Extract AI response text (remove speaker tags and modifiers)
275
+ ai_response = prompt.split('[S2]')[-1].strip()
276
+ ai_response = ai_response.replace('(cheerful)', '').replace('(sympathetic)', '')
277
+ ai_response = ai_response.replace('(calming)', '').replace('(reassuring)', '')
278
+ ai_response = ai_response.replace('(excited)', '').strip()
279
+
280
+ # Store conversation turn
281
+ turn = ConversationTurn(
282
+ user_audio=audio,
283
+ user_text=user_text,
284
+ ai_response_text=ai_response,
285
+ ai_response_audio=audio_output,
286
+ timestamp=time.time(),
287
+ emotion=emotion,
288
+ speaker_id=session_id
289
+ )
290
+
291
+ self.conversation_manager.add_turn(session_id, turn)
292
+
293
+ # Track performance
294
+ processing_time = time.time() - start_time
295
+ self.processing_times.append(processing_time)
296
+
297
+ # Memory cleanup
298
+ if torch.cuda.is_available():
299
+ torch.cuda.empty_cache()
300
+ gc.collect()
301
+
302
+ status = f"βœ… Processed in {processing_time:.2f}s | Emotion: {emotion} | Users: {len(self.active_sessions)}"
303
+
304
+ return (44100, audio_output), status, f"**You said:** {user_text}\n\n**AI Response:** {ai_response}"
305
+
306
+ except Exception as e:
307
+ print(f"Processing error: {e}")
308
+ return None, f"❌ Processing failed: {str(e)}", "Please try again"
309
+
310
+ def get_conversation_history(self, session_id: str) -> str:
311
+ """Get formatted conversation history"""
312
+ context = self.conversation_manager.get_context(session_id, last_n=10)
313
+ if not context:
314
+ return "No conversation history yet."
315
+
316
+ history = "## Conversation History\n\n"
317
+ for i, turn in enumerate(context, 1):
318
+ history += f"**Turn {i}:**\n"
319
+ history += f"- **You:** {turn.user_text}\n"
320
+ history += f"- **AI:** {turn.ai_response_text}\n"
321
+ history += f"- **Emotion Detected:** {turn.emotion}\n\n"
322
+
323
+ return history
324
+
325
+ def clear_conversation(self, session_id: str) -> str:
326
+ """Clear conversation history for session"""
327
+ self.conversation_manager.clear_session(session_id)
328
+ return "Conversation history cleared."
329
+
330
+ def get_system_status(self) -> str:
331
+ """Get system status information"""
332
+ memory = self._get_memory_usage()
333
+ avg_processing = np.mean(self.processing_times) if self.processing_times else 0
334
+
335
+ status = f"""## System Status
336
+
337
+ **Performance:**
338
+ - Average Processing Time: {avg_processing:.2f}s
339
+ - Active Sessions: {len(self.active_sessions)}
340
+ - Total Conversations: {len(self.conversation_manager.conversations)}
341
+
342
+ **Memory Usage:**
343
+ - RAM: {memory['RAM']:.1f}%
344
+ - GPU Memory: {memory.get('GPU', {})}
345
+
346
+ **Models Status:**
347
+ - Models Loaded: {"βœ…" if self.models_loaded else "❌"}
348
+ - Device: {self.device}
349
+ """
350
+ return status
351
+
352
+ # Initialize the AI system
353
+ print("Starting Supernatural AI system...")
354
+ ai_system = SupernaturalAI()
355
+
356
+ # Gradio Interface
357
+ def process_audio_interface(audio, session_id):
358
+ """Interface function for Gradio"""
359
+ if not session_id:
360
+ session_id = f"user_{int(time.time())}"
361
+
362
+ ai_system.active_sessions.add(session_id)
363
+ result = ai_system.process_audio_input(audio, session_id)
364
+ return result + (session_id,)
365
+
366
+ def get_history_interface(session_id):
367
+ """Get conversation history interface"""
368
+ if not session_id:
369
+ return "No session ID provided"
370
+ return ai_system.get_conversation_history(session_id)
371
+
372
+ def clear_history_interface(session_id):
373
+ """Clear history interface"""
374
+ if not session_id:
375
+ return "No session ID provided"
376
+ return ai_system.clear_conversation(session_id)
377
+
378
+ # Create Gradio interface
379
+ with gr.Blocks(title="Supernatural Conversational AI", theme=gr.themes.Soft()) as demo:
380
+ gr.HTML("""
381
+ <div style="text-align: center; padding: 20px;">
382
+ <h1>πŸ§™β€β™‚οΈ Supernatural Conversational AI</h1>
383
+ <p style="font-size: 18px; color: #666;">
384
+ Advanced Speech-to-Speech AI with Emotional Intelligence
385
+ </p>
386
+ <p style="color: #888;">
387
+ Powered by Ultravox + Dia TTS | Optimized for 4x L4 GPUs
388
+ </p>
389
+ </div>
390
+ """)
391
+
392
+ with gr.Row():
393
+ with gr.Column(scale=2):
394
+ # Audio input/output
395
+ audio_input = gr.Audio(
396
+ label="🎀 Speak to the AI",
397
+ sources=["microphone"],
398
+ type="numpy",
399
+ streaming=False
400
+ )
401
+
402
+ audio_output = gr.Audio(
403
+ label="πŸ”Š AI Response",
404
+ type="numpy",
405
+ autoplay=True
406
+ )
407
+
408
+ # Session management
409
+ session_id = gr.Textbox(
410
+ label="Session ID",
411
+ placeholder="Auto-generated if empty",
412
+ value="",
413
+ interactive=True
414
+ )
415
+
416
+ # Process button
417
+ process_btn = gr.Button("🎯 Process Audio", variant="primary", size="lg")
418
+
419
+ with gr.Column(scale=1):
420
+ # Status and conversation
421
+ status_display = gr.Textbox(
422
+ label="πŸ“Š Status",
423
+ interactive=False,
424
+ lines=3
425
+ )
426
+
427
+ conversation_display = gr.Markdown(
428
+ label="πŸ’¬ Conversation",
429
+ value="Start speaking to begin..."
430
+ )
431
+
432
+ # History management
433
+ with gr.Row():
434
+ history_btn = gr.Button("πŸ“œ Show History", size="sm")
435
+ clear_btn = gr.Button("πŸ—‘οΈ Clear History", size="sm")
436
+ status_btn = gr.Button("⚑ System Status", size="sm")
437
+
438
+ # History and status display
439
+ history_display = gr.Markdown(
440
+ label="πŸ“š Conversation History",
441
+ value="No history yet."
442
+ )
443
+
444
+ # Event handlers
445
+ process_btn.click(
446
+ fn=process_audio_interface,
447
+ inputs=[audio_input, session_id],
448
+ outputs=[audio_output, status_display, conversation_display, session_id]
449
+ )
450
+
451
+ history_btn.click(
452
+ fn=get_history_interface,
453
+ inputs=[session_id],
454
+ outputs=[history_display]
455
+ )
456
+
457
+ clear_btn.click(
458
+ fn=clear_history_interface,
459
+ inputs=[session_id],
460
+ outputs=[history_display]
461
+ )
462
+
463
+ status_btn.click(
464
+ fn=lambda: ai_system.get_system_status(),
465
+ outputs=[history_display]
466
+ )
467
+
468
+ # Auto-process on audio input
469
+ audio_input.change(
470
+ fn=process_audio_interface,
471
+ inputs=[audio_input, session_id],
472
+ outputs=[audio_output, status_display, conversation_display, session_id]
473
+ )
474
+
475
+ # Usage instructions
476
+ gr.HTML("""
477
+ <div style="margin-top: 20px; padding: 15px; background: #f0f8ff; border-radius: 8px;">
478
+ <h3>πŸ’‘ Usage Instructions:</h3>
479
+ <ul>
480
+ <li><strong>Record Audio:</strong> Click the microphone and speak naturally</li>
481
+ <li><strong>Emotional AI:</strong> The AI detects and responds to your emotions</li>
482
+ <li><strong>Conversation Memory:</strong> Up to 50 exchanges are remembered</li>
483
+ <li><strong>Session Management:</strong> Use Session ID to maintain separate conversations</li>
484
+ <li><strong>Performance:</strong> Optimized for sub-500ms latency</li>
485
+ </ul>
486
+
487
+ <p><strong>Supported Features:</strong> Emotion recognition, voice activity detection,
488
+ contextual responses, conversation history, concurrent users (15-20), memory management</p>
489
+ </div>
490
+ """)
491
+
492
+ # Configure for optimal performance
493
+ demo.queue(
494
+ concurrency_count=20, # Support 20 concurrent users
495
+ max_size=100,
496
+ api_open=False
497
+ )
498
 
499
  if __name__ == "__main__":
500
+ demo.launch(
501
+ server_name="0.0.0.0",
502
+ server_port=7860,
503
+ share=False,
504
+ show_error=True,
505
+ quiet=False,
506
+ enable_queue=True,
507
+ max_threads=40
508
+ )