from pydub import AudioSegment import os from transformers import WhisperForConditionalGeneration, WhisperProcessor, WhisperTokenizer import torchaudio import torch import re from transformers import pipeline from peft import PeftModel, PeftConfig import spaces device = 0 if torch.cuda.is_available() else "cpu" torch_dtype = torch.float32 ### Configuration MODEL_NAME_V2 = "./whisper-large-v3-catalan" MODEL_NAME_V1 = "projecte-aina/whisper-large-v3-tiny-caesar" CHUNK_LENGTH = 30 BATCH_SIZE = 1 pipe = pipeline( task="automatic-speech-recognition", model=MODEL_NAME_V1, chunk_length_s=30, device=device, token=os.getenv("HF_TOKEN") ) peft_config = PeftConfig.from_pretrained(MODEL_NAME_V2) model = WhisperForConditionalGeneration.from_pretrained( peft_config.base_model_name_or_path, device_map="auto" ) task = "transcribe" model = PeftModel.from_pretrained(model, MODEL_NAME_V2) model.config.use_cache = True tokenizer = WhisperTokenizer.from_pretrained(peft_config.base_model_name_or_path, task=task) processor = WhisperProcessor.from_pretrained(peft_config.base_model_name_or_path, task=task) feature_extractor = processor.feature_extractor forced_decoder_ids = processor.get_decoder_prompt_ids(task=task) asr_pipe = pipeline( task="automatic-speech-recognition", model=model, tokenizer=tokenizer, feature_extractor=feature_extractor, chunk_length_s=30) def asr(audio_path, task): asr_result = asr_pipe(audio_path, batch_size=BATCH_SIZE, generate_kwargs={"task":task}, return_timestamps=True) base_model = asr_pipe.model.base_model if hasattr(asr_pipe.model, "base_model") else asr_pipe.model return asr_result def post_process_transcription(transcription, max_repeats=2): tokens = re.findall(r'\b\w+\'?\w*\b[.,!?]?', transcription) cleaned_tokens = [] repetition_count = 0 previous_token = None for token in tokens: reduced_token = re.sub(r"(\w{1,3})(\1{2,})", "", token) if reduced_token == previous_token: repetition_count += 1 if repetition_count <= max_repeats: cleaned_tokens.append(reduced_token) else: repetition_count = 1 cleaned_tokens.append(reduced_token) previous_token = reduced_token cleaned_transcription = " ".join(cleaned_tokens) cleaned_transcription = re.sub(r'\s+', ' ', cleaned_transcription).strip() return cleaned_transcription def format_audio(audio_path): input_audio, sample_rate = torchaudio.load(audio_path) if input_audio.shape[0] == 2: #stereo2mono input_audio = torch.mean(input_audio, dim=0, keepdim=True) resampler = torchaudio.transforms.Resample(sample_rate, 16000) input_audio = resampler(input_audio) input_audio = input_audio.squeeze().numpy() return(input_audio) def split_stereo_channels(audio_path): audio = AudioSegment.from_wav(audio_path) channels = audio.split_to_mono() if len(channels) != 2: raise ValueError(f"Audio {audio_path} does not have 2 channels.") channels[0].export(f"temp_mono_speaker1.wav", format="wav") # Right channels[1].export(f"temp_mono_speaker2.wav", format="wav") # Left def transcribe_pipeline(audio, task): text = pipe(audio, batch_size=BATCH_SIZE, generate_kwargs={"task": task}, return_timestamps=True)["text"] return text def generate(audio_path, use_v2): task = "transcribe" temp_mono_path = None if use_v2: split_stereo_channels(audio_path) audio_id = os.path.splitext(os.path.basename(audio_path))[0] left_channel_path = "temp_mono_speaker2.wav" right_channel_path = "temp_mono_speaker1.wav" left_audio = format_audio(left_channel_path) right_audio = format_audio(right_channel_path) left_result = asr(left_audio, task) right_result = asr(right_audio, task) left_segs = [(seg["timestamp"][0], seg["timestamp"][1], "Speaker 1", post_process_transcription(seg["text"])) for seg in left_result["chunks"]] right_segs = [(seg["timestamp"][0], seg["timestamp"][1], "Speaker 2", post_process_transcription(seg["text"])) for seg in right_result["chunks"]] #merged_transcript = sorted(left_segs + right_segs, key=lambda x: x[0]) merged_transcript = sorted(left_segs + right_segs, key=lambda x: x[0] if x[0] is not None else 0.0) output = "" for start, end, speaker, text in merged_transcript: output += f"[{speaker}]: {text}\n" clean_output = output.strip() else: audio = AudioSegment.from_wav(audio_path) if audio.channels != 1: #stereo2mono audio = audio.set_channels(1) temp_mono_path = "temp_mono.wav" audio.export(temp_mono_path, format="wav") audio_path = temp_mono_path output = transcribe_pipeline(format_audio(audio_path), task) clean_output = post_process_transcription(output) if temp_mono_path and os.path.exists(temp_mono_path): os.remove(temp_mono_path) for temp_file in ["temp_mono_speaker1.wav", "temp_mono_speaker2.wav"]: if os.path.exists(temp_file): os.remove(temp_file) return clean_output