import os import streamlit as st import tempfile import torch import torchaudio import transformers from transformers import pipeline, AutoModelForSequenceClassification, AutoTokenizer import plotly.express as px import logging import warnings import whisper import base64 import io import asyncio from concurrent.futures import ThreadPoolExecutor import streamlit.components.v1 as components # Suppress warnings logging.getLogger("torch").setLevel(logging.ERROR) logging.getLogger("transformers").setLevel(logging.ERROR) warnings.filterwarnings("ignore") os.environ["TOKENIZERS_PARALLELISM"] = "false" # Device setup device = torch.device("cuda" if torch.cuda.is_available() else "cpu") st.write(f"Using device: {device}") # Streamlit config st.set_page_config(layout="wide", page_title="Voice Sentiment Analysis") st.title("🎙 Voice Sentiment Analysis") st.markdown("Fast, accurate detection of emotions, sentiment, and sarcasm from voice or text.") # Global model cache @st.cache_resource def load_models(): whisper_model = whisper.load_model("base") emotion_tokenizer = AutoTokenizer.from_pretrained("bhadresh-savani/distilbert-base-uncased-emotion") emotion_model = AutoModelForSequenceClassification.from_pretrained("bhadresh-savani/distilbert-base-uncased-emotion") emotion_model = emotion_model.to(device).half() emotion_classifier = pipeline("text-classification", model=emotion_model, tokenizer=emotion_tokenizer, top_k=None, device=0 if torch.cuda.is_available() else -1) sarcasm_tokenizer = AutoTokenizer.from_pretrained("cardiffnlp/twitter-roberta-base-irony") sarcasm_model = AutoModelForSequenceClassification.from_pretrained("cardiffnlp/twitter-roberta-base-irony") sarcasm_model = sarcasm_model.to(device).half() sarcasm_classifier = pipeline("text-classification", model=sarcasm_model, tokenizer=sarcasm_tokenizer, device=0 if torch.cuda.is_available() else -1) return whisper_model, emotion_classifier, sarcasm_classifier whisper_model, emotion_classifier, sarcasm_classifier = load_models() # Emotion detection async def perform_emotion_detection(text): if not text or len(text.strip()) < 3: return {}, "neutral", {}, "NEUTRAL" try: results = emotion_classifier(text)[0] emotions_dict = {r['label']: r['score'] for r in results} filtered_emotions = {k: v for k, v in emotions_dict.items() if v > 0.01} top_emotion = max(filtered_emotions, key=filtered_emotions.get) positive_emotions = ["joy"] negative_emotions = ["anger", "disgust", "fear", "sadness"] sentiment = ("POSITIVE" if top_emotion in positive_emotions else "NEGATIVE" if top_emotion in negative_emotions else "NEUTRAL") emotion_map = {"joy": "😊", "anger": "😡", "disgust": "🤢", "fear": "😨", "sadness": "😭", "surprise": "😲"} return emotions_dict, top_emotion, emotion_map, sentiment except Exception as e: st.error(f"Emotion detection failed: {str(e)}") return {}, "neutral", {}, "NEUTRAL" # Sarcasm detection async def perform_sarcasm_detection(text): if not text or len(text.strip()) < 3: return False, 0.0 try: result = sarcasm_classifier(text)[0] is_sarcastic = result['label'] == "LABEL_1" sarcasm_score = result['score'] if is_sarcastic else 1 - result['score'] return is_sarcastic, sarcasm_score except Exception as e: st.error(f"Sarcasm detection failed: {str(e)}") return False, 0.0 # Audio validation def validate_audio(audio_path): try: waveform, sample_rate = torchaudio.load(audio_path) if waveform.abs().max() < 0.01: st.warning("Audio volume too low.") return False if waveform.shape[1] / sample_rate < 1: st.warning("Audio too short.") return False return True except: st.error("Invalid audio file.") return False # Audio transcription @st.cache_data def transcribe_audio(audio_path): try: waveform, sample_rate = torchaudio.load(audio_path) if sample_rate != 16000: resampler = torchaudio.transforms.Resample(sample_rate, 16000) waveform = resampler(waveform) with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_file: torchaudio.save(temp_file.name, waveform, 16000) result = whisper_model.transcribe(temp_file.name, language="en") os.remove(temp_file.name) return result["text"].strip() except Exception as e: st.error(f"Transcription failed: {str(e)}") return "" # Process uploaded audio def process_uploaded_audio(audio_file): try: ext = audio_file.name.split('.')[-1].lower() if ext not in ['wav', 'mp3', 'ogg']: st.error("Unsupported format. Use WAV, MP3, or OGG.") return None with tempfile.NamedTemporaryFile(suffix=f".{ext}", delete=False) as temp_file: temp_file.write(audio_file.getvalue()) temp_file_path = temp_file.name if not validate_audio(temp_file_path): os.remove(temp_file_path) return None return temp_file_path except Exception as e: st.error(f"Error processing audio: {str(e)}") return None # Process base64 audio def process_base64_audio(base64_data): try: base64_binary = base64_data.split(',')[1] binary_data = base64.b64decode(base64_binary) with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_file: temp_file.write(binary_data) temp_file_path = temp_file.name if not validate_audio(temp_file_path): os.remove(temp_file_path) return None return temp_file_path except Exception as e: st.error(f"Error processing audio data: {str(e)}") return None # Custom audio recorder def custom_audio_recorder(): audio_recorder_html = """