dashakoryakovskaya commited on
Commit
f4dbd19
Β·
verified Β·
1 Parent(s): 538dfc3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +59 -18
app.py CHANGED
@@ -62,21 +62,35 @@ def plotly_plot_audio(audio_path):
62
  "⚠️ Processing Error"
63
  )
64
 
65
- def create_demo():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  with gr.Blocks(theme=gr.themes.Soft(), title="Emotion Detection") as demo:
67
  gr.Markdown("# Text-based bilingual emotion recognition")
68
 
69
  with gr.Row():
70
- with gr.Column():
71
- audio_input = gr.Audio(
72
- sources=["upload", "microphone"],
73
- type="filepath",
74
- label="Record or Upload Audio",
75
- format="wav",
76
- interactive=True
77
- )
78
- with gr.Column():
79
- text_input = gr.Text(label="Write Text")
80
 
81
  with gr.Row():
82
  top_emotion = gr.Markdown("## πŸ† Dominant Emotion: Waiting for input ...",
@@ -85,24 +99,51 @@ def create_demo():
85
  with gr.Row():
86
  text_plot = gr.Plot(label="Text Analysis")
87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
  transcription = gr.Textbox(
89
  label="πŸ“œ Transcription Results",
90
  placeholder="Transcribed text will appear here...",
91
  lines=3,
92
  max_lines=6
93
  )
94
-
95
- if text_input is not None:
96
- text_input.change(fn=plotly_plot_text, inputs=text_input, outputs=[text_plot, transcription, top_emotion])
97
- elif audio_input:
98
- audio_input.change(fn=plotly_plot_audio, inputs=audio_input, outputs=[text_plot, transcription, top_emotion])
99
  return demo
100
 
 
 
 
 
 
 
 
 
 
101
 
102
  if __name__ == "__main__":
103
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
104
  model = Mamba(num_layers = 2, d_input = 1024, d_model = 512, num_classes=7, model_name='jina', pooling=None).to(device)
105
- checkpoint = torch.load("Mamba_jina_checkpoint.pth", map_location=torch.device('cpu'))
106
  model.load_state_dict(checkpoint['model_state_dict'])
 
107
  demo = create_demo()
108
  demo.launch()
 
62
  "⚠️ Processing Error"
63
  )
64
 
65
+ def plotly_plot_audio(audio_path):
66
+ data = pd.DataFrame()
67
+ data['Emotion'] = ['😠 anger', '🀒 disgust', '😨 fear', 'πŸ˜„ joy/happiness', '😐 neutral', '😒 sadness', '😲 surprise/enthusiasm']
68
+ try:
69
+ text = transcribe_audio(audio_path)
70
+ data['Probability'] = model.predict_proba([text])[0].tolist() if text.strip() else [0.0] * data.shape[0]
71
+ p = px.bar(data, x='Emotion', y='Probability', color="Probability")
72
+ return (
73
+ p,
74
+ f"πŸ—£οΈ Transcription:\n{text}",
75
+ f"## πŸ† Dominant Emotion: {data['Emotion'].values[np.argmax(np.array(data['Probability']))]}"
76
+ )
77
+
78
+ except Exception as e:
79
+ logging.error(f"Processing failed: {e}")
80
+ data['Probability'] = [0] * data.shape[0]
81
+ p = px.bar(data, x='Emotion', y='Probability', color="Probability")
82
+ return (
83
+ p,
84
+ "❌ Error processing audio",
85
+ "⚠️ Processing Error"
86
+ )
87
+
88
+ def create_demo_text():
89
  with gr.Blocks(theme=gr.themes.Soft(), title="Emotion Detection") as demo:
90
  gr.Markdown("# Text-based bilingual emotion recognition")
91
 
92
  with gr.Row():
93
+ text_input = gr.Textbox(label="Write Text")
 
 
 
 
 
 
 
 
 
94
 
95
  with gr.Row():
96
  top_emotion = gr.Markdown("## πŸ† Dominant Emotion: Waiting for input ...",
 
99
  with gr.Row():
100
  text_plot = gr.Plot(label="Text Analysis")
101
 
102
+ text_input.change(fn=plotly_plot_text, inputs=text_input, outputs=[text_plot, top_emotion])
103
+ return demo
104
+
105
+ def create_demo_audio():
106
+ with gr.Blocks(theme=gr.themes.Soft(), title="Emotion Detection") as demo:
107
+ gr.Markdown("# Text-based bilingual emotion recognition")
108
+
109
+ with gr.Row():
110
+ audio_input = gr.Audio(
111
+ sources=["upload", "microphone"],
112
+ type="filepath",
113
+ label="Record or Upload Audio",
114
+ format="wav",
115
+ interactive=True
116
+ )
117
+ with gr.Row():
118
+ top_emotion = gr.Markdown("## πŸ† Dominant Emotion: Waiting for input ...",
119
+ elem_classes="dominant-emotion")
120
+
121
+ with gr.Row():
122
+ text_plot = gr.Plot(label="Text Analysis")
123
+
124
  transcription = gr.Textbox(
125
  label="πŸ“œ Transcription Results",
126
  placeholder="Transcribed text will appear here...",
127
  lines=3,
128
  max_lines=6
129
  )
130
+ audio_input.change(fn=plotly_plot_audio, inputs=audio_input, outputs=[text_plot, transcription, top_emotion])
 
 
 
 
131
  return demo
132
 
133
+ def create_demo():
134
+ text = create_demo_text()
135
+ audio = create_demo_audio()
136
+ demo = gr.TabbedInterface(
137
+ [text, audio],
138
+ ["Text Prediction", "Transcribed Audio Prediction"],
139
+ )
140
+ return demo
141
+
142
 
143
  if __name__ == "__main__":
 
144
  model = Mamba(num_layers = 2, d_input = 1024, d_model = 512, num_classes=7, model_name='jina', pooling=None).to(device)
145
+ checkpoint = torch.load("models/Mamba_jina_checkpoint.pth", map_location=torch.device('cpu'))
146
  model.load_state_dict(checkpoint['model_state_dict'])
147
+
148
  demo = create_demo()
149
  demo.launch()