Yilin0601's picture
Update app.py
8a4ae78 verified
import gradio as gr
import torch
import numpy as np
import librosa
from transformers import Wav2Vec2ForSequenceClassification, Wav2Vec2FeatureExtractor
# --------------------------------------------------
# Load Your Fine-Tuned Model for Fluency Prediction
# --------------------------------------------------
# This model was fine-tuned with labels remapped from [3..10] to [0..7].
# Ensure that "Yilin0601/wav2vec2-fluency-checkpoints" is your correct repo.
model = Wav2Vec2ForSequenceClassification.from_pretrained(
"Yilin0601/wav2vec2-fluency-checkpoints"
)
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
"Yilin0601/wav2vec2-fluency-checkpoints"
)
# --------------------------------------------------
# Prediction Function
# --------------------------------------------------
def predict(audio):
if audio is None:
return "No audio provided."
# Gradio returns audio as (sample_rate, np.array)
sample_rate, audio_data = audio
# Ensure audio is in floating-point (librosa requires float32 or float64)
if audio_data.dtype not in [np.float32, np.float64]:
audio_data = audio_data.astype(np.float32)
# Convert stereo to mono if needed
if len(audio_data.shape) > 1 and audio_data.shape[1] > 1:
audio_data = np.mean(audio_data, axis=1)
# Resample to 16 kHz if necessary
if sample_rate != 16000:
audio_data = librosa.resample(audio_data, orig_sr=sample_rate, target_sr=16000)
# Extract features using the feature extractor
inputs = feature_extractor(
audio_data,
sampling_rate=16000,
return_tensors="pt",
padding=True
)
# Model inference
model.eval()
with torch.no_grad():
logits = model(**inputs).logits
# The model outputs an 8-class prediction (0..7), corresponding to original fluency scores [3..10]
pred_class = torch.argmax(logits, dim=-1).item()
predicted_level = pred_class + 3 # Map back to [3..10]
return f"Predicted Fluency Level: {predicted_level}"
# --------------------------------------------------
# Gradio Interface
# --------------------------------------------------
iface = gr.Interface(
fn=predict,
inputs=gr.Audio(type="numpy", label="Record or Upload Audio"),
outputs="text",
title="L2 English Fluency Predictor",
description=(
"<p style='font-size:16px;'>"
"This demo predicts your English fluency level on a scale from 0 to 10. "
"It uses a fine-tuned <b>facebook/wav2vec2-base-960h</b> model trained on the "
"<b>DynamicSuperb/L2EnglishAccuracy_speechocean762-Scoring</b> dataset, which contains "
"745 labeled audio recordings of non-native English speakers. "
"To get your fluency score, simply record or upload an audio file. "
"<br><br>"
"<b>Note:</b> This prediction is for demo purposes and should be interpreted with caution. "
"</p>"
),
allow_flagging="never"
)
if __name__ == "__main__":
iface.launch()