Yilin0601 commited on
Commit
2597334
·
verified ·
1 Parent(s): 141103b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -29
app.py CHANGED
@@ -4,61 +4,69 @@ import numpy as np
4
  from transformers import Wav2Vec2ForSequenceClassification, Wav2Vec2FeatureExtractor
5
  import librosa
6
 
7
- # -------------------------------
8
- # Configuration – Modify as Needed
9
- # -------------------------------
10
- # Number of labels for your classification task.
11
- # For example, if you want levels 3 through 10, that's 8 labels.
12
- num_labels = 8
13
 
14
- # Pre-trained model: We're using facebook's wav2vec2-base-960h.
15
- # Since we are not fine-tuning, we are simply adding a classification head with random weights.
16
- model = Wav2Vec2ForSequenceClassification.from_pretrained("facebook/wav2vec2-base-960h", num_labels=8)
 
 
 
17
  feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("facebook/wav2vec2-base-960h")
18
 
19
- # -------------------------------
 
 
 
20
  # Prediction Function
21
- # -------------------------------
22
  def predict(audio):
23
  if audio is None:
24
  return "No audio provided."
25
-
26
- # Gradio returns audio as a tuple: (sample_rate, np.array)
27
  sample_rate, audio_data = audio
28
 
29
- # Ensure the audio is at 16 kHz (the expected sampling rate)
 
 
 
 
30
  if sample_rate != 16000:
31
- audio_data = librosa.resample(np.asarray(audio_data), orig_sr=sample_rate, target_sr=16000)
32
 
33
- # Process the audio input using the feature extractor
34
  inputs = feature_extractor(audio_data, sampling_rate=16000, return_tensors="pt", padding=True)
35
 
36
- # Set the model to evaluation mode and run inference
37
  model.eval()
38
  with torch.no_grad():
39
  logits = model(**inputs).logits
40
 
41
- # Obtain the predicted class (index)
42
  pred_class = torch.argmax(logits, dim=-1).item()
43
 
44
- # If you want to map from 0..7 to your intended label range (e.g., 3..10),
45
- # simply add an offset. Here we add 3.
46
- predicted_level = pred_class + 3
47
 
48
- # Return a string with the predicted level
49
- return f"Predicted L2 English Accuracy Level: {predicted_level}"
50
 
51
- # -------------------------------
52
- # Gradio Interface Setup
53
- # -------------------------------
54
  iface = gr.Interface(
55
  fn=predict,
56
  inputs=gr.Audio(type="numpy", label="Record or Upload Audio"),
57
  outputs="text",
58
- title="L2 English Accuracy Predictor Demo",
59
  description=(
60
- "This demo uses Wav2Vec2ForSequenceClassification without fine-tuning. "
61
- "the prediction results are random and for demonstration purposes only."
 
62
  ),
63
  allow_flagging="never"
64
  )
 
4
  from transformers import Wav2Vec2ForSequenceClassification, Wav2Vec2FeatureExtractor
5
  import librosa
6
 
7
+ # --------------------------------------------------
8
+ # Configuration
9
+ # --------------------------------------------------
10
+ # We have 3 classes: 0 = "low", 1 = "medium", 2 = "high"
11
+ num_labels = 3
 
12
 
13
+ # Load a base Wav2Vec2 model for classification with 3 labels.
14
+ # The classification head will be randomly initialized.
15
+ model = Wav2Vec2ForSequenceClassification.from_pretrained(
16
+ "facebook/wav2vec2-base-960h",
17
+ num_labels=num_labels
18
+ )
19
  feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("facebook/wav2vec2-base-960h")
20
 
21
+ # Map integer predictions to textual labels
22
+ label_map = {0: "low", 1: "medium", 2: "high"}
23
+
24
+ # --------------------------------------------------
25
  # Prediction Function
26
+ # --------------------------------------------------
27
  def predict(audio):
28
  if audio is None:
29
  return "No audio provided."
30
+
31
+ # Gradio provides audio as (sample_rate, np.array)
32
  sample_rate, audio_data = audio
33
 
34
+ # Convert stereo to mono if needed
35
+ if len(audio_data.shape) > 1 and audio_data.shape[1] > 1:
36
+ audio_data = np.mean(audio_data, axis=1)
37
+
38
+ # Resample to 16 kHz if not already
39
  if sample_rate != 16000:
40
+ audio_data = librosa.resample(audio_data, orig_sr=sample_rate, target_sr=16000)
41
 
42
+ # Extract features
43
  inputs = feature_extractor(audio_data, sampling_rate=16000, return_tensors="pt", padding=True)
44
 
45
+ # Model inference
46
  model.eval()
47
  with torch.no_grad():
48
  logits = model(**inputs).logits
49
 
50
+ # Argmax over logits -> integer class
51
  pred_class = torch.argmax(logits, dim=-1).item()
52
 
53
+ # Convert integer class to textual label
54
+ predicted_label = label_map.get(pred_class, "Unknown")
 
55
 
56
+ return f"Predicted Level: {predicted_label}"
 
57
 
58
+ # --------------------------------------------------
59
+ # Gradio Interface
60
+ # --------------------------------------------------
61
  iface = gr.Interface(
62
  fn=predict,
63
  inputs=gr.Audio(type="numpy", label="Record or Upload Audio"),
64
  outputs="text",
65
+ title="3-Class Audio Classification Demo (Random)",
66
  description=(
67
+ "This demo uses Wav2Vec2ForSequenceClassification with 3 classes (low, medium, high) "
68
+ "but has not been fine-tuned, so the classification head is random. The predictions "
69
+ "are not meaningful, but the pipeline demonstrates how a 3-class audio classifier can be set up."
70
  ),
71
  allow_flagging="never"
72
  )