import gradio as gr import plotly.express as px import pandas as pd import logging import whisper import torch import torch.nn as nn import torch.nn.functional as F import numpy as np import pandas as pd from torch.nn.functional import silu from torch.nn.functional import softplus from einops import rearrange, repeat, einsum from transformers import AutoTokenizer, AutoModel from torch import Tensor from einops import rearrange from model import Mamba logging.basicConfig(level=logging.INFO) def plotly_plot_text(text): data = pd.DataFrame() data['Emotion'] = ['😠 anger', '🤢 disgust', '😨 fear', '😄 joy/happiness', '😐 neutral', '😢 sadness', '😲 surprise/enthusiasm'] data['Probability'] = model.predict_proba([text])[0].tolist() p = px.bar(data, x='Emotion', y='Probability', color="Probability") return ( p, f"🗣️ Transcription:\n{text}", f"## 🏆 Dominant Emotion: {data['Emotion'].values[np.argmax(np.array(data['Probability']))]}" ) def transcribe_audio(audio_path): whisper_model = whisper.load_model("base") try: result = whisper_model.transcribe(audio_path, fp16=False) return result.get('text', '') except Exception as e: logging.error(f"Transcription failed: {e}") return "" def plotly_plot_audio(audio_path): data = pd.DataFrame() data['Emotion'] = ['😠 anger', '🤢 disgust', '😨 fear', '😄 joy/happiness', '😐 neutral', '😢 sadness', '😲 surprise/enthusiasm'] try: text = transcribe_audio(audio_path) data['Probability'] = model.predict_proba([text])[0].tolist() if text.strip() else [0.0] * data.shape[0] p = px.bar(data, x='Emotion', y='Probability', color="Probability") return ( p, f"## ✔️ Dominant Emotion: {data['Emotion'].values[np.argmax(np.array(data['Probability']))]}" ) except Exception as e: logging.error(f"Processing failed: {e}") data['Probability'] = [0] * data.shape[0] p = px.bar(data, x='Emotion', y='Probability', color="Probability") return ( p, "❌ Error processing audio", "⚠️ Processing Error" ) def plotly_plot_audio(audio_path): data = pd.DataFrame() data['Emotion'] = ['😠 anger', '🤢 disgust', '😨 fear', '😄 joy/happiness', '😐 neutral', '😢 sadness', '😲 surprise/enthusiasm'] try: text = transcribe_audio(audio_path) data['Probability'] = model.predict_proba([text])[0].tolist() if text.strip() else [0.0] * data.shape[0] p = px.bar(data, x='Emotion', y='Probability', color="Probability") return ( p, f"🎤 Transcription:\n{text}", f"## ✔️ Dominant Emotion: {data['Emotion'].values[np.argmax(np.array(data['Probability']))]}" ) except Exception as e: logging.error(f"Processing failed: {e}") data['Probability'] = [0] * data.shape[0] p = px.bar(data, x='Emotion', y='Probability', color="Probability") return ( p, "❌ Error processing audio", "⚠️ Processing Error" ) def create_demo_text(): with gr.Blocks(theme='Nymbo/rounded-gradient', css=".gradio-container {background-color: #F0F8FF}", title="Emotion Detection") as demo: gr.Markdown("# Text-based bilingual emotion recognition") with gr.Row(): text_input = gr.Textbox(label="Write Text") with gr.Row(): top_emotion = gr.Markdown("## ✔️ Dominant Emotion: Waiting for input ...", elem_classes="dominant-emotion") with gr.Row(): text_plot = gr.Plot(label="Text Analysis") text_input.change(fn=plotly_plot_text, inputs=text_input, outputs=[text_plot, top_emotion]) return demo def create_demo_audio(): with gr.Blocks(theme='Nymbo/rounded-gradient', css=".gradio-container {background-color: #F0F8FF}", title="Emotion Detection") as demo: gr.Markdown("# Text-based bilingual emotion recognition with audio transcription") with gr.Row(): audio_input = gr.Audio( sources=["upload", "microphone"], type="filepath", label="Record or Upload Audio", format="wav", interactive=True ) with gr.Row(): top_emotion = gr.Markdown("## ✔️ Dominant Emotion: Waiting for input ...", elem_classes="dominant-emotion") with gr.Row(): text_plot = gr.Plot(label="Text Analysis") transcription = gr.Textbox( label="📜 Transcription Results", placeholder="Transcribed text will appear here...", lines=3, max_lines=6 ) audio_input.change(fn=plotly_plot_audio, inputs=audio_input, outputs=[text_plot, transcription, top_emotion]) return demo def create_demo(): text = create_demo_text() audio = create_demo_audio() demo = gr.TabbedInterface( [text, audio], ["Text Prediction", "Transcribed Audio Prediction"], ) return demo if __name__ == "__main__": device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = Mamba(num_layers = 2, d_input = 1024, d_model = 512, num_classes=7, model_name='jina', pooling=None).to(device) checkpoint = torch.load("Mamba_jina_checkpoint.pth", map_location=torch.device('cpu')) model.load_state_dict(checkpoint['model_state_dict']) demo = create_demo() demo.launch()