|
import gradio as gr |
|
import torch |
|
import numpy as np |
|
import librosa |
|
from transformers import Wav2Vec2ForSequenceClassification, Wav2Vec2FeatureExtractor |
|
|
|
|
|
|
|
|
|
|
|
|
|
model = Wav2Vec2ForSequenceClassification.from_pretrained( |
|
"Yilin0601/wav2vec2-fluency-checkpoints" |
|
) |
|
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained( |
|
"Yilin0601/wav2vec2-fluency-checkpoints" |
|
) |
|
|
|
|
|
|
|
|
|
def predict(audio): |
|
if audio is None: |
|
return "No audio provided." |
|
|
|
|
|
sample_rate, audio_data = audio |
|
|
|
|
|
if audio_data.dtype not in [np.float32, np.float64]: |
|
audio_data = audio_data.astype(np.float32) |
|
|
|
|
|
if len(audio_data.shape) > 1 and audio_data.shape[1] > 1: |
|
audio_data = np.mean(audio_data, axis=1) |
|
|
|
|
|
if sample_rate != 16000: |
|
audio_data = librosa.resample(audio_data, orig_sr=sample_rate, target_sr=16000) |
|
|
|
|
|
inputs = feature_extractor( |
|
audio_data, |
|
sampling_rate=16000, |
|
return_tensors="pt", |
|
padding=True |
|
) |
|
|
|
|
|
model.eval() |
|
with torch.no_grad(): |
|
logits = model(**inputs).logits |
|
|
|
|
|
pred_class = torch.argmax(logits, dim=-1).item() |
|
predicted_level = pred_class + 3 |
|
|
|
return f"Predicted Fluency Level: {predicted_level}" |
|
|
|
|
|
|
|
|
|
iface = gr.Interface( |
|
fn=predict, |
|
inputs=gr.Audio(type="numpy", label="Record or Upload Audio"), |
|
outputs="text", |
|
title="L2 English Fluency Predictor", |
|
description=( |
|
"<p style='font-size:16px;'>" |
|
"This demo predicts your English fluency level on a scale from 0 to 10. " |
|
"It uses a fine-tuned <b>facebook/wav2vec2-base-960h</b> model trained on the " |
|
"<b>DynamicSuperb/L2EnglishAccuracy_speechocean762-Scoring</b> dataset, which contains " |
|
"745 labeled audio recordings of non-native English speakers. " |
|
"To get your fluency score, simply record or upload an audio file. " |
|
"<br><br>" |
|
"<b>Note:</b> This prediction is for demo purposes and should be interpreted with caution. " |
|
"</p>" |
|
), |
|
allow_flagging="never" |
|
) |
|
|
|
if __name__ == "__main__": |
|
iface.launch() |
|
|