asr-inference / whisper.py
ssolito's picture
Update whisper.py
e14e797 verified
raw
history blame
5.31 kB
from pydub import AudioSegment
import os
from transformers import WhisperForConditionalGeneration, WhisperProcessor, WhisperTokenizer
import torchaudio
import torch
import re
from transformers import pipeline
from peft import PeftModel, PeftConfig
import spaces
device = 0 if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float32
### Configuration
MODEL_NAME_V2 = "./whisper-large-v3-catalan"
MODEL_NAME_V1 = "projecte-aina/whisper-large-v3-tiny-caesar"
CHUNK_LENGTH = 30
BATCH_SIZE = 1
pipe = pipeline(
task="automatic-speech-recognition",
model=MODEL_NAME_V1,
chunk_length_s=30,
device=device,
token=os.getenv("HF_TOKEN")
)
peft_config = PeftConfig.from_pretrained(MODEL_NAME_V2)
model = WhisperForConditionalGeneration.from_pretrained(
peft_config.base_model_name_or_path,
device_map="auto"
)
task = "transcribe"
model = PeftModel.from_pretrained(model, MODEL_NAME_V2)
model.config.use_cache = True
tokenizer = WhisperTokenizer.from_pretrained(peft_config.base_model_name_or_path, task=task)
processor = WhisperProcessor.from_pretrained(peft_config.base_model_name_or_path, task=task)
feature_extractor = processor.feature_extractor
forced_decoder_ids = processor.get_decoder_prompt_ids(task=task)
asr_pipe = pipeline(
task="automatic-speech-recognition",
model=model,
tokenizer=tokenizer,
feature_extractor=feature_extractor,
chunk_length_s=30)
def asr(audio_path, task):
asr_result = asr_pipe(audio_path, batch_size=BATCH_SIZE, generate_kwargs={"task":task}, return_timestamps=True)
base_model = asr_pipe.model.base_model if hasattr(asr_pipe.model, "base_model") else asr_pipe.model
return asr_result
def post_process_transcription(transcription, max_repeats=2):
tokens = re.findall(r'\b\w+\'?\w*\b[.,!?]?', transcription)
cleaned_tokens = []
repetition_count = 0
previous_token = None
for token in tokens:
reduced_token = re.sub(r"(\w{1,3})(\1{2,})", "", token)
if reduced_token == previous_token:
repetition_count += 1
if repetition_count <= max_repeats:
cleaned_tokens.append(reduced_token)
else:
repetition_count = 1
cleaned_tokens.append(reduced_token)
previous_token = reduced_token
cleaned_transcription = " ".join(cleaned_tokens)
cleaned_transcription = re.sub(r'\s+', ' ', cleaned_transcription).strip()
return cleaned_transcription
def format_audio(audio_path):
input_audio, sample_rate = torchaudio.load(audio_path)
if input_audio.shape[0] == 2: #stereo2mono
input_audio = torch.mean(input_audio, dim=0, keepdim=True)
resampler = torchaudio.transforms.Resample(sample_rate, 16000)
input_audio = resampler(input_audio)
input_audio = input_audio.squeeze().numpy()
return(input_audio)
def split_stereo_channels(audio_path):
audio = AudioSegment.from_wav(audio_path)
channels = audio.split_to_mono()
if len(channels) != 2:
raise ValueError(f"Audio {audio_path} does not have 2 channels.")
channels[0].export(f"temp_mono_speaker1.wav", format="wav") # Right
channels[1].export(f"temp_mono_speaker2.wav", format="wav") # Left
def transcribe_pipeline(audio, task):
text = pipe(audio, batch_size=BATCH_SIZE, generate_kwargs={"task": task}, return_timestamps=True)["text"]
return text
def generate(audio_path, use_v2):
if use_v2:
split_stereo_channels(audio_path)
audio_id = os.path.splitext(os.path.basename(audio_path))[0]
left_channel_path = "temp_mono_speaker2.wav"
right_channel_path = "temp_mono_speaker1.wav"
left_audio = format_audio(left_channel_path)
right_audio = format_audio(right_channel_path)
left_result = asr(left_audio, task)
right_result = asr(right_audio, task)
left_segs = [(seg["timestamp"][0], seg["timestamp"][1], "Speaker 1", post_process_transcription(seg["text"])) for seg in left_result["chunks"]]
right_segs = [(seg["timestamp"][0], seg["timestamp"][1], "Speaker 2", post_process_transcription(seg["text"])) for seg in right_result["chunks"]]
merged_transcript = sorted(left_segs + right_segs, key=lambda x: x[0])
merged_text = " ".join([seg[3] for seg in merged_transcript])
output = ""
for start, end, speaker, text in merged_transcript:
output += f"[{start:.2f}s - {end:.2f}s] {speaker}: {text}\n"
else:
audio = AudioSegment.from_wav(audio_path)
temp_mono_path = None
if audio.channels != 1: #stereo2mono
audio = audio.set_channels(1)
temp_mono_path = "temp_mono.wav"
audio.export(temp_mono_path, format="wav")
audio_path = temp_mono_path
task = "transcribe"
output = transcribe_pipeline(format_audio(audio_path), task)
clean_output = post_process_transcription(output, max_repeats=1) #check
if temp_mono_path and os.path.exists(temp_mono_path):
os.remove(temp_mono_path)
for temp_file in ["temp_mono_speaker1.wav", "temp_mono_speaker2.wav"]:
if os.path.exists(temp_file):
os.remove(temp_file)
return clean_output