dashakoryakovskaya's picture
Update app.py
c01aab5 verified
import gradio as gr
import plotly.express as px
import pandas as pd
import logging
import whisper
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pandas as pd
from torch.nn.functional import silu
from torch.nn.functional import softplus
from einops import rearrange, repeat, einsum
from transformers import AutoTokenizer, AutoModel
from torch import Tensor
from einops import rearrange
from model import Mamba
logging.basicConfig(level=logging.INFO)
def plotly_plot_text(text):
data = pd.DataFrame()
data['Emotion'] = ['😠 anger', '🀒 disgust', '😨 fear', 'πŸ˜„ joy/happiness', '😐 neutral', '😒 sadness', '😲 surprise/enthusiasm']
data['Probability'] = model.predict_proba([text])[0].tolist()
p = px.bar(data, x='Emotion', y='Probability', color="Probability")
return (
p,
f"πŸ—£οΈ Transcription:\n{text}",
f"## πŸ† Dominant Emotion: {data['Emotion'].values[np.argmax(np.array(data['Probability']))]}"
)
def transcribe_audio(audio_path):
whisper_model = whisper.load_model("base")
try:
result = whisper_model.transcribe(audio_path, fp16=False)
return result.get('text', '')
except Exception as e:
logging.error(f"Transcription failed: {e}")
return ""
def plotly_plot_audio(audio_path):
data = pd.DataFrame()
data['Emotion'] = ['😠 anger', '🀒 disgust', '😨 fear', 'πŸ˜„ joy/happiness', '😐 neutral', '😒 sadness', '😲 surprise/enthusiasm']
try:
text = transcribe_audio(audio_path)
data['Probability'] = model.predict_proba([text])[0].tolist() if text.strip() else [0.0] * data.shape[0]
p = px.bar(data, x='Emotion', y='Probability', color="Probability")
return (
p,
f"## βœ”οΈ Dominant Emotion: {data['Emotion'].values[np.argmax(np.array(data['Probability']))]}"
)
except Exception as e:
logging.error(f"Processing failed: {e}")
data['Probability'] = [0] * data.shape[0]
p = px.bar(data, x='Emotion', y='Probability', color="Probability")
return (
p,
"❌ Error processing audio",
"⚠️ Processing Error"
)
def plotly_plot_audio(audio_path):
data = pd.DataFrame()
data['Emotion'] = ['😠 anger', '🀒 disgust', '😨 fear', 'πŸ˜„ joy/happiness', '😐 neutral', '😒 sadness', '😲 surprise/enthusiasm']
try:
text = transcribe_audio(audio_path)
data['Probability'] = model.predict_proba([text])[0].tolist() if text.strip() else [0.0] * data.shape[0]
p = px.bar(data, x='Emotion', y='Probability', color="Probability")
return (
p,
f"🎀 Transcription:\n{text}",
f"## βœ”οΈ Dominant Emotion: {data['Emotion'].values[np.argmax(np.array(data['Probability']))]}"
)
except Exception as e:
logging.error(f"Processing failed: {e}")
data['Probability'] = [0] * data.shape[0]
p = px.bar(data, x='Emotion', y='Probability', color="Probability")
return (
p,
"❌ Error processing audio",
"⚠️ Processing Error"
)
def create_demo_text():
with gr.Blocks(theme='Nymbo/rounded-gradient', css=".gradio-container {background-color: #F0F8FF}", title="Emotion Detection") as demo:
gr.Markdown("# Text-based bilingual emotion recognition")
with gr.Row():
text_input = gr.Textbox(label="Write Text")
with gr.Row():
top_emotion = gr.Markdown("## βœ”οΈ Dominant Emotion: Waiting for input ...",
elem_classes="dominant-emotion")
with gr.Row():
text_plot = gr.Plot(label="Text Analysis")
text_input.change(fn=plotly_plot_text, inputs=text_input, outputs=[text_plot, top_emotion])
return demo
def create_demo_audio():
with gr.Blocks(theme='Nymbo/rounded-gradient', css=".gradio-container {background-color: #F0F8FF}", title="Emotion Detection") as demo:
gr.Markdown("# Text-based bilingual emotion recognition with audio transcription")
with gr.Row():
audio_input = gr.Audio(
sources=["upload", "microphone"],
type="filepath",
label="Record or Upload Audio",
format="wav",
interactive=True
)
with gr.Row():
top_emotion = gr.Markdown("## βœ”οΈ Dominant Emotion: Waiting for input ...",
elem_classes="dominant-emotion")
with gr.Row():
text_plot = gr.Plot(label="Text Analysis")
transcription = gr.Textbox(
label="πŸ“œ Transcription Results",
placeholder="Transcribed text will appear here...",
lines=3,
max_lines=6
)
audio_input.change(fn=plotly_plot_audio, inputs=audio_input, outputs=[text_plot, transcription, top_emotion])
return demo
def create_demo():
text = create_demo_text()
audio = create_demo_audio()
demo = gr.TabbedInterface(
[text, audio],
["Text Prediction", "Transcribed Audio Prediction"],
)
return demo
if __name__ == "__main__":
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = Mamba(num_layers = 2, d_input = 1024, d_model = 512, num_classes=7, model_name='jina', pooling=None).to(device)
checkpoint = torch.load("Mamba_jina_checkpoint.pth", map_location=torch.device('cpu'))
model.load_state_dict(checkpoint['model_state_dict'])
demo = create_demo()
demo.launch()