File size: 5,695 Bytes
47c9424 c01aab5 47c9424 f4dbd19 c01aab5 f4dbd19 c01aab5 47c9424 f4dbd19 47c9424 c01aab5 47c9424 f4dbd19 c01aab5 f4dbd19 c01aab5 f4dbd19 47c9424 f4dbd19 47c9424 f4dbd19 47c9424 066685c 47c9424 4959d33 47c9424 f4dbd19 47c9424 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 |
import gradio as gr
import plotly.express as px
import pandas as pd
import logging
import whisper
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pandas as pd
from torch.nn.functional import silu
from torch.nn.functional import softplus
from einops import rearrange, repeat, einsum
from transformers import AutoTokenizer, AutoModel
from torch import Tensor
from einops import rearrange
from model import Mamba
logging.basicConfig(level=logging.INFO)
def plotly_plot_text(text):
data = pd.DataFrame()
data['Emotion'] = ['π anger', 'π€’ disgust', 'π¨ fear', 'π joy/happiness', 'π neutral', 'π’ sadness', 'π² surprise/enthusiasm']
data['Probability'] = model.predict_proba([text])[0].tolist()
p = px.bar(data, x='Emotion', y='Probability', color="Probability")
return (
p,
f"π£οΈ Transcription:\n{text}",
f"## π Dominant Emotion: {data['Emotion'].values[np.argmax(np.array(data['Probability']))]}"
)
def transcribe_audio(audio_path):
whisper_model = whisper.load_model("base")
try:
result = whisper_model.transcribe(audio_path, fp16=False)
return result.get('text', '')
except Exception as e:
logging.error(f"Transcription failed: {e}")
return ""
def plotly_plot_audio(audio_path):
data = pd.DataFrame()
data['Emotion'] = ['π anger', 'π€’ disgust', 'π¨ fear', 'π joy/happiness', 'π neutral', 'π’ sadness', 'π² surprise/enthusiasm']
try:
text = transcribe_audio(audio_path)
data['Probability'] = model.predict_proba([text])[0].tolist() if text.strip() else [0.0] * data.shape[0]
p = px.bar(data, x='Emotion', y='Probability', color="Probability")
return (
p,
f"## βοΈ Dominant Emotion: {data['Emotion'].values[np.argmax(np.array(data['Probability']))]}"
)
except Exception as e:
logging.error(f"Processing failed: {e}")
data['Probability'] = [0] * data.shape[0]
p = px.bar(data, x='Emotion', y='Probability', color="Probability")
return (
p,
"β Error processing audio",
"β οΈ Processing Error"
)
def plotly_plot_audio(audio_path):
data = pd.DataFrame()
data['Emotion'] = ['π anger', 'π€’ disgust', 'π¨ fear', 'π joy/happiness', 'π neutral', 'π’ sadness', 'π² surprise/enthusiasm']
try:
text = transcribe_audio(audio_path)
data['Probability'] = model.predict_proba([text])[0].tolist() if text.strip() else [0.0] * data.shape[0]
p = px.bar(data, x='Emotion', y='Probability', color="Probability")
return (
p,
f"π€ Transcription:\n{text}",
f"## βοΈ Dominant Emotion: {data['Emotion'].values[np.argmax(np.array(data['Probability']))]}"
)
except Exception as e:
logging.error(f"Processing failed: {e}")
data['Probability'] = [0] * data.shape[0]
p = px.bar(data, x='Emotion', y='Probability', color="Probability")
return (
p,
"β Error processing audio",
"β οΈ Processing Error"
)
def create_demo_text():
with gr.Blocks(theme='Nymbo/rounded-gradient', css=".gradio-container {background-color: #F0F8FF}", title="Emotion Detection") as demo:
gr.Markdown("# Text-based bilingual emotion recognition")
with gr.Row():
text_input = gr.Textbox(label="Write Text")
with gr.Row():
top_emotion = gr.Markdown("## βοΈ Dominant Emotion: Waiting for input ...",
elem_classes="dominant-emotion")
with gr.Row():
text_plot = gr.Plot(label="Text Analysis")
text_input.change(fn=plotly_plot_text, inputs=text_input, outputs=[text_plot, top_emotion])
return demo
def create_demo_audio():
with gr.Blocks(theme='Nymbo/rounded-gradient', css=".gradio-container {background-color: #F0F8FF}", title="Emotion Detection") as demo:
gr.Markdown("# Text-based bilingual emotion recognition with audio transcription")
with gr.Row():
audio_input = gr.Audio(
sources=["upload", "microphone"],
type="filepath",
label="Record or Upload Audio",
format="wav",
interactive=True
)
with gr.Row():
top_emotion = gr.Markdown("## βοΈ Dominant Emotion: Waiting for input ...",
elem_classes="dominant-emotion")
with gr.Row():
text_plot = gr.Plot(label="Text Analysis")
transcription = gr.Textbox(
label="π Transcription Results",
placeholder="Transcribed text will appear here...",
lines=3,
max_lines=6
)
audio_input.change(fn=plotly_plot_audio, inputs=audio_input, outputs=[text_plot, transcription, top_emotion])
return demo
def create_demo():
text = create_demo_text()
audio = create_demo_audio()
demo = gr.TabbedInterface(
[text, audio],
["Text Prediction", "Transcribed Audio Prediction"],
)
return demo
if __name__ == "__main__":
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = Mamba(num_layers = 2, d_input = 1024, d_model = 512, num_classes=7, model_name='jina', pooling=None).to(device)
checkpoint = torch.load("Mamba_jina_checkpoint.pth", map_location=torch.device('cpu'))
model.load_state_dict(checkpoint['model_state_dict'])
demo = create_demo()
demo.launch() |