Devakumar868 commited on
Commit
c0a635e
Β·
verified Β·
1 Parent(s): 37040da

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +276 -37
app.py CHANGED
@@ -1,42 +1,281 @@
1
- from fastapi import FastAPI, UploadFile
 
 
2
  import gradio as gr
3
- from nemo.collections.asr import EncDecRNNTBPEModel
4
  from speechbrain.pretrained import EncoderClassifier
5
- from transformers import DiffusionPipeline, AutoModelForCausalLM, AutoTokenizer
6
- from dia.model import Dia
7
  import soundfile as sf
8
- # Load models
9
- asr = EncDecRNNTBPEModel.from_pretrained("nvidia/parakeet-tdt-0.6b-v2")
10
- emotion = EncoderClassifier.from_hparams(source="speechbrain/emotion-recognition-wav2vec2-IEMOCAP")
11
- diffuser = DiffusionPipeline.from_pretrained("teticio/audio-diffusion-256").to("cuda")
12
- llm_tokenizer = AutoTokenizer.from_pretrained("Vicuna-7B")
13
- llm = AutoModelForCausalLM.from_pretrained("Vicuna-7B").half().to("cuda")
14
- tts = Dia.from_pretrained("nari-labs/Dia-1.6B")
15
 
 
16
  app = FastAPI()
17
- def process(audio_file):
18
- # Save
19
- data, sr = sf.read(audio_file)
20
- # ASR
21
- text = asr.transcribe([audio_file])[0]
22
- # Emotion
23
- emo = emotion.classify_file(audio_file)["label"]
24
- # LLM response
25
- inputs = llm_tokenizer(text, return_tensors="pt").to("cuda")
26
- resp = llm.generate(**inputs, max_new_tokens=128)
27
- reply = llm_tokenizer.decode(resp[0])
28
- # TTS
29
- wav = tts.generate(f"[S1] {reply} [S2]")
30
- sf.write("reply.wav", wav, 44100)
31
- return text, emo, reply, "reply.wav"
32
-
33
- # Gradio UI
34
- iface = gr.Interface(fn=process, inputs=gr.Audio(source="microphone"), outputs=[
35
- gr.Textbox(label="Transcript"),
36
- gr.Textbox(label="Emotion"),
37
- gr.Textbox(label="Reply"),
38
- gr.Audio(label="Audio Reply")
39
- ], live=False, enable_queue=True)
40
- app.mount("/", gr.routes.App.create_app(iface))
41
- if __name__=="__main__":
42
- import uvicorn; uvicorn.run(app, host="0.0.0.0", port=7860)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import tempfile
3
+ from fastapi import FastAPI, UploadFile, File
4
  import gradio as gr
5
+ import nemo.collections.asr as nemo_asr
6
  from speechbrain.pretrained import EncoderClassifier
7
+ from transformers import AutoModelForCausalLM, AutoTokenizer
 
8
  import soundfile as sf
9
+ import torch
10
+ import numpy as np
11
+ from typing import Dict, List, Tuple
12
+ import json
13
+ import uuid
14
+ from datetime import datetime
 
15
 
16
+ # Initialize FastAPI app
17
  app = FastAPI()
18
+
19
+ # Global variables for models
20
+ asr_model = None
21
+ emotion_model = None
22
+ llm_model = None
23
+ llm_tokenizer = None
24
+ conversation_history = {}
25
+
26
+ def load_models():
27
+ """Load all required models"""
28
+ global asr_model, emotion_model, llm_model, llm_tokenizer
29
+
30
+ try:
31
+ # Load ASR model using correct syntax
32
+ print("Loading ASR model...")
33
+ asr_model = nemo_asr.models.ASRModel.from_pretrained("nvidia/parakeet-tdt-0.6b-v2")
34
+ print("ASR model loaded successfully")
35
+
36
+ # Load emotion recognition model
37
+ print("Loading emotion model...")
38
+ emotion_model = EncoderClassifier.from_hparams(
39
+ source="speechbrain/emotion-recognition-wav2vec2-IEMOCAP",
40
+ savedir="./emotion_model_cache"
41
+ )
42
+ print("Emotion model loaded successfully")
43
+
44
+ # Load LLM for conversation
45
+ print("Loading LLM...")
46
+ model_name = "microsoft/DialoGPT-medium" # Lighter alternative to Vicuna
47
+ llm_tokenizer = AutoTokenizer.from_pretrained(model_name)
48
+ llm_model = AutoModelForCausalLM.from_pretrained(
49
+ model_name,
50
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
51
+ device_map="auto" if torch.cuda.is_available() else None
52
+ )
53
+
54
+ # Add padding token if not present
55
+ if llm_tokenizer.pad_token is None:
56
+ llm_tokenizer.pad_token = llm_tokenizer.eos_token
57
+
58
+ print("All models loaded successfully")
59
+
60
+ except Exception as e:
61
+ print(f"Error loading models: {str(e)}")
62
+ raise e
63
+
64
+ def transcribe_audio(audio_path: str) -> Tuple[str, str]:
65
+ """Transcribe audio and detect emotion"""
66
+ try:
67
+ # ASR transcription
68
+ transcription = asr_model.transcribe([audio_path])
69
+ text = transcription[0].text if hasattr(transcription[0], 'text') else str(transcription[0])
70
+
71
+ # Emotion detection
72
+ emotion_result = emotion_model.classify_file(audio_path)
73
+ emotion = emotion_result[0] if isinstance(emotion_result, list) else str(emotion_result)
74
+
75
+ return text, emotion
76
+
77
+ except Exception as e:
78
+ print(f"Error in transcription: {str(e)}")
79
+ return f"Error: {str(e)}", "unknown"
80
+
81
+ def generate_response(user_text: str, emotion: str, user_id: str) -> str:
82
+ """Generate contextual response based on user input and emotion"""
83
+ try:
84
+ # Get conversation history
85
+ if user_id not in conversation_history:
86
+ conversation_history[user_id] = []
87
+
88
+ # Add emotion context to the input
89
+ emotional_context = f"[User is feeling {emotion}] {user_text}"
90
+
91
+ # Encode input with conversation history
92
+ conversation_history[user_id].append(emotional_context)
93
+
94
+ # Keep only last 5 exchanges to manage memory
95
+ if len(conversation_history[user_id]) > 10:
96
+ conversation_history[user_id] = conversation_history[user_id][-10:]
97
+
98
+ # Create input for the model
99
+ input_text = " ".join(conversation_history[user_id][-3:]) # Last 3 exchanges
100
+
101
+ # Tokenize and generate
102
+ inputs = llm_tokenizer.encode(input_text, return_tensors="pt")
103
+ if torch.cuda.is_available():
104
+ inputs = inputs.cuda()
105
+
106
+ with torch.no_grad():
107
+ outputs = llm_model.generate(
108
+ inputs,
109
+ max_new_tokens=100,
110
+ num_return_sequences=1,
111
+ temperature=0.7,
112
+ do_sample=True,
113
+ pad_token_id=llm_tokenizer.eos_token_id
114
+ )
115
+
116
+ # Decode response
117
+ response = llm_tokenizer.decode(outputs[0], skip_special_tokens=True)
118
+
119
+ # Extract only the new part of the response
120
+ response = response[len(input_text):].strip()
121
+
122
+ # Add to conversation history
123
+ conversation_history[user_id].append(response)
124
+
125
+ return response if response else "I understand your feelings. How can I help you today?"
126
+
127
+ except Exception as e:
128
+ print(f"Error generating response: {str(e)}")
129
+ return "I'm having trouble processing that right now. Could you try again?"
130
+
131
+ def process_audio_input(audio_file, user_id: str = None) -> Tuple[str, str, str, str]:
132
+ """Main processing function for audio input"""
133
+ if user_id is None:
134
+ user_id = str(uuid.uuid4())
135
+
136
+ if audio_file is None:
137
+ return "No audio file provided", "", "", user_id
138
+
139
+ try:
140
+ # Save uploaded audio to temporary file
141
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_file:
142
+ # Handle different audio input formats
143
+ if hasattr(audio_file, 'name'):
144
+ # File upload case
145
+ audio_path = audio_file.name
146
+ else:
147
+ # Direct audio data case
148
+ sf.write(tmp_file.name, audio_file[1], audio_file[0])
149
+ audio_path = tmp_file.name
150
+
151
+ # Process audio
152
+ transcription, emotion = transcribe_audio(audio_path)
153
+
154
+ # Generate response
155
+ response = generate_response(transcription, emotion, user_id)
156
+
157
+ # Clean up temporary file
158
+ if audio_path != (audio_file.name if hasattr(audio_file, 'name') else ''):
159
+ os.unlink(audio_path)
160
+
161
+ return transcription, emotion, response, user_id
162
+
163
+ except Exception as e:
164
+ error_msg = f"Processing error: {str(e)}"
165
+ print(error_msg)
166
+ return error_msg, "error", "I'm sorry, I couldn't process your audio.", user_id
167
+
168
+ def get_conversation_history(user_id: str) -> str:
169
+ """Get formatted conversation history for a user"""
170
+ if user_id not in conversation_history or not conversation_history[user_id]:
171
+ return "No conversation history yet."
172
+
173
+ history = conversation_history[user_id]
174
+ formatted_history = []
175
+
176
+ for i in range(0, len(history), 2):
177
+ if i + 1 < len(history):
178
+ user_msg = history[i].replace(f"[User is feeling ", "").split("] ", 1)[-1]
179
+ bot_msg = history[i + 1]
180
+ formatted_history.append(f"**You:** {user_msg}")
181
+ formatted_history.append(f"**AI:** {bot_msg}")
182
+
183
+ return "\n\n".join(formatted_history) if formatted_history else "No conversation history yet."
184
+
185
+ def clear_conversation(user_id: str) -> str:
186
+ """Clear conversation history for a user"""
187
+ if user_id in conversation_history:
188
+ conversation_history[user_id] = []
189
+ return "Conversation history cleared."
190
+
191
+ # Load models on startup
192
+ print("Initializing models...")
193
+ load_models()
194
+ print("Models initialized successfully")
195
+
196
+ # Create Gradio interface
197
+ with gr.Blocks(title="Emotional Conversational AI", theme=gr.themes.Soft()) as iface:
198
+ gr.Markdown("# 🎀 Emotional Conversational AI")
199
+ gr.Markdown("Upload audio or use your microphone to have an emotional conversation with AI")
200
+
201
+ # User ID state
202
+ user_id_state = gr.State(value=str(uuid.uuid4()))
203
+
204
+ with gr.Row():
205
+ with gr.Column(scale=2):
206
+ # Audio input
207
+ audio_input = gr.Audio(
208
+ sources=["microphone", "upload"],
209
+ type="filepath",
210
+ label="πŸŽ™οΈ Record or Upload Audio"
211
+ )
212
+
213
+ # Process button
214
+ process_btn = gr.Button("πŸš€ Process Audio", variant="primary", size="lg")
215
+
216
+ with gr.Column(scale=3):
217
+ # Output displays
218
+ transcription_output = gr.Textbox(
219
+ label="πŸ“ Transcription",
220
+ placeholder="Your speech will appear here...",
221
+ max_lines=3
222
+ )
223
+
224
+ emotion_output = gr.Textbox(
225
+ label="😊 Detected Emotion",
226
+ placeholder="Detected emotion will appear here...",
227
+ max_lines=1
228
+ )
229
+
230
+ response_output = gr.Textbox(
231
+ label="πŸ€– AI Response",
232
+ placeholder="AI response will appear here...",
233
+ max_lines=5
234
+ )
235
+
236
+ with gr.Row():
237
+ with gr.Column():
238
+ # Conversation history
239
+ history_output = gr.Textbox(
240
+ label="πŸ’¬ Conversation History",
241
+ placeholder="Your conversation history will appear here...",
242
+ max_lines=10,
243
+ interactive=False
244
+ )
245
+
246
+ with gr.Column():
247
+ # Control buttons
248
+ show_history_btn = gr.Button("πŸ“– Show History", variant="secondary")
249
+ clear_history_btn = gr.Button("πŸ—‘οΈ Clear History", variant="stop")
250
+ new_session_btn = gr.Button("πŸ†• New Session", variant="secondary")
251
+
252
+ # Event handlers
253
+ process_btn.click(
254
+ fn=process_audio_input,
255
+ inputs=[audio_input, user_id_state],
256
+ outputs=[transcription_output, emotion_output, response_output, user_id_state]
257
+ )
258
+
259
+ show_history_btn.click(
260
+ fn=get_conversation_history,
261
+ inputs=[user_id_state],
262
+ outputs=[history_output]
263
+ )
264
+
265
+ clear_history_btn.click(
266
+ fn=clear_conversation,
267
+ inputs=[user_id_state],
268
+ outputs=[history_output]
269
+ )
270
+
271
+ new_session_btn.click(
272
+ fn=lambda: (str(uuid.uuid4()), "New session started!"),
273
+ outputs=[user_id_state, history_output]
274
+ )
275
+
276
+ # Mount Gradio app to FastAPI
277
+ app = gr.mount_gradio_app(app, iface, path="/")
278
+
279
+ if __name__ == "__main__":
280
+ import uvicorn
281
+ uvicorn.run(app, host="0.0.0.0", port=7860)