Gapeleon commited on
Commit
eb92e9b
·
verified ·
1 Parent(s): 8bce1e6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -10
app.py CHANGED
@@ -47,7 +47,10 @@ def transcribe_audio(audio_input):
47
  # Microphone input: (sample_rate, numpy_array)
48
  logs.append("Processing microphone input")
49
  sr, wav_np = audio_input
50
- wav = torch.from_numpy(wav_np).float().unsqueeze(0)
 
 
 
51
  else:
52
  # File input: filepath string
53
  logs.append(f"Processing file input: {audio_input}")
@@ -68,6 +71,11 @@ def transcribe_audio(audio_input):
68
 
69
  logs.append(f"Final audio: sample rate {sr}Hz, shape {wav.shape}, min: {wav.min().item()}, max: {wav.max().item()}")
70
 
 
 
 
 
 
71
  # Create text prompt
72
  chat = [
73
  {
@@ -88,11 +96,15 @@ def transcribe_audio(audio_input):
88
  logs.append("Preparing model inputs")
89
  model_inputs = speech_granite_processor(
90
  text=text,
91
- audio=wav.numpy().squeeze(), # Convert to numpy and squeeze
92
  sampling_rate=sr,
93
  return_tensors="pt",
94
  ).to(device)
95
 
 
 
 
 
96
  # Generate transcription
97
  logs.append("Generating transcription")
98
  model_outputs = speech_granite.generate(
@@ -105,21 +117,16 @@ def transcribe_audio(audio_input):
105
  repetition_penalty=3.0,
106
  length_penalty=1.0,
107
  temperature=1.0,
108
- bos_token_id=tokenizer.bos_token_id,
109
- eos_token_id=tokenizer.eos_token_id,
110
- pad_token_id=tokenizer.pad_token_id,
111
  )
112
 
113
  # Extract the generated text (skipping input tokens)
114
  logs.append("Processing output")
115
  num_input_tokens = model_inputs["input_ids"].shape[-1]
116
- new_tokens = torch.unsqueeze(model_outputs[0, num_input_tokens:], dim=0)
117
 
118
- output_text = tokenizer.batch_decode(
119
- new_tokens, add_special_tokens=False, skip_special_tokens=True
120
- )
121
 
122
- transcription = output_text[0].strip().upper()
123
  logs.append(f"Transcription complete: {transcription[:50]}...")
124
 
125
  except Exception as e:
 
47
  # Microphone input: (sample_rate, numpy_array)
48
  logs.append("Processing microphone input")
49
  sr, wav_np = audio_input
50
+ wav = torch.from_numpy(wav_np).float()
51
+ # Make sure we have the right dimensions [channels, time]
52
+ if len(wav.shape) == 1:
53
+ wav = wav.unsqueeze(0)
54
  else:
55
  # File input: filepath string
56
  logs.append(f"Processing file input: {audio_input}")
 
71
 
72
  logs.append(f"Final audio: sample rate {sr}Hz, shape {wav.shape}, min: {wav.min().item()}, max: {wav.max().item()}")
73
 
74
+ # Convert to numpy array as expected by the processor
75
+ # Make sure it's in the format [time]
76
+ wav_np = wav.squeeze().numpy()
77
+ logs.append(f"Audio array shape for processor: {wav_np.shape}")
78
+
79
  # Create text prompt
80
  chat = [
81
  {
 
96
  logs.append("Preparing model inputs")
97
  model_inputs = speech_granite_processor(
98
  text=text,
99
+ audio=wav_np, # Pass numpy array in format [time]
100
  sampling_rate=sr,
101
  return_tensors="pt",
102
  ).to(device)
103
 
104
+ # Verify audio tokens are present
105
+ if "audio_values" not in model_inputs:
106
+ logs.append(f"WARNING: No audio_values in model inputs. Keys present: {list(model_inputs.keys())}")
107
+
108
  # Generate transcription
109
  logs.append("Generating transcription")
110
  model_outputs = speech_granite.generate(
 
117
  repetition_penalty=3.0,
118
  length_penalty=1.0,
119
  temperature=1.0,
 
 
 
120
  )
121
 
122
  # Extract the generated text (skipping input tokens)
123
  logs.append("Processing output")
124
  num_input_tokens = model_inputs["input_ids"].shape[-1]
125
+ new_tokens = model_outputs[0, num_input_tokens:]
126
 
127
+ output_text = tokenizer.decode(new_tokens, skip_special_tokens=True)
 
 
128
 
129
+ transcription = output_text.strip().upper()
130
  logs.append(f"Transcription complete: {transcription[:50]}...")
131
 
132
  except Exception as e: