Yilin0601 commited on
Commit
96e64e2
·
verified ·
1 Parent(s): df4b4f2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -18
app.py CHANGED
@@ -6,35 +6,44 @@ import numpy as np
6
  import librosa
7
  from transformers import Wav2Vec2FeatureExtractor, Wav2Vec2ForSequenceClassification
8
 
9
- # 1. Load your fine-tuned model & feature extractor from the Hugging Face Hub or local path
10
- # Replace "YourUsername/YourModelRepo" with the actual repo ID where your fine-tuned model is hosted
11
- model_name = "YourUsername/YourModelRepo"
12
- model = Wav2Vec2ForSequenceClassification.from_pretrained(model_name)
 
 
 
 
 
 
13
  feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(model_name)
14
 
15
  model.eval()
16
 
 
 
 
17
  def classify_accuracy(audio):
18
  """
19
- audio: Gradio provides a tuple (sample_rate, data) when type='numpy'.
20
- We'll convert to the correct format, run inference, and return the predicted level.
21
  """
22
  if audio is None:
23
  return "No audio provided."
24
 
25
  sample_rate, data = audio
26
 
27
- # Ensure the audio is a NumPy array
28
  if not isinstance(data, np.ndarray):
29
  data = np.array(data)
30
 
31
- # Resample if needed (model expects 16kHz)
32
  target_sr = 16000
33
  if sample_rate != target_sr:
34
  data = librosa.resample(data, orig_sr=sample_rate, target_sr=target_sr)
35
  sample_rate = target_sr
36
 
37
- # Convert to batch of size 1
38
  inputs = feature_extractor(
39
  data,
40
  sampling_rate=sample_rate,
@@ -47,28 +56,29 @@ def classify_accuracy(audio):
47
  logits = outputs.logits
48
  predicted_id = torch.argmax(logits, dim=-1).item()
49
 
50
- # Map model output (0..7) back to your desired scale (3..10) if needed
51
  accuracy_level = predicted_id + 3
52
 
53
  return f"Predicted Accuracy Level: {accuracy_level}"
54
 
55
- # 2. Build Gradio Interface
56
- title = "Speech Accuracy Classifier"
 
 
57
  description = (
58
  "Upload an audio file (or record audio) on the left. "
59
- "The model will classify the audio's accuracy level on the right."
 
60
  )
61
 
62
- # Gradio Interface:
63
  demo = gr.Interface(
64
  fn=classify_accuracy,
65
- inputs=gr.Audio(source="upload", type="numpy"), # left side: audio upload
66
- outputs="text", # right side: classification result
67
  title=title,
68
  description=description,
69
- allow_flagging="never" # disable user flagging if you prefer
70
  )
71
 
72
- # 3. Launch Gradio App
73
  if __name__ == "__main__":
74
  demo.launch()
 
6
  import librosa
7
  from transformers import Wav2Vec2FeatureExtractor, Wav2Vec2ForSequenceClassification
8
 
9
+ # ------------------------------------------------
10
+ # 1. Load base Wav2Vec2 model + classification head
11
+ # ------------------------------------------------
12
+ model_name = "facebook/wav2vec2-base-960h"
13
+
14
+ # We specify num_labels=8 to create a random classification head on top
15
+ model = Wav2Vec2ForSequenceClassification.from_pretrained(
16
+ model_name,
17
+ num_labels=8
18
+ )
19
  feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(model_name)
20
 
21
  model.eval()
22
 
23
+ # ------------------------------------------------
24
+ # 2. Define inference function
25
+ # ------------------------------------------------
26
  def classify_accuracy(audio):
27
  """
28
+ Receives a tuple (sample_rate, data) from Gradio when type='numpy'.
29
+ We'll resample if needed, run a forward pass, and return a 'level'.
30
  """
31
  if audio is None:
32
  return "No audio provided."
33
 
34
  sample_rate, data = audio
35
 
36
+ # Ensure we have a NumPy array
37
  if not isinstance(data, np.ndarray):
38
  data = np.array(data)
39
 
40
+ # Resample if the model expects 16kHz
41
  target_sr = 16000
42
  if sample_rate != target_sr:
43
  data = librosa.resample(data, orig_sr=sample_rate, target_sr=target_sr)
44
  sample_rate = target_sr
45
 
46
+ # Extract features
47
  inputs = feature_extractor(
48
  data,
49
  sampling_rate=sample_rate,
 
56
  logits = outputs.logits
57
  predicted_id = torch.argmax(logits, dim=-1).item()
58
 
59
+ # Map 0..7 3..10 if you want a "level" in that range
60
  accuracy_level = predicted_id + 3
61
 
62
  return f"Predicted Accuracy Level: {accuracy_level}"
63
 
64
+ # ------------------------------------------------
65
+ # 3. Build Gradio interface
66
+ # ------------------------------------------------
67
+ title = "Speech Accuracy Classifier (Base Wav2Vec2)"
68
  description = (
69
  "Upload an audio file (or record audio) on the left. "
70
+ "The base model is NOT fine-tuned for classification, so results may be random. "
71
+ "This demo simply illustrates how to attach a classification head."
72
  )
73
 
 
74
  demo = gr.Interface(
75
  fn=classify_accuracy,
76
+ inputs=gr.Audio(source="upload", type="numpy"),
77
+ outputs="text",
78
  title=title,
79
  description=description,
80
+ allow_flagging="never"
81
  )
82
 
 
83
  if __name__ == "__main__":
84
  demo.launch()