asr-inference / whisper.py
AbirMessaoudi's picture
Update whisper.py (#14)
790d7cc verified
from pydub import AudioSegment
import os
from transformers import WhisperForConditionalGeneration, WhisperProcessor, WhisperTokenizer
import torchaudio
import torch
import re
from transformers import pipeline
from peft import PeftModel, PeftConfig
import spaces
device = 0 if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float32
### Configuration
MODEL_NAME_V2 = "./whisper-large-v3-catalan"
MODEL_NAME_V1 = "projecte-aina/whisper-large-v3-tiny-caesar"
CHUNK_LENGTH = 30
BATCH_SIZE = 1
pipe = pipeline(
task="automatic-speech-recognition",
model=MODEL_NAME_V1,
chunk_length_s=30,
device=device,
token=os.getenv("HF_TOKEN")
)
peft_config = PeftConfig.from_pretrained(MODEL_NAME_V2)
model = WhisperForConditionalGeneration.from_pretrained(
peft_config.base_model_name_or_path,
device_map="auto"
)
task = "transcribe"
model = PeftModel.from_pretrained(model, MODEL_NAME_V2)
model.config.use_cache = True
tokenizer = WhisperTokenizer.from_pretrained(peft_config.base_model_name_or_path, task=task)
processor = WhisperProcessor.from_pretrained(peft_config.base_model_name_or_path, task=task)
feature_extractor = processor.feature_extractor
forced_decoder_ids = processor.get_decoder_prompt_ids(task=task)
asr_pipe = pipeline(
task="automatic-speech-recognition",
model=model,
tokenizer=tokenizer,
feature_extractor=feature_extractor,
chunk_length_s=30)
def asr(audio_path, task):
asr_result = asr_pipe(audio_path, batch_size=BATCH_SIZE, generate_kwargs={"task":task}, return_timestamps=True)
base_model = asr_pipe.model.base_model if hasattr(asr_pipe.model, "base_model") else asr_pipe.model
return asr_result
def post_process_transcription(transcription, max_repeats=2):
tokens = re.findall(r'\b\w+\'?\w*\b[.,!?]?', transcription)
cleaned_tokens = []
repetition_count = 0
previous_token = None
for token in tokens:
reduced_token = re.sub(r"(\w{1,3})(\1{2,})", "", token)
if reduced_token == previous_token:
repetition_count += 1
if repetition_count <= max_repeats:
cleaned_tokens.append(reduced_token)
else:
repetition_count = 1
cleaned_tokens.append(reduced_token)
previous_token = reduced_token
cleaned_transcription = " ".join(cleaned_tokens)
cleaned_transcription = re.sub(r'\s+', ' ', cleaned_transcription).strip()
return cleaned_transcription
def format_audio(audio_path):
input_audio, sample_rate = torchaudio.load(audio_path)
if input_audio.shape[0] == 2: #stereo2mono
input_audio = torch.mean(input_audio, dim=0, keepdim=True)
resampler = torchaudio.transforms.Resample(sample_rate, 16000)
input_audio = resampler(input_audio)
input_audio = input_audio.squeeze().numpy()
return(input_audio)
def split_stereo_channels(audio_path):
audio = AudioSegment.from_wav(audio_path)
channels = audio.split_to_mono()
if len(channels) != 2:
raise ValueError(f"Audio {audio_path} does not have 2 channels.")
channels[0].export(f"temp_mono_speaker1.wav", format="wav") # Right
channels[1].export(f"temp_mono_speaker2.wav", format="wav") # Left
def transcribe_pipeline(audio, task):
text = pipe(audio, batch_size=BATCH_SIZE, generate_kwargs={"task": task}, return_timestamps=True)["text"]
return text
def generate(audio_path, use_v2):
task = "transcribe"
temp_mono_path = None
if use_v2:
split_stereo_channels(audio_path)
audio_id = os.path.splitext(os.path.basename(audio_path))[0]
left_channel_path = "temp_mono_speaker2.wav"
right_channel_path = "temp_mono_speaker1.wav"
left_audio = format_audio(left_channel_path)
right_audio = format_audio(right_channel_path)
left_result = asr(left_audio, task)
right_result = asr(right_audio, task)
left_segs = [(seg["timestamp"][0], seg["timestamp"][1], "Speaker 1", post_process_transcription(seg["text"])) for seg in left_result["chunks"]]
right_segs = [(seg["timestamp"][0], seg["timestamp"][1], "Speaker 2", post_process_transcription(seg["text"])) for seg in right_result["chunks"]]
#merged_transcript = sorted(left_segs + right_segs, key=lambda x: x[0])
merged_transcript = sorted(left_segs + right_segs, key=lambda x: x[0] if x[0] is not None else 0.0)
output = ""
for start, end, speaker, text in merged_transcript:
output += f"[{speaker}]: {text}\n"
clean_output = output.strip()
else:
audio = AudioSegment.from_wav(audio_path)
if audio.channels != 1: #stereo2mono
audio = audio.set_channels(1)
temp_mono_path = "temp_mono.wav"
audio.export(temp_mono_path, format="wav")
audio_path = temp_mono_path
output = transcribe_pipeline(format_audio(audio_path), task)
clean_output = post_process_transcription(output)
if temp_mono_path and os.path.exists(temp_mono_path):
os.remove(temp_mono_path)
for temp_file in ["temp_mono_speaker1.wav", "temp_mono_speaker2.wav"]:
if os.path.exists(temp_file):
os.remove(temp_file)
return clean_output