Spaces:
Runtime error
Runtime error
import os | |
import tempfile | |
from fastapi import FastAPI, UploadFile, File | |
import gradio as gr | |
import nemo.collections.asr as nemo_asr | |
from speechbrain.pretrained import EncoderClassifier | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
import soundfile as sf | |
import torch | |
import numpy as np | |
from typing import Dict, List, Tuple | |
import json | |
import uuid | |
from datetime import datetime | |
# Initialize FastAPI app | |
app = FastAPI() | |
# Global variables for models | |
asr_model = None | |
emotion_model = None | |
llm_model = None | |
llm_tokenizer = None | |
conversation_history = {} | |
def load_models(): | |
"""Load all required models""" | |
global asr_model, emotion_model, llm_model, llm_tokenizer | |
try: | |
# Load ASR model using correct syntax | |
print("Loading ASR model...") | |
asr_model = nemo_asr.models.ASRModel.from_pretrained("nvidia/parakeet-tdt-0.6b-v2") | |
print("ASR model loaded successfully") | |
# Load emotion recognition model | |
print("Loading emotion model...") | |
emotion_model = EncoderClassifier.from_hparams( | |
source="speechbrain/emotion-recognition-wav2vec2-IEMOCAP", | |
savedir="./emotion_model_cache" | |
) | |
print("Emotion model loaded successfully") | |
# Load LLM for conversation | |
print("Loading LLM...") | |
model_name = "microsoft/DialoGPT-medium" # Lighter alternative to Vicuna | |
llm_tokenizer = AutoTokenizer.from_pretrained(model_name) | |
llm_model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, | |
device_map="auto" if torch.cuda.is_available() else None | |
) | |
# Add padding token if not present | |
if llm_tokenizer.pad_token is None: | |
llm_tokenizer.pad_token = llm_tokenizer.eos_token | |
print("All models loaded successfully") | |
except Exception as e: | |
print(f"Error loading models: {str(e)}") | |
raise e | |
def transcribe_audio(audio_path: str) -> Tuple[str, str]: | |
"""Transcribe audio and detect emotion""" | |
try: | |
# ASR transcription | |
transcription = asr_model.transcribe([audio_path]) | |
text = transcription[0].text if hasattr(transcription[0], 'text') else str(transcription[0]) | |
# Emotion detection | |
emotion_result = emotion_model.classify_file(audio_path) | |
emotion = emotion_result[0] if isinstance(emotion_result, list) else str(emotion_result) | |
return text, emotion | |
except Exception as e: | |
print(f"Error in transcription: {str(e)}") | |
return f"Error: {str(e)}", "unknown" | |
def generate_response(user_text: str, emotion: str, user_id: str) -> str: | |
"""Generate contextual response based on user input and emotion""" | |
try: | |
# Get conversation history | |
if user_id not in conversation_history: | |
conversation_history[user_id] = [] | |
# Add emotion context to the input | |
emotional_context = f"[User is feeling {emotion}] {user_text}" | |
# Encode input with conversation history | |
conversation_history[user_id].append(emotional_context) | |
# Keep only last 5 exchanges to manage memory | |
if len(conversation_history[user_id]) > 10: | |
conversation_history[user_id] = conversation_history[user_id][-10:] | |
# Create input for the model | |
input_text = " ".join(conversation_history[user_id][-3:]) # Last 3 exchanges | |
# Tokenize and generate | |
inputs = llm_tokenizer.encode(input_text, return_tensors="pt") | |
if torch.cuda.is_available(): | |
inputs = inputs.cuda() | |
with torch.no_grad(): | |
outputs = llm_model.generate( | |
inputs, | |
max_new_tokens=100, | |
num_return_sequences=1, | |
temperature=0.7, | |
do_sample=True, | |
pad_token_id=llm_tokenizer.eos_token_id | |
) | |
# Decode response | |
response = llm_tokenizer.decode(outputs[0], skip_special_tokens=True) | |
# Extract only the new part of the response | |
response = response[len(input_text):].strip() | |
# Add to conversation history | |
conversation_history[user_id].append(response) | |
return response if response else "I understand your feelings. How can I help you today?" | |
except Exception as e: | |
print(f"Error generating response: {str(e)}") | |
return "I'm having trouble processing that right now. Could you try again?" | |
def process_audio_input(audio_file, user_id: str = None) -> Tuple[str, str, str, str]: | |
"""Main processing function for audio input""" | |
if user_id is None: | |
user_id = str(uuid.uuid4()) | |
if audio_file is None: | |
return "No audio file provided", "", "", user_id | |
try: | |
# Save uploaded audio to temporary file | |
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_file: | |
# Handle different audio input formats | |
if hasattr(audio_file, 'name'): | |
# File upload case | |
audio_path = audio_file.name | |
else: | |
# Direct audio data case | |
sf.write(tmp_file.name, audio_file[1], audio_file[0]) | |
audio_path = tmp_file.name | |
# Process audio | |
transcription, emotion = transcribe_audio(audio_path) | |
# Generate response | |
response = generate_response(transcription, emotion, user_id) | |
# Clean up temporary file | |
if audio_path != (audio_file.name if hasattr(audio_file, 'name') else ''): | |
os.unlink(audio_path) | |
return transcription, emotion, response, user_id | |
except Exception as e: | |
error_msg = f"Processing error: {str(e)}" | |
print(error_msg) | |
return error_msg, "error", "I'm sorry, I couldn't process your audio.", user_id | |
def get_conversation_history(user_id: str) -> str: | |
"""Get formatted conversation history for a user""" | |
if user_id not in conversation_history or not conversation_history[user_id]: | |
return "No conversation history yet." | |
history = conversation_history[user_id] | |
formatted_history = [] | |
for i in range(0, len(history), 2): | |
if i + 1 < len(history): | |
user_msg = history[i].replace(f"[User is feeling ", "").split("] ", 1)[-1] | |
bot_msg = history[i + 1] | |
formatted_history.append(f"**You:** {user_msg}") | |
formatted_history.append(f"**AI:** {bot_msg}") | |
return "\n\n".join(formatted_history) if formatted_history else "No conversation history yet." | |
def clear_conversation(user_id: str) -> str: | |
"""Clear conversation history for a user""" | |
if user_id in conversation_history: | |
conversation_history[user_id] = [] | |
return "Conversation history cleared." | |
# Load models on startup | |
print("Initializing models...") | |
load_models() | |
print("Models initialized successfully") | |
# Create Gradio interface | |
with gr.Blocks(title="Emotional Conversational AI", theme=gr.themes.Soft()) as iface: | |
gr.Markdown("# π€ Emotional Conversational AI") | |
gr.Markdown("Upload audio or use your microphone to have an emotional conversation with AI") | |
# User ID state | |
user_id_state = gr.State(value=str(uuid.uuid4())) | |
with gr.Row(): | |
with gr.Column(scale=2): | |
# Audio input | |
audio_input = gr.Audio( | |
sources=["microphone", "upload"], | |
type="filepath", | |
label="ποΈ Record or Upload Audio" | |
) | |
# Process button | |
process_btn = gr.Button("π Process Audio", variant="primary", size="lg") | |
with gr.Column(scale=3): | |
# Output displays | |
transcription_output = gr.Textbox( | |
label="π Transcription", | |
placeholder="Your speech will appear here...", | |
max_lines=3 | |
) | |
emotion_output = gr.Textbox( | |
label="π Detected Emotion", | |
placeholder="Detected emotion will appear here...", | |
max_lines=1 | |
) | |
response_output = gr.Textbox( | |
label="π€ AI Response", | |
placeholder="AI response will appear here...", | |
max_lines=5 | |
) | |
with gr.Row(): | |
with gr.Column(): | |
# Conversation history | |
history_output = gr.Textbox( | |
label="π¬ Conversation History", | |
placeholder="Your conversation history will appear here...", | |
max_lines=10, | |
interactive=False | |
) | |
with gr.Column(): | |
# Control buttons | |
show_history_btn = gr.Button("π Show History", variant="secondary") | |
clear_history_btn = gr.Button("ποΈ Clear History", variant="stop") | |
new_session_btn = gr.Button("π New Session", variant="secondary") | |
# Event handlers | |
process_btn.click( | |
fn=process_audio_input, | |
inputs=[audio_input, user_id_state], | |
outputs=[transcription_output, emotion_output, response_output, user_id_state] | |
) | |
show_history_btn.click( | |
fn=get_conversation_history, | |
inputs=[user_id_state], | |
outputs=[history_output] | |
) | |
clear_history_btn.click( | |
fn=clear_conversation, | |
inputs=[user_id_state], | |
outputs=[history_output] | |
) | |
new_session_btn.click( | |
fn=lambda: (str(uuid.uuid4()), "New session started!"), | |
outputs=[user_id_state, history_output] | |
) | |
# Mount Gradio app to FastAPI | |
app = gr.mount_gradio_app(app, iface, path="/") | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run(app, host="0.0.0.0", port=7860) | |