import os import time import streamlit as st from transformers import pipeline from pydub import AudioSegment import tempfile import torch from datasets import load_dataset import jiwer import librosa import soundfile # Page configuration st.set_page_config(page_title="Audio-to-Text with Grammar Check", page_icon="🎤", layout="wide") # Model configurations (three ASR models) MODELS = { "automatic-speech-recognition": { "whisper-tiny": "openai/whisper-tiny", "whisper-small": "openai/whisper-small", "whisper-base": "openai/whisper-base" }, "text2text-generation": { "flan-t5-base": "pszemraj/grammar-synthesis-small" } } # Cached model loading @st.cache_resource def load_model(model_key, task): device = "cuda" if torch.cuda.is_available() else "cpu" with st.spinner(f"Loading {model_key} model..."): return pipeline(task, model=MODELS[task][model_key], device=device) def convert_audio_to_wav(audio_file): """Convert uploaded audio to WAV format""" try: with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_file: audio = AudioSegment.from_file(audio_file) audio.export(tmp_file.name, format="wav") return tmp_file.name except Exception as e: st.error(f"Audio conversion failed: {str(e)}") return None def evaluate_asr_accuracy(transcription, reference): """Calculate WER and CER accuracy""" ref_processed = reference.lower().strip() hyp_processed = transcription.lower().strip() if not ref_processed: return 0.0, 0.0 wer = jiwer.wer(ref_processed, hyp_processed) cer = jiwer.cer(ref_processed, hyp_processed) return wer, cer # Cached dataset loading with audio decoding @st.cache_data(show_spinner=False) def load_cached_dataset(num_samples=1): st.info("Loading dataset...") try: dataset = load_dataset( "librispeech_asr", "clean", split="test", streaming=True, trust_remote_code=True ).take(num_samples) return [sample for sample in dataset] except Exception as e: st.error(f"Dataset loading failed: {str(e)}") return None def main(): st.title("🎤 Audio Grammar Evaluation System for Language Learners") # Session state for persisting results if "transcription" not in st.session_state: st.session_state.transcription = "" if "grammar_feedback" not in st.session_state: st.session_state.grammar_feedback = "" # Audio processing tab tab1, tab2 = st.tabs(["Audio Processor", "Model Evaluator"]) with tab1: st.subheader("Upload & Process Audio") audio_file = st.file_uploader("Upload audio file", type=["mp3", "wav", "ogg", "m4a"]) if audio_file: st.audio(audio_file, format="audio/wav") wav_path = convert_audio_to_wav(audio_file) if wav_path: asr_model = load_model("whisper-tiny", "automatic-speech-recognition") with st.spinner("Generating transcription..."): transcription = asr_model(wav_path)["text"] st.session_state.transcription = transcription st.text_area("Transcription Result", transcription, height=150) if st.session_state.transcription: grammar_model = load_model("flan-t5-base", "text2text-generation") with st.spinner("Checking grammar..."): grammar_feedback = grammar_model( f"Correct the grammar in: {transcription}" )[0]["generated_text"] st.session_state.grammar_feedback = grammar_feedback st.success("Grammar Corrected Text:") st.write(grammar_feedback) os.unlink(wav_path) with tab2: st.subheader("Triple Model Evaluation with Runtime") # Model selection model_options = list(MODELS["automatic-speech-recognition"].keys()) model1, model2, model3 = st.columns(3) with model1: selected_model1 = st.selectbox("Select Model 1", model_options, index=0) with model2: selected_model2 = st.selectbox("Select Model 2", model_options, index=1) with model3: selected_model3 = st.selectbox("Select Model 3", model_options, index=2) if st.button("Run Triple Evaluation"): dataset = load_cached_dataset(num_samples=1) if not dataset: return # Load three models model1 = load_model(selected_model1, "automatic-speech-recognition") model2 = load_model(selected_model2, "automatic-speech-recognition") model3 = load_model(selected_model3, "automatic-speech-recognition") results = [] total_runtime_model1 = 0.0 total_runtime_model2 = 0.0 total_runtime_model3 = 0.0 for i, sample in enumerate(dataset): with st.spinner(f"Processing Sample..."): audio_array = sample["audio"]["array"] reference_text = sample["text"] # Evaluate Model 1 start_time = time.perf_counter() transcription1 = model1(audio_array)["text"] end_time = time.perf_counter() runtime1 = end_time - start_time total_runtime_model1 += runtime1 wer1, cer1 = evaluate_asr_accuracy(transcription1, reference_text) # Evaluate Model 2 start_time = time.perf_counter() transcription2 = model2(audio_array)["text"] end_time = time.perf_counter() runtime2 = end_time - start_time total_runtime_model2 += runtime2 wer2, cer2 = evaluate_asr_accuracy(transcription2, reference_text) # Evaluate Model 3 start_time = time.perf_counter() transcription3 = model3(audio_array)["text"] end_time = time.perf_counter() runtime3 = end_time - start_time total_runtime_model3 += runtime3 wer3, cer3 = evaluate_asr_accuracy(transcription3, reference_text) # Organize results model1_result = { "Model": selected_model1, "Runtime": f"{runtime1:.4f}s", "WER": f"{wer1*100:.2f}%", "CER": f"{cer1*100:.2f}%" } model2_result = { "Model": selected_model2, "Runtime": f"{runtime2:.4f}s", "WER": f"{wer2*100:.2f}%", "CER": f"{cer2*100:.2f}%" } model3_result = { "Model": selected_model3, "Runtime": f"{runtime3:.4f}s", "WER": f"{wer3*100:.2f}%", "CER": f"{cer3*100:.2f}%" } results.extend([model1_result, model2_result, model3_result]) # Display results st.subheader("Model Evaluation Results") st.table(results) if __name__ == "__main__": main()