Devakumar868 commited on
Commit
25ff0cb
·
verified ·
1 Parent(s): 9161761

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +56 -263
app.py CHANGED
@@ -1,280 +1,73 @@
1
- import os
2
- import tempfile
3
- from fastapi import FastAPI, UploadFile, File
4
  import gradio as gr
5
- import nemo.collections.asr as nemo_asr
6
- from speechbrain.pretrained import EncoderClassifier
7
- from transformers import AutoModelForCausalLM, AutoTokenizer
8
  import soundfile as sf
9
  import torch
10
  import numpy as np
11
- from typing import Dict, List, Tuple
12
- import json
13
- import uuid
14
- from datetime import datetime
15
 
16
- # Initialize FastAPI app
17
  app = FastAPI()
18
-
19
- # Global variables for models
20
- asr_model = None
21
- emotion_model = None
22
- llm_model = None
23
- llm_tokenizer = None
24
  conversation_history = {}
25
 
26
- def load_models():
27
- """Load all required models"""
28
- global asr_model, emotion_model, llm_model, llm_tokenizer
29
-
30
- try:
31
- # Load ASR model using correct syntax
32
- print("Loading ASR model...")
33
- asr_model = nemo_asr.models.ASRModel.from_pretrained("nvidia/parakeet-tdt-0.6b-v2")
34
- print("ASR model loaded successfully")
35
-
36
- # Load emotion recognition model
37
- print("Loading emotion model...")
38
- emotion_model = EncoderClassifier.from_hparams(
39
- source="speechbrain/emotion-recognition-wav2vec2-IEMOCAP",
40
- savedir="./emotion_model_cache"
41
- )
42
- print("Emotion model loaded successfully")
43
-
44
- # Load LLM for conversation
45
- print("Loading LLM...")
46
- model_name = "microsoft/DialoGPT-medium" # Lighter alternative to Vicuna
47
- llm_tokenizer = AutoTokenizer.from_pretrained(model_name)
48
- llm_model = AutoModelForCausalLM.from_pretrained(
49
- model_name,
50
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
51
- device_map="auto" if torch.cuda.is_available() else None
52
- )
53
-
54
- # Add padding token if not present
55
- if llm_tokenizer.pad_token is None:
56
- llm_tokenizer.pad_token = llm_tokenizer.eos_token
57
-
58
- print("All models loaded successfully")
59
-
60
- except Exception as e:
61
- print(f"Error loading models: {str(e)}")
62
- raise e
63
-
64
- def transcribe_audio(audio_path: str) -> Tuple[str, str]:
65
- """Transcribe audio and detect emotion"""
66
- try:
67
- # ASR transcription
68
- transcription = asr_model.transcribe([audio_path])
69
- text = transcription[0].text if hasattr(transcription[0], 'text') else str(transcription[0])
70
-
71
- # Emotion detection
72
- emotion_result = emotion_model.classify_file(audio_path)
73
- emotion = emotion_result[0] if isinstance(emotion_result, list) else str(emotion_result)
74
-
75
- return text, emotion
76
-
77
- except Exception as e:
78
- print(f"Error in transcription: {str(e)}")
79
- return f"Error: {str(e)}", "unknown"
80
-
81
- def generate_response(user_text: str, emotion: str, user_id: str) -> str:
82
- """Generate contextual response based on user input and emotion"""
83
- try:
84
- # Get conversation history
85
- if user_id not in conversation_history:
86
- conversation_history[user_id] = []
87
-
88
- # Add emotion context to the input
89
- emotional_context = f"[User is feeling {emotion}] {user_text}"
90
-
91
- # Encode input with conversation history
92
- conversation_history[user_id].append(emotional_context)
93
-
94
- # Keep only last 5 exchanges to manage memory
95
- if len(conversation_history[user_id]) > 10:
96
- conversation_history[user_id] = conversation_history[user_id][-10:]
97
-
98
- # Create input for the model
99
- input_text = " ".join(conversation_history[user_id][-3:]) # Last 3 exchanges
100
-
101
- # Tokenize and generate
102
- inputs = llm_tokenizer.encode(input_text, return_tensors="pt")
103
- if torch.cuda.is_available():
104
- inputs = inputs.cuda()
105
-
106
- with torch.no_grad():
107
- outputs = llm_model.generate(
108
- inputs,
109
- max_new_tokens=100,
110
- num_return_sequences=1,
111
- temperature=0.7,
112
- do_sample=True,
113
- pad_token_id=llm_tokenizer.eos_token_id
114
- )
115
-
116
- # Decode response
117
- response = llm_tokenizer.decode(outputs[0], skip_special_tokens=True)
118
-
119
- # Extract only the new part of the response
120
- response = response[len(input_text):].strip()
121
-
122
- # Add to conversation history
123
- conversation_history[user_id].append(response)
124
-
125
- return response if response else "I understand your feelings. How can I help you today?"
126
-
127
- except Exception as e:
128
- print(f"Error generating response: {str(e)}")
129
- return "I'm having trouble processing that right now. Could you try again?"
130
 
131
- def process_audio_input(audio_file, user_id: str = None) -> Tuple[str, str, str, str]:
132
- """Main processing function for audio input"""
133
- if user_id is None:
134
- user_id = str(uuid.uuid4())
135
-
136
- if audio_file is None:
137
- return "No audio file provided", "", "", user_id
138
-
139
- try:
140
- # Save uploaded audio to temporary file
141
- with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_file:
142
- # Handle different audio input formats
143
- if hasattr(audio_file, 'name'):
144
- # File upload case
145
- audio_path = audio_file.name
146
- else:
147
- # Direct audio data case
148
- sf.write(tmp_file.name, audio_file[1], audio_file[0])
149
- audio_path = tmp_file.name
150
-
151
- # Process audio
152
- transcription, emotion = transcribe_audio(audio_path)
153
-
154
- # Generate response
155
- response = generate_response(transcription, emotion, user_id)
156
-
157
- # Clean up temporary file
158
- if audio_path != (audio_file.name if hasattr(audio_file, 'name') else ''):
159
- os.unlink(audio_path)
160
-
161
- return transcription, emotion, response, user_id
162
-
163
- except Exception as e:
164
- error_msg = f"Processing error: {str(e)}"
165
- print(error_msg)
166
- return error_msg, "error", "I'm sorry, I couldn't process your audio.", user_id
167
 
168
- def get_conversation_history(user_id: str) -> str:
169
- """Get formatted conversation history for a user"""
170
- if user_id not in conversation_history or not conversation_history[user_id]:
171
- return "No conversation history yet."
172
-
173
- history = conversation_history[user_id]
174
- formatted_history = []
175
-
176
- for i in range(0, len(history), 2):
177
- if i + 1 < len(history):
178
- user_msg = history[i].replace(f"[User is feeling ", "").split("] ", 1)[-1]
179
- bot_msg = history[i + 1]
180
- formatted_history.append(f"**You:** {user_msg}")
181
- formatted_history.append(f"**AI:** {bot_msg}")
182
-
183
- return "\n\n".join(formatted_history) if formatted_history else "No conversation history yet."
184
 
185
- def clear_conversation(user_id: str) -> str:
186
- """Clear conversation history for a user"""
187
- if user_id in conversation_history:
188
- conversation_history[user_id] = []
189
- return "Conversation history cleared."
 
190
 
191
- # Load models on startup
192
- print("Initializing models...")
193
- load_models()
194
- print("Models initialized successfully")
 
 
 
 
 
 
 
 
 
 
195
 
196
- # Create Gradio interface
197
- with gr.Blocks(title="Emotional Conversational AI", theme=gr.themes.Soft()) as iface:
198
- gr.Markdown("# 🎤 Emotional Conversational AI")
199
- gr.Markdown("Upload audio or use your microphone to have an emotional conversation with AI")
200
-
201
- # User ID state
202
- user_id_state = gr.State(value=str(uuid.uuid4()))
203
-
204
- with gr.Row():
205
- with gr.Column(scale=2):
206
- # Audio input
207
- audio_input = gr.Audio(
208
- sources=["microphone", "upload"],
209
- type="filepath",
210
- label="🎙️ Record or Upload Audio"
211
- )
212
-
213
- # Process button
214
- process_btn = gr.Button("🚀 Process Audio", variant="primary", size="lg")
215
-
216
- with gr.Column(scale=3):
217
- # Output displays
218
- transcription_output = gr.Textbox(
219
- label="📝 Transcription",
220
- placeholder="Your speech will appear here...",
221
- max_lines=3
222
- )
223
-
224
- emotion_output = gr.Textbox(
225
- label="😊 Detected Emotion",
226
- placeholder="Detected emotion will appear here...",
227
- max_lines=1
228
- )
229
-
230
- response_output = gr.Textbox(
231
- label="🤖 AI Response",
232
- placeholder="AI response will appear here...",
233
- max_lines=5
234
- )
235
-
236
- with gr.Row():
237
- with gr.Column():
238
- # Conversation history
239
- history_output = gr.Textbox(
240
- label="💬 Conversation History",
241
- placeholder="Your conversation history will appear here...",
242
- max_lines=10,
243
- interactive=False
244
- )
245
-
246
- with gr.Column():
247
- # Control buttons
248
- show_history_btn = gr.Button("📖 Show History", variant="secondary")
249
- clear_history_btn = gr.Button("🗑️ Clear History", variant="stop")
250
- new_session_btn = gr.Button("🆕 New Session", variant="secondary")
251
-
252
- # Event handlers
253
- process_btn.click(
254
- fn=process_audio_input,
255
- inputs=[audio_input, user_id_state],
256
- outputs=[transcription_output, emotion_output, response_output, user_id_state]
257
- )
258
-
259
- show_history_btn.click(
260
- fn=get_conversation_history,
261
- inputs=[user_id_state],
262
- outputs=[history_output]
263
- )
264
-
265
- clear_history_btn.click(
266
- fn=clear_conversation,
267
- inputs=[user_id_state],
268
- outputs=[history_output]
269
- )
270
-
271
- new_session_btn.click(
272
- fn=lambda: (str(uuid.uuid4()), "New session started!"),
273
- outputs=[user_id_state, history_output]
274
- )
275
 
276
- # Mount Gradio app to FastAPI
277
- app = gr.mount_gradio_app(app, iface, path="/")
278
 
279
  if __name__ == "__main__":
280
  import uvicorn
 
1
+ import os, tempfile, uuid
2
+ from fastapi import FastAPI
 
3
  import gradio as gr
 
 
 
4
  import soundfile as sf
5
  import torch
6
  import numpy as np
7
+ import nemo.collections.asr as nemo_asr
8
+ from speechbrain.pretrained import EncoderClassifier
9
+ from transformers import AutoTokenizer, AutoModelForCausalLM
 
10
 
11
+ # Initialize FastAPI and models
12
  app = FastAPI()
 
 
 
 
 
 
13
  conversation_history = {}
14
 
15
+ # Model loading
16
+ asr_model = nemo_asr.models.ASRModel.from_pretrained("nvidia/parakeet-tdt-0.6b-v2") # ASR [2]
17
+ emotion_model = EncoderClassifier.from_hparams(
18
+ source="speechbrain/emotion-recognition-wav2vec2-IEMOCAP",
19
+ savedir="emotion_cache"
20
+ ) # Emotion [3]
21
+ llm_name = "microsoft/DialoGPT-medium"
22
+ llm_tokenizer = AutoTokenizer.from_pretrained(llm_name)
23
+ llm_model = AutoModelForCausalLM.from_pretrained(llm_name).to("cuda" if torch.cuda.is_available() else "cpu") # LLM [4]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
+ def transcribe_and_emote(audio_path):
26
+ text = asr_model.transcribe([audio_path])[0].text
27
+ emotion = emotion_model.classify_file(audio_path)[0]
28
+ return text, emotion
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
+ def generate_reply(user_text, emotion, uid):
31
+ # Track and trim history
32
+ hist = conversation_history.setdefault(uid, [])
33
+ ctx = f"[Feeling:{emotion}] {user_text}"
34
+ hist.append(ctx)
35
+ hist = hist[-6:]
36
+ conversation_history[uid] = hist
 
 
 
 
 
 
 
 
 
37
 
38
+ prompt = " ".join(hist)
39
+ inputs = llm_tokenizer.encode(prompt, return_tensors="pt").to(llm_model.device)
40
+ out = llm_model.generate(inputs, max_new_tokens=100, pad_token_id=llm_tokenizer.eos_token_id)
41
+ reply = llm_tokenizer.decode(out[0], skip_special_tokens=True)[len(prompt):].strip()
42
+ hist.append(reply)
43
+ return reply or "I’m here to help!"
44
 
45
+ def process(audio, uid):
46
+ if not audio:
47
+ return "", "", "", uid
48
+ # Save temp file
49
+ tmp = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
50
+ data, sr = audio
51
+ sf.write(tmp.name, data, sr)
52
+ # ASR + Emotion
53
+ text, emo = transcribe_and_emote(tmp.name)
54
+ # LLM response
55
+ reply = generate_reply(text, emo, uid)
56
+ # Clean up
57
+ os.unlink(tmp.name)
58
+ return text, emo, reply, uid
59
 
60
+ # Gradio interface
61
+ with gr.Blocks() as demo:
62
+ uid_state = gr.State(value=str(uuid.uuid4()))
63
+ audio_in = gr.Audio(source="microphone", type="numpy")
64
+ txt_out = gr.Textbox(label="Transcription")
65
+ emo_out = gr.Textbox(label="Emotion")
66
+ rep_out = gr.Textbox(label="AI Reply")
67
+ btn = gr.Button("Process")
68
+ btn.click(process, inputs=[audio_in, uid_state], outputs=[txt_out, emo_out, rep_out, uid_state])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
 
70
+ app = gr.mount_gradio_app(app, demo, path="/")
 
71
 
72
  if __name__ == "__main__":
73
  import uvicorn