Yilin0601 commited on
Commit
227aa4c
·
verified ·
1 Parent(s): 20b8be9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -11
app.py CHANGED
@@ -5,10 +5,10 @@ import librosa
5
  from transformers import Wav2Vec2ForSequenceClassification, Wav2Vec2FeatureExtractor
6
 
7
  # --------------------------------------------------
8
- # Load Your Fine-Tuned Model
9
  # --------------------------------------------------
10
  # This model was fine-tuned with labels remapped from [3..10] to [0..7].
11
- # Make sure the model repo name below is correct and accessible.
12
  model = Wav2Vec2ForSequenceClassification.from_pretrained(
13
  "Yilin0601/wav2vec2-fluency-checkpoints"
14
  )
@@ -22,11 +22,11 @@ feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
22
  def predict(audio):
23
  if audio is None:
24
  return "No audio provided."
25
-
26
- # Gradio provides audio as (sample_rate, np.array)
27
  sample_rate, audio_data = audio
28
 
29
- # Ensure the audio is floating-point (librosa requires float32 or float64)
30
  if audio_data.dtype not in [np.float32, np.float64]:
31
  audio_data = audio_data.astype(np.float32)
32
 
@@ -34,11 +34,11 @@ def predict(audio):
34
  if len(audio_data.shape) > 1 and audio_data.shape[1] > 1:
35
  audio_data = np.mean(audio_data, axis=1)
36
 
37
- # Resample to 16 kHz if needed
38
  if sample_rate != 16000:
39
  audio_data = librosa.resample(audio_data, orig_sr=sample_rate, target_sr=16000)
40
 
41
- # Extract features
42
  inputs = feature_extractor(
43
  audio_data,
44
  sampling_rate=16000,
@@ -51,11 +51,11 @@ def predict(audio):
51
  with torch.no_grad():
52
  logits = model(**inputs).logits
53
 
54
- # The model output is an 8-class prediction (0..7), corresponding to original labels 3..10
55
  pred_class = torch.argmax(logits, dim=-1).item()
56
  predicted_level = pred_class + 3 # Map back to [3..10]
57
 
58
- return f"Predicted Level: {predicted_level}"
59
 
60
  # --------------------------------------------------
61
  # Gradio Interface
@@ -66,8 +66,10 @@ iface = gr.Interface(
66
  outputs="text",
67
  title="L2 English Fluency Predictor",
68
  description=(
69
- "This demo uses a fine-tuned Wav2Vec2ForSequenceClassification model with labels for accuracy evaluation "
70
- "mapped from 0 to 10. Record or upload audio to see the predicted level."
 
 
71
  ),
72
  allow_flagging="never"
73
  )
 
5
  from transformers import Wav2Vec2ForSequenceClassification, Wav2Vec2FeatureExtractor
6
 
7
  # --------------------------------------------------
8
+ # Load Your Fine-Tuned Model for Fluency Prediction
9
  # --------------------------------------------------
10
  # This model was fine-tuned with labels remapped from [3..10] to [0..7].
11
+ # Ensure that "Yilin0601/wav2vec2-fluency-checkpoints" is your correct repo.
12
  model = Wav2Vec2ForSequenceClassification.from_pretrained(
13
  "Yilin0601/wav2vec2-fluency-checkpoints"
14
  )
 
22
  def predict(audio):
23
  if audio is None:
24
  return "No audio provided."
25
+
26
+ # Gradio returns audio as (sample_rate, np.array)
27
  sample_rate, audio_data = audio
28
 
29
+ # Ensure audio is in floating-point (librosa requires float32 or float64)
30
  if audio_data.dtype not in [np.float32, np.float64]:
31
  audio_data = audio_data.astype(np.float32)
32
 
 
34
  if len(audio_data.shape) > 1 and audio_data.shape[1] > 1:
35
  audio_data = np.mean(audio_data, axis=1)
36
 
37
+ # Resample to 16 kHz if necessary
38
  if sample_rate != 16000:
39
  audio_data = librosa.resample(audio_data, orig_sr=sample_rate, target_sr=16000)
40
 
41
+ # Extract features using the feature extractor
42
  inputs = feature_extractor(
43
  audio_data,
44
  sampling_rate=16000,
 
51
  with torch.no_grad():
52
  logits = model(**inputs).logits
53
 
54
+ # The model outputs an 8-class prediction (0..7), corresponding to original fluency scores [3..10]
55
  pred_class = torch.argmax(logits, dim=-1).item()
56
  predicted_level = pred_class + 3 # Map back to [3..10]
57
 
58
+ return f"Predicted Fluency Level: {predicted_level}"
59
 
60
  # --------------------------------------------------
61
  # Gradio Interface
 
66
  outputs="text",
67
  title="L2 English Fluency Predictor",
68
  description=(
69
+ "This demo uses a fine-tuned Wav2Vec2ForSequenceClassification model for fluency prediction. "
70
+ "The model was fine-tuned with fluency scores remapped from [3..10] to [0..7]. "
71
+ "Record or upload audio to see the predicted fluency level. "
72
+ "If the predicted level is always the same (e.g., 8), it might indicate that the model needs further fine-tuning or calibration."
73
  ),
74
  allow_flagging="never"
75
  )