Yilin0601 commited on
Commit
c076438
·
verified ·
1 Parent(s): 96e64e2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -12
app.py CHANGED
@@ -11,7 +11,7 @@ from transformers import Wav2Vec2FeatureExtractor, Wav2Vec2ForSequenceClassifica
11
  # ------------------------------------------------
12
  model_name = "facebook/wav2vec2-base-960h"
13
 
14
- # We specify num_labels=8 to create a random classification head on top
15
  model = Wav2Vec2ForSequenceClassification.from_pretrained(
16
  model_name,
17
  num_labels=8
@@ -26,24 +26,24 @@ model.eval()
26
  def classify_accuracy(audio):
27
  """
28
  Receives a tuple (sample_rate, data) from Gradio when type='numpy'.
29
- We'll resample if needed, run a forward pass, and return a 'level'.
30
  """
31
  if audio is None:
32
  return "No audio provided."
33
 
34
  sample_rate, data = audio
35
 
36
- # Ensure we have a NumPy array
37
  if not isinstance(data, np.ndarray):
38
  data = np.array(data)
39
 
40
- # Resample if the model expects 16kHz
41
  target_sr = 16000
42
  if sample_rate != target_sr:
43
  data = librosa.resample(data, orig_sr=sample_rate, target_sr=target_sr)
44
  sample_rate = target_sr
45
 
46
- # Extract features
47
  inputs = feature_extractor(
48
  data,
49
  sampling_rate=sample_rate,
@@ -51,14 +51,14 @@ def classify_accuracy(audio):
51
  padding=True
52
  )
53
 
 
54
  with torch.no_grad():
55
  outputs = model(**inputs)
56
  logits = outputs.logits
57
  predicted_id = torch.argmax(logits, dim=-1).item()
58
 
59
- # Map 0..7 3..10 if you want a "level" in that range
60
  accuracy_level = predicted_id + 3
61
-
62
  return f"Predicted Accuracy Level: {accuracy_level}"
63
 
64
  # ------------------------------------------------
@@ -66,15 +66,15 @@ def classify_accuracy(audio):
66
  # ------------------------------------------------
67
  title = "Speech Accuracy Classifier (Base Wav2Vec2)"
68
  description = (
69
- "Upload an audio file (or record audio) on the left. "
70
- "The base model is NOT fine-tuned for classification, so results may be random. "
71
- "This demo simply illustrates how to attach a classification head."
72
  )
73
 
 
74
  demo = gr.Interface(
75
  fn=classify_accuracy,
76
- inputs=gr.Audio(source="upload", type="numpy"),
77
- outputs="text",
78
  title=title,
79
  description=description,
80
  allow_flagging="never"
 
11
  # ------------------------------------------------
12
  model_name = "facebook/wav2vec2-base-960h"
13
 
14
+ # Specify num_labels=8 to create a random classification head on top.
15
  model = Wav2Vec2ForSequenceClassification.from_pretrained(
16
  model_name,
17
  num_labels=8
 
26
  def classify_accuracy(audio):
27
  """
28
  Receives a tuple (sample_rate, data) from Gradio when type='numpy'.
29
+ Resamples if needed, runs a forward pass, and returns a 'level'.
30
  """
31
  if audio is None:
32
  return "No audio provided."
33
 
34
  sample_rate, data = audio
35
 
36
+ # Ensure data is a NumPy array.
37
  if not isinstance(data, np.ndarray):
38
  data = np.array(data)
39
 
40
+ # Resample to 16kHz if needed.
41
  target_sr = 16000
42
  if sample_rate != target_sr:
43
  data = librosa.resample(data, orig_sr=sample_rate, target_sr=target_sr)
44
  sample_rate = target_sr
45
 
46
+ # Extract features from the audio data.
47
  inputs = feature_extractor(
48
  data,
49
  sampling_rate=sample_rate,
 
51
  padding=True
52
  )
53
 
54
+ # Run model inference.
55
  with torch.no_grad():
56
  outputs = model(**inputs)
57
  logits = outputs.logits
58
  predicted_id = torch.argmax(logits, dim=-1).item()
59
 
60
+ # Map predicted id (0..7) to the final level (3..10).
61
  accuracy_level = predicted_id + 3
 
62
  return f"Predicted Accuracy Level: {accuracy_level}"
63
 
64
  # ------------------------------------------------
 
66
  # ------------------------------------------------
67
  title = "Speech Accuracy Classifier (Base Wav2Vec2)"
68
  description = (
69
+ "Record audio using your microphone or upload an audio file (left). "
70
+ "The model (not fine-tuned) will classify the audio into an accuracy level (right)."
 
71
  )
72
 
73
+ # Using source="microphone" allows for direct recording, while recent versions also enable file upload.
74
  demo = gr.Interface(
75
  fn=classify_accuracy,
76
+ inputs=gr.Audio(source="microphone", type="numpy", label="Record/Upload Audio"),
77
+ outputs=gr.Textbox(label="Classification Result"),
78
  title=title,
79
  description=description,
80
  allow_flagging="never"