Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch | |
import numpy as np | |
import librosa | |
import soundfile as sf | |
import threading | |
import time | |
import queue | |
import warnings | |
from typing import Optional, List, Dict, Tuple | |
from dataclasses import dataclass | |
from collections import deque | |
import psutil | |
import gc | |
# Models and pipelines | |
from dia.model import Dia | |
from transformers import pipeline | |
import webrtcvad | |
warnings.filterwarnings("ignore", category=FutureWarning) | |
warnings.filterwarnings("ignore", category=UserWarning) | |
class ConversationTurn: | |
user_audio: np.ndarray | |
user_text: str | |
ai_response_text: str | |
ai_response_audio: np.ndarray | |
timestamp: float | |
emotion: str | |
speaker_id: str | |
class EmotionRecognizer: | |
def __init__(self): | |
self.emotion_pipeline = pipeline( | |
"audio-classification", | |
model="ehcalabres/wav2vec2-lg-xlsr-en-speech-emotion-recognition", | |
device=0 if torch.cuda.is_available() else -1 | |
) | |
def detect_emotion(self, audio: np.ndarray, sample_rate: int = 16000) -> str: | |
try: | |
result = self.emotion_pipeline({"array": audio, "sampling_rate": sample_rate}) | |
return result[0]["label"] if result else "neutral" | |
except Exception: | |
return "neutral" | |
class VADProcessor: | |
def __init__(self, aggressiveness: int = 2): | |
self.vad = webrtcvad.Vad(aggressiveness) | |
self.sample_rate = 16000 | |
self.frame_duration = 30 | |
self.frame_size = int(self.sample_rate * self.frame_duration / 1000) | |
def is_speech(self, audio: np.ndarray) -> bool: | |
audio_int16 = (audio * 32767).astype(np.int16) | |
frames = [] | |
for i in range(0, len(audio_int16) - self.frame_size, self.frame_size): | |
frame = audio_int16[i : i + self.frame_size].tobytes() | |
frames.append(self.vad.is_speech(frame, self.sample_rate)) | |
return sum(frames) > len(frames) * 0.3 | |
class ConversationManager: | |
def __init__(self, max_exchanges: int = 50): | |
self.conversations: Dict[str, deque] = {} | |
self.max_exchanges = max_exchanges | |
self.lock = threading.RLock() | |
def add_turn(self, session_id: str, turn: ConversationTurn): | |
with self.lock: | |
if session_id not in self.conversations: | |
self.conversations[session_id] = deque(maxlen=self.max_exchanges) | |
self.conversations[session_id].append(turn) | |
def get_context(self, session_id: str, last_n: int = 5) -> List[ConversationTurn]: | |
with self.lock: | |
return list(self.conversations.get(session_id, []))[-last_n:] | |
def clear_session(self, session_id: str): | |
with self.lock: | |
if session_id in self.conversations: | |
del self.conversations[session_id] | |
class SupernaturalAI: | |
def __init__(self): | |
self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
self.models_loaded = False | |
self.conversation_manager = ConversationManager() | |
self.processing_times = deque(maxlen=100) | |
self.emotion_recognizer = None | |
self.vad_processor = VADProcessor() | |
self.ultravox_model = None | |
self.dia_model = None | |
self._initialize_models() | |
def _initialize_models(self): | |
try: | |
self.ultravox_model = pipeline( | |
'automatic-speech-recognition', | |
model='fixie-ai/ultravox-v0_2', | |
trust_remote_code=True, | |
device=0 if torch.cuda.is_available() else -1, | |
torch_dtype=torch.float16 | |
) | |
self.dia_model = Dia.from_pretrained( | |
"nari-labs/Dia-1.6B", compute_dtype="float16" | |
) | |
self.emotion_recognizer = EmotionRecognizer() | |
self.models_loaded = True | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
except Exception as e: | |
print(f"Model load error: {e}") | |
self.models_loaded = False | |
def process_audio_input(self, audio_data: Tuple[int, np.ndarray], session_id: str): | |
if not self.models_loaded or audio_data is None: | |
return None, "Models not ready", "Please wait" | |
start = time.time() | |
sample_rate, audio = audio_data | |
if len(audio.shape) > 1: | |
audio = np.mean(audio, axis=1) | |
audio = audio.astype(np.float32) | |
if np.max(np.abs(audio)) > 0: | |
audio = audio / np.max(np.abs(audio)) * 0.95 | |
if not self.vad_processor.is_speech(audio): | |
return None, "No speech detected", "Speak clearly" | |
if sample_rate != 16000: | |
audio = librosa.resample(audio, sample_rate, 16000) | |
sample_rate = 16000 | |
try: | |
result = self.ultravox_model({'array': audio, 'sampling_rate': sample_rate}) | |
user_text = result.get('text', '').strip() | |
if not user_text: | |
return None, "Could not understand", "Try again" | |
except Exception as e: | |
return None, f"ASR error: {e}", "Retry" | |
emotion = self.emotion_recognizer.detect_emotion(audio, sample_rate) | |
context = self.conversation_manager.get_context(session_id) | |
prompt = self._build_prompt(user_text, emotion, context) | |
try: | |
with torch.no_grad(): | |
audio_out = self.dia_model.generate(prompt, use_torch_compile=False) | |
audio_out = audio_out.cpu().numpy() if isinstance(audio_out, torch.Tensor) else audio_out | |
except Exception as e: | |
return None, f"TTS error: {e}", "Retry" | |
ai_text = prompt.split('[S2]')[-1].strip() | |
turn = ConversationTurn(audio, user_text, ai_text, audio_out, time.time(), emotion, session_id) | |
self.conversation_manager.add_turn(session_id, turn) | |
elapsed = time.time() - start | |
self.processing_times.append(elapsed) | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
gc.collect() | |
status = f"Processed in {elapsed:.2f}s | Emotion: {emotion}" | |
return (44100, audio_out), status, f"You: {user_text}\n\nAI: {ai_text}" | |
def _build_prompt(self, text, emotion, context): | |
ctx = "".join(f"[U]{t.user_text}[A]{t.ai_response_text} " for t in context[-3:]) | |
mods = {"happy":"(cheerful)","sad":"(sympathetic)","angry":"(calming)", | |
"fear":"(reassuring)","surprise":"(excited)","neutral":""} | |
return f"{ctx}[U]{text}[A]{mods.get(emotion,'')} As a supernatural AI, I sense your {emotion} energy. " | |
def get_history(self, session_id: str) -> str: | |
ctx = self.conversation_manager.get_context(session_id, last_n=10) | |
if not ctx: | |
return "No history." | |
out = "" | |
for i, t in enumerate(ctx,1): | |
out += f"Turn {i} — You: {t.user_text} | AI: {t.ai_response_text} | Emotion: {t.emotion}\n\n" | |
return out | |
def clear_history(self, session_id: str) -> str: | |
self.conversation_manager.clear_session(session_id) | |
return "History cleared." | |
# Instantiate and launch Gradio app | |
ai = SupernaturalAI() | |
with gr.Blocks() as demo: | |
audio_in = gr.Audio(source="microphone", type="numpy", label="Speak") | |
audio_out = gr.Audio(label="AI Response") | |
session = gr.Textbox(label="Session ID", interactive=True) | |
status = gr.Textbox(label="Status") | |
chat = gr.Markdown("## Conversation") | |
btn = gr.Button("Send") | |
btn.click(fn=lambda a, s: ai.process_audio_input(a, s), | |
inputs=[audio_in, session], | |
outputs=[audio_out, status, chat, session]) | |
hist_btn = gr.Button("History") | |
hist_btn.click(fn=lambda s: ai.get_history(s), inputs=session, outputs=chat) | |
clr_btn = gr.Button("Clear") | |
clr_btn.click(fn=lambda s: ai.clear_history(s), inputs=session, outputs=chat) | |
demo.queue(concurrency_count=20, max_size=100) | |
demo.launch(server_name="0.0.0.0", server_port=7860, enable_queue=True) | |