Spaces:
Sleeping
Sleeping
import gradio as gr | |
import spaces | |
import random | |
import json | |
import os | |
import string | |
from difflib import SequenceMatcher | |
from jiwer import wer | |
import torchaudio | |
from transformers import pipeline | |
# Load metadata | |
with open("common_voice_en_validated_249_hf_ready.json") as f: | |
data = json.load(f) | |
# Prepare dropdown options | |
ages = sorted(set(entry["age"] for entry in data)) | |
genders = sorted(set(entry["gender"] for entry in data)) | |
accents = sorted(set(entry["accent"] for entry in data)) | |
# Utility functions | |
def convert_to_wav(file_path): | |
wav_path = file_path.replace(".mp3", ".wav") | |
if not os.path.exists(wav_path): | |
waveform, sample_rate = torchaudio.load(file_path) | |
waveform = waveform.mean(dim=0, keepdim=True) | |
torchaudio.save(wav_path, waveform, sample_rate) | |
return wav_path | |
def highlight_differences(ref, hyp): | |
sm = SequenceMatcher(None, ref.split(), hyp.split()) | |
result = [] | |
for opcode, i1, i2, j1, j2 in sm.get_opcodes(): | |
if opcode == "equal": | |
result.extend(hyp.split()[j1:j2]) | |
else: | |
wrong = hyp.split()[j1:j2] | |
result.extend([f"<span style='color:red'>{w}</span>" for w in wrong]) | |
return " ".join(result) | |
def normalize(text): | |
text = text.lower() | |
text = text.translate(str.maketrans('', '', string.punctuation)) | |
return text.strip() | |
# Generate Audio | |
def generate_audio(age, gender, accent): | |
filtered = [ | |
entry for entry in data | |
if entry["age"] == age and entry["gender"] == gender and entry["accent"] == accent | |
] | |
if not filtered: | |
return None, "No matching sample." | |
sample = random.choice(filtered) | |
file_path = os.path.join("common_voice_en_validated_249", sample["path"]) | |
wav_file_path = convert_to_wav(file_path) | |
return wav_file_path, wav_file_path | |
# Transcribe & Compare (GPU Decorated) | |
def transcribe_audio(file_path): | |
if not file_path: | |
return "No file selected.", "", "", "", "", "", "" | |
filename_mp3 = os.path.basename(file_path).replace(".wav", ".mp3") | |
gold = "" | |
for entry in data: | |
if entry["path"].endswith(filename_mp3): | |
gold = normalize(entry["sentence"]) | |
break | |
if not gold: | |
return "Reference not found.", "", "", "", "", "", "" | |
model_ids = [ | |
"openai/whisper-tiny", # Smallest, multilingual | |
"openai/whisper-tiny.en", # Tiny, English-only | |
"openai/whisper-base", # Balanced, multilingual | |
"openai/whisper-base.en", # Base, English-only | |
"openai/whisper-medium", # Medium, multilingual | |
"openai/whisper-medium.en", # Medium, English-only | |
"distil-whisper/distil-large-v3.5", # Distilled from Whisper large, Faster & More accurate | |
"facebook/wav2vec2-base-960h", # Base model trained on 960h LibriSpeech (monolingual, English) | |
"facebook/wav2vec2-large-960h", #Larger model, better performance (monolingual, English) | |
"facebook/wav2vec2-large-960h-lv60-self", # Fine-tuned on 60k LibriLight hours | |
"facebook/hubert-large-ls960-ft", # Fine-tuned on LibriSpeech | |
] | |
outputs = {} | |
for model_id in model_ids: | |
try: | |
pipe = pipeline("automatic-speech-recognition", model=model_id) | |
text = pipe(file_path)["text"].strip().lower() | |
clean = normalize(text) | |
wer_score = wer(gold, clean) | |
outputs[model_id] = f"<b>{model_id} (WER: {wer_score:.2f}):</b><br>{highlight_differences(gold, clean)}" | |
except Exception as e: | |
outputs[model_id] = f"<b>{model_id}:</b><br><span style='color:red'>Error: {str(e)}</span>" | |
return (gold, *outputs.values()) | |
# Gradio UI | |
with gr.Blocks() as demo: | |
gr.Markdown("# Comparing ASR Models on Diverse English Speech Samples") | |
gr.Markdown(""" | |
This demo compares the transcription performance of several automatic speech recognition (ASR) models. | |
Users can select age, gender, and accent to generate diverse English audio samples. | |
The models are evaluated on their ability to transcribe those samples. | |
Data is sourced from 249 validated entries in the Common Voice English Delta Segment 21.0 release. | |
""") | |
with gr.Row(): | |
accent = gr.Dropdown(choices=accents, label="Accent", interactive=True) | |
gender = gr.Dropdown(choices=[], label="Gender", interactive=True) | |
age = gr.Dropdown(choices=[], label="Age", interactive=True) | |
def update_gender_options(selected_accent): | |
options = sorted(set(entry["gender"] for entry in data if entry["accent"] == selected_accent)) | |
return gr.update(choices=options, value=None) | |
def update_age_options(selected_accent, selected_gender): | |
options = sorted(set( | |
entry["age"] for entry in data | |
if entry["accent"] == selected_accent and entry["gender"] == selected_gender | |
)) | |
return gr.update(choices=options, value=None) | |
accent.change(update_gender_options, inputs=[accent], outputs=[gender]) | |
gender.change(update_age_options, inputs=[accent, gender], outputs=[age]) | |
generate_btn = gr.Button("Get Audio") | |
audio_output = gr.Audio(label="Audio", type="filepath", interactive=False) | |
file_path_output = gr.Textbox(label="Audio File Path", visible=False) | |
generate_btn.click(generate_audio, [age, gender, accent], [audio_output, file_path_output]) | |
transcribe_btn = gr.Button("Transcribe with All Models") | |
gold_text = gr.Textbox(label="Reference (Gold Standard)") | |
whisper_tiny_html = gr.HTML(label="Whisper Tiny") | |
whisper_tiny_en_html = gr.HTML(label="Whisper Tiny English") | |
whisper_base_html = gr.HTML(label="Whisper Base") | |
whisper_base_en_html = gr.HTML(label="Whisper Base English") | |
whisper_medium_html = gr.HTML(label="Whisper Medium") | |
whisper_medium_en_html = gr.HTML(label="Whisper Medium English") | |
distil_html = gr.HTML(label="Distil-Whisper Large") | |
wav2vec_base_html = gr.HTML(label="Wav2Vec2 Base") | |
wav2vec_large_html = gr.HTML(label="Wav2Vec2 Large") | |
wav2vec_lv60_html = gr.HTML(label="Wav2Vec2 Large + LibriLight") | |
hubert_html = gr.HTML(label="HuBERT Large") | |
transcribe_btn.click( | |
transcribe_audio, | |
inputs=[file_path_output], | |
outputs=[ | |
gold_text, | |
whisper_tiny_html, | |
whisper_tiny_en_html, | |
whisper_base_html, | |
whisper_base_en_html, | |
whisper_medium_html, | |
whisper_medium_en_html, | |
distil_html, | |
wav2vec_base_html, | |
wav2vec_large_html, | |
wav2vec_lv60_html, | |
hubert_html, | |
], | |
) | |
demo.launch() |