AbirMessaoudi ssolito commited on
Commit
95d83f6
·
verified ·
1 Parent(s): c01ffa1

Update whisper.py (#5)

Browse files

- Update whisper.py (c6669b54ffd9eb0a0b4bdfe1c33ead47c046f299)


Co-authored-by: Sarah Solito <ssolito@users.noreply.huggingface.co>

Files changed (1) hide show
  1. whisper.py +87 -176
whisper.py CHANGED
@@ -1,35 +1,59 @@
1
- from pyannote.audio import Pipeline
2
  from pydub import AudioSegment
3
  import os
4
- from transformers import WhisperForConditionalGeneration, WhisperProcessor
5
  import torchaudio
6
  import torch
7
  import re
8
  from transformers import pipeline
 
9
  import spaces
10
 
11
-
12
  device = 0 if torch.cuda.is_available() else "cpu"
13
  torch_dtype = torch.float32
14
 
15
-
16
- MODEL_NAME = "openai/whisper-large-v3"
17
- CKPT = "projecte-aina/whisper-large-v3-tiny-caesar"
 
18
  BATCH_SIZE = 1
19
- model = WhisperForConditionalGeneration.from_pretrained(MODEL_NAME, torch_dtype=torch_dtype).to(device)
20
- processor = WhisperProcessor.from_pretrained(MODEL_NAME)
21
- pipeline_vad = Pipeline.from_pretrained("./pyannote/config.yaml")
22
- threshold = 10000
23
- segments_dir = "."
24
 
25
  pipe = pipeline(
26
  task="automatic-speech-recognition",
27
- model=CKPT,
28
  chunk_length_s=30,
29
  device=device,
30
  token=os.getenv("HF_TOKEN")
31
  )
32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  def post_process_transcription(transcription, max_repeats=2):
34
  tokens = re.findall(r'\b\w+\'?\w*\b[.,!?]?', transcription)
35
 
@@ -56,151 +80,6 @@ def post_process_transcription(transcription, max_repeats=2):
56
  return cleaned_transcription
57
 
58
 
59
- def convert_forced_to_tokens(forced_decoder_ids):
60
- forced_decoder_tokens = []
61
- for i, (idx, token) in enumerate(forced_decoder_ids):
62
- if token is not None:
63
- forced_decoder_tokens.append([idx, processor.tokenizer.decode(token)])
64
- else:
65
- forced_decoder_tokens.append([idx, token])
66
- return forced_decoder_tokens
67
-
68
- def generate_1st_chunk(audio):
69
-
70
- input_audio, sample_rate = torchaudio.load(audio)
71
- input_audio = torchaudio.transforms.Resample(sample_rate, 16000)(input_audio)
72
-
73
- input_speech = input_audio[0]
74
-
75
- input_features = processor(input_speech,
76
- sampling_rate=16_000,
77
- return_tensors="pt", torch_dtype=torch_dtype).input_features.to(device)
78
-
79
- forced_decoder_ids = []
80
- forced_decoder_ids.append([1,50270]) #[1, '<|ca|>']
81
- forced_decoder_ids.append([2,50262]) #[2, '<|es|>']
82
- forced_decoder_ids.append([3,50360]) #[3, '<|transcribe|>']
83
-
84
- forced_decoder_ids_modified = forced_decoder_ids
85
- idx = processor.tokenizer.all_special_tokens.index("<|startofprev|>")
86
- forced_bos_token_id = processor.tokenizer.all_special_ids[idx]
87
- prompt = "Antes de 'digui'm', '112'. 112, digui'm. Hola, puc parlar en castellà? Sí, digui, diga. Sí, mire: a veces al abrir la puerta de mi piso tengo una persona ahí. Vale, avisamos a la Guàrdia Urbana, ¿de acuerdo? Vale, perfecto. Gracias. Gracias. Buen día."
88
- prompt_tokens = processor.tokenizer(prompt, add_special_tokens=False).input_ids
89
-
90
- # we need to force these tokens
91
- forced_decoder_ids = []
92
- for idx, token in enumerate(prompt_tokens):
93
- # indexing starts from 1 for forced tokens (token at position 0 is the SOS token)
94
- forced_decoder_ids.append([idx + 1, token])
95
-
96
- # now we add the SOS token at the end
97
- offset = len(forced_decoder_ids)
98
- forced_decoder_ids.append([offset + 1, model.generation_config.decoder_start_token_id])
99
-
100
- # now we need to append the rest of the prefix tokens (lang, task, timestamps)
101
- offset = len(forced_decoder_ids)
102
- for idx, token in forced_decoder_ids_modified:
103
- forced_decoder_ids.append([idx + offset , token])
104
-
105
- model.generation_config.forced_decoder_ids = forced_decoder_ids
106
-
107
- pred_ids = model.generate(input_features,
108
- return_timestamps=True,
109
- max_new_tokens=128,
110
- decoder_start_token_id=forced_bos_token_id)
111
- #exclude prompt from output
112
- forced_decoder_tokens = convert_forced_to_tokens(forced_decoder_ids)
113
- output = processor.decode(pred_ids[0][len(forced_decoder_tokens) + 1:], skip_special_tokens=True)
114
-
115
- return output[1:]
116
-
117
- def generate_2nd_chuk(audio):
118
-
119
- input_audio, sample_rate = torchaudio.load(audio)
120
- input_audio = torchaudio.transforms.Resample(sample_rate, 16000)(input_audio)
121
-
122
- input_speech = input_audio[0]
123
-
124
- input_features = processor(input_speech,
125
- sampling_rate=16_000,
126
- return_tensors="pt", torch_dtype=torch_dtype).input_features.to(device)
127
- forced_decoder_ids = []
128
-
129
- forced_decoder_ids.append([1,50270]) #[1, '<|ca|>']
130
- forced_decoder_ids.append([2,50262]) #[2, '<|es|>']
131
- forced_decoder_ids.append([3,50360]) #[3, '<|transcribe|>']
132
-
133
- forced_decoder_ids_modified = forced_decoder_ids
134
- idx = processor.tokenizer.all_special_tokens.index("<|startofprev|>")
135
- forced_bos_token_id = processor.tokenizer.all_special_ids[idx]
136
-
137
- prompt = "112, digui'm. Hola, puc parlar en castellà? Sí, digui, diga. Sí, mire: a veces al abrir la puerta de mi piso tengo una persona ahí. Vale, avisamos a la Guàrdia Urbana, ¿de acuerdo? Vale, perfecto. Gracias. Gracias. Buen día."
138
- prompt_tokens = processor.tokenizer(prompt, add_special_tokens=False).input_ids
139
-
140
- # we need to force these tokens
141
- forced_decoder_ids = []
142
- for idx, token in enumerate(prompt_tokens):
143
- # indexing starts from 1 for forced tokens (token at position 0 is the SOS token)
144
- forced_decoder_ids.append([idx + 1, token])
145
-
146
- # now we add the SOS token at the end
147
- offset = len(forced_decoder_ids)
148
- forced_decoder_ids.append([offset + 1, model.generation_config.decoder_start_token_id])
149
-
150
- # now we need to append the rest of the prefix tokens (lang, task, timestamps)
151
- offset = len(forced_decoder_ids)
152
- for idx, token in forced_decoder_ids_modified:
153
- forced_decoder_ids.append([idx + offset , token])
154
-
155
- model.generation_config.forced_decoder_ids = forced_decoder_ids
156
-
157
- pred_ids = model.generate(input_features,
158
- return_timestamps=True,
159
- max_new_tokens=128,
160
- decoder_start_token_id=forced_bos_token_id)
161
- #exclude prompt from output
162
- forced_decoder_tokens = convert_forced_to_tokens(forced_decoder_ids)
163
- output = processor.decode(pred_ids[0][len(forced_decoder_tokens) + 1:], skip_special_tokens=True)
164
-
165
- return output[1:]
166
-
167
- def processing_vad_threshold(audio, output_vad, threshold, max_duration, concatenated_segment):
168
-
169
- transcription_audio = ""
170
- is_first_chunk = True
171
- for speech in output_vad.get_timeline().support():
172
- start, end = speech.start, speech.end
173
- segment_duration = (end - start) * 1000
174
- segment_audio = audio[start * 1000:end * 1000]
175
-
176
- if max_duration + segment_duration < threshold:
177
- concatenated_segment += audio[start * 1000:end * 1000]
178
- max_duration += segment_duration
179
-
180
- else:
181
- if len(concatenated_segment) > 0:
182
- temp_segment_path = os.path.join(segments_dir, f"temp_segment.wav")
183
- concatenated_segment.export(temp_segment_path, format="wav")
184
-
185
- if is_first_chunk:
186
- output = generate_1st_chunk(temp_segment_path)
187
- is_first_chunk = False
188
- else:
189
- output = generate_2nd_chuk(temp_segment_path)
190
- transcription_audio = transcription_audio + output
191
- max_duration = segment_duration
192
- concatenated_segment = segment_audio
193
-
194
- # Process any remaining audio in the concatenated_segment
195
- if len(concatenated_segment) > 0:
196
- temp_segment_path = os.path.join(segments_dir, f"temp_segment.wav")
197
- concatenated_segment.export(temp_segment_path, format="wav")
198
-
199
- output = generate_2nd_chuk(temp_segment_path)
200
- transcription_audio = transcription_audio + output
201
-
202
- return(transcription_audio)
203
-
204
  def format_audio(audio_path):
205
  input_audio, sample_rate = torchaudio.load(audio_path)
206
 
@@ -212,34 +91,66 @@ def format_audio(audio_path):
212
  input_audio = input_audio.squeeze().numpy()
213
  return(input_audio)
214
 
 
 
 
 
 
 
 
 
 
 
215
 
216
  def transcribe_pipeline(audio, task):
217
  text = pipe(audio, batch_size=BATCH_SIZE, generate_kwargs={"task": task}, return_timestamps=True)["text"]
218
  return text
219
 
220
- def generate(audio_path, use_v5):
221
- audio = AudioSegment.from_wav(audio_path)
222
 
223
- temp_mono_path = None
224
- if audio.channels != 1: #stereo2mono
225
- audio = audio.set_channels(1)
226
- temp_mono_path = "temp_mono.wav"
227
- audio.export(temp_mono_path, format="wav")
228
- audio_path = temp_mono_path
 
 
 
 
 
 
 
 
 
 
 
 
 
229
 
230
- output_vad = pipeline_vad(audio_path)
231
- concatenated_segment = AudioSegment.empty()
232
- max_duration = 0
233
-
234
- if use_v5:
235
- output = processing_vad_threshold(audio, output_vad, threshold, max_duration, concatenated_segment)
236
- else:
237
- task = "transcribe"
 
 
 
 
 
 
238
  output = transcribe_pipeline(format_audio(audio_path), task)
239
 
240
- clean_output = post_process_transcription(output)
241
-
242
  if temp_mono_path and os.path.exists(temp_mono_path):
243
- os.remove(temp_mono_path)
 
 
 
 
244
 
245
  return clean_output
 
 
1
  from pydub import AudioSegment
2
  import os
3
+ from transformers import WhisperForConditionalGeneration, WhisperProcessor, WhisperTokenizer
4
  import torchaudio
5
  import torch
6
  import re
7
  from transformers import pipeline
8
+ from peft import PeftModel, PeftConfig
9
  import spaces
10
 
 
11
  device = 0 if torch.cuda.is_available() else "cpu"
12
  torch_dtype = torch.float32
13
 
14
+ ### Configuration
15
+ MODEL_NAME_V2 = "./whisper-large-v3-catalan"
16
+ MODEL_NAME_V1 = "projecte-aina/whisper-large-v3-tiny-caesar"
17
+ CHUNK_LENGTH = 30
18
  BATCH_SIZE = 1
 
 
 
 
 
19
 
20
  pipe = pipeline(
21
  task="automatic-speech-recognition",
22
+ model=MODEL_NAME_V1,
23
  chunk_length_s=30,
24
  device=device,
25
  token=os.getenv("HF_TOKEN")
26
  )
27
 
28
+
29
+ peft_config = PeftConfig.from_pretrained(MODEL_NAME_V2)
30
+ model = WhisperForConditionalGeneration.from_pretrained(
31
+ peft_config.base_model_name_or_path,
32
+ device_map="auto"
33
+ )
34
+
35
+ task = "transcribe"
36
+
37
+ model = PeftModel.from_pretrained(model, MODEL_NAME_V2)
38
+ model.config.use_cache = True
39
+
40
+ tokenizer = WhisperTokenizer.from_pretrained(peft_config.base_model_name_or_path, task=task)
41
+ processor = WhisperProcessor.from_pretrained(peft_config.base_model_name_or_path, task=task)
42
+ feature_extractor = processor.feature_extractor
43
+ forced_decoder_ids = processor.get_decoder_prompt_ids(task=task)
44
+
45
+ asr_pipe = pipeline(
46
+ task="automatic-speech-recognition",
47
+ model=model,
48
+ tokenizer=tokenizer,
49
+ feature_extractor=feature_extractor,
50
+ chunk_length_s=30)
51
+
52
+ def asr(audio_path, task):
53
+ asr_result = asr_pipe(audio_path, batch_size=BATCH_SIZE, generate_kwargs={"task":task}, return_timestamps=True)
54
+ base_model = asr_pipe.model.base_model if hasattr(asr_pipe.model, "base_model") else asr_pipe.model
55
+ return asr_result
56
+
57
  def post_process_transcription(transcription, max_repeats=2):
58
  tokens = re.findall(r'\b\w+\'?\w*\b[.,!?]?', transcription)
59
 
 
80
  return cleaned_transcription
81
 
82
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
  def format_audio(audio_path):
84
  input_audio, sample_rate = torchaudio.load(audio_path)
85
 
 
91
  input_audio = input_audio.squeeze().numpy()
92
  return(input_audio)
93
 
94
+ def split_stereo_channels(audio_path):
95
+
96
+ audio = AudioSegment.from_wav(audio_path)
97
+
98
+ channels = audio.split_to_mono()
99
+ if len(channels) != 2:
100
+ raise ValueError(f"Audio {audio_path} does not have 2 channels.")
101
+
102
+ channels[0].export(f"temp_mono_speaker1.wav", format="wav") # Right
103
+ channels[1].export(f"temp_mono_speaker2.wav", format="wav") # Left
104
 
105
  def transcribe_pipeline(audio, task):
106
  text = pipe(audio, batch_size=BATCH_SIZE, generate_kwargs={"task": task}, return_timestamps=True)["text"]
107
  return text
108
 
109
+ def generate(audio_path, use_v2):
 
110
 
111
+ if use_v2:
112
+ split_stereo_channels(audio_path)
113
+
114
+ audio_id = os.path.splitext(os.path.basename(audio_path))[0]
115
+
116
+ left_channel_path = "temp_mono_speaker2.wav"
117
+ right_channel_path = "temp_mono_speaker1.wav"
118
+
119
+ left_audio = format_audio(left_channel_path)
120
+ right_audio = format_audio(right_channel_path)
121
+
122
+ left_result = asr(left_audio, task)
123
+ right_result = asr(right_audio, task)
124
+
125
+ left_segs = [(seg["timestamp"][0], seg["timestamp"][1], "Speaker 1", post_process_transcription(seg["text"])) for seg in left_result["chunks"]]
126
+ right_segs = [(seg["timestamp"][0], seg["timestamp"][1], "Speaker 2", post_process_transcription(seg["text"])) for seg in right_result["chunks"]]
127
+
128
+ merged_transcript = sorted(left_segs + right_segs, key=lambda x: x[0])
129
+ merged_text = " ".join([seg[3] for seg in merged_transcript])
130
 
131
+ output = ""
132
+ for start, end, speaker, text in merged_transcript:
133
+ output += f"[{start:.2f}s - {end:.2f}s] {speaker}: {text}\n"
134
+
135
+ else:
136
+ audio = AudioSegment.from_wav(audio_path)
137
+ temp_mono_path = None
138
+
139
+ if audio.channels != 1: #stereo2mono
140
+ audio = audio.set_channels(1)
141
+ temp_mono_path = "temp_mono.wav"
142
+ audio.export(temp_mono_path, format="wav")
143
+ audio_path = temp_mono_path
144
+ task = "transcribe"
145
  output = transcribe_pipeline(format_audio(audio_path), task)
146
 
147
+ clean_output = post_process_transcription(output, max_repeats=1) #check
148
+
149
  if temp_mono_path and os.path.exists(temp_mono_path):
150
+ os.remove(temp_mono_path)
151
+
152
+ for temp_file in ["temp_mono_speaker1.wav", "temp_mono_speaker2.wav"]:
153
+ if os.path.exists(temp_file):
154
+ os.remove(temp_file)
155
 
156
  return clean_output