Spaces:
Running
on
Zero
Running
on
Zero
from pydub import AudioSegment | |
import os | |
from transformers import WhisperForConditionalGeneration, WhisperProcessor, WhisperTokenizer | |
import torchaudio | |
import torch | |
import re | |
from transformers import pipeline | |
from peft import PeftModel, PeftConfig | |
import spaces | |
device = 0 if torch.cuda.is_available() else "cpu" | |
torch_dtype = torch.float32 | |
### Configuration | |
MODEL_NAME_V2 = "./whisper-large-v3-catalan" | |
MODEL_NAME_V1 = "projecte-aina/whisper-large-v3-tiny-caesar" | |
CHUNK_LENGTH = 30 | |
BATCH_SIZE = 1 | |
pipe = pipeline( | |
task="automatic-speech-recognition", | |
model=MODEL_NAME_V1, | |
chunk_length_s=30, | |
device=device, | |
token=os.getenv("HF_TOKEN") | |
) | |
peft_config = PeftConfig.from_pretrained(MODEL_NAME_V2) | |
model = WhisperForConditionalGeneration.from_pretrained( | |
peft_config.base_model_name_or_path, | |
device_map="auto" | |
) | |
task = "transcribe" | |
model = PeftModel.from_pretrained(model, MODEL_NAME_V2) | |
model.config.use_cache = True | |
tokenizer = WhisperTokenizer.from_pretrained(peft_config.base_model_name_or_path, task=task) | |
processor = WhisperProcessor.from_pretrained(peft_config.base_model_name_or_path, task=task) | |
feature_extractor = processor.feature_extractor | |
forced_decoder_ids = processor.get_decoder_prompt_ids(task=task) | |
asr_pipe = pipeline( | |
task="automatic-speech-recognition", | |
model=model, | |
tokenizer=tokenizer, | |
feature_extractor=feature_extractor, | |
chunk_length_s=30) | |
def asr(audio_path, task): | |
asr_result = asr_pipe(audio_path, batch_size=BATCH_SIZE, generate_kwargs={"task":task}, return_timestamps=True) | |
base_model = asr_pipe.model.base_model if hasattr(asr_pipe.model, "base_model") else asr_pipe.model | |
return asr_result | |
def post_process_transcription(transcription, max_repeats=2): | |
tokens = re.findall(r'\b\w+\'?\w*\b[.,!?]?', transcription) | |
cleaned_tokens = [] | |
repetition_count = 0 | |
previous_token = None | |
for token in tokens: | |
reduced_token = re.sub(r"(\w{1,3})(\1{2,})", "", token) | |
if reduced_token == previous_token: | |
repetition_count += 1 | |
if repetition_count <= max_repeats: | |
cleaned_tokens.append(reduced_token) | |
else: | |
repetition_count = 1 | |
cleaned_tokens.append(reduced_token) | |
previous_token = reduced_token | |
cleaned_transcription = " ".join(cleaned_tokens) | |
cleaned_transcription = re.sub(r'\s+', ' ', cleaned_transcription).strip() | |
return cleaned_transcription | |
def format_audio(audio_path): | |
input_audio, sample_rate = torchaudio.load(audio_path) | |
if input_audio.shape[0] == 2: #stereo2mono | |
input_audio = torch.mean(input_audio, dim=0, keepdim=True) | |
resampler = torchaudio.transforms.Resample(sample_rate, 16000) | |
input_audio = resampler(input_audio) | |
input_audio = input_audio.squeeze().numpy() | |
return(input_audio) | |
def split_stereo_channels(audio_path): | |
audio = AudioSegment.from_wav(audio_path) | |
channels = audio.split_to_mono() | |
if len(channels) != 2: | |
raise ValueError(f"Audio {audio_path} does not have 2 channels.") | |
channels[0].export(f"temp_mono_speaker1.wav", format="wav") # Right | |
channels[1].export(f"temp_mono_speaker2.wav", format="wav") # Left | |
def transcribe_pipeline(audio, task): | |
text = pipe(audio, batch_size=BATCH_SIZE, generate_kwargs={"task": task}, return_timestamps=True)["text"] | |
return text | |
def generate(audio_path, use_v2): | |
task = "transcribe" | |
temp_mono_path = None | |
if use_v2: | |
split_stereo_channels(audio_path) | |
audio_id = os.path.splitext(os.path.basename(audio_path))[0] | |
left_channel_path = "temp_mono_speaker2.wav" | |
right_channel_path = "temp_mono_speaker1.wav" | |
left_audio = format_audio(left_channel_path) | |
right_audio = format_audio(right_channel_path) | |
left_result = asr(left_audio, task) | |
right_result = asr(right_audio, task) | |
left_segs = [(seg["timestamp"][0], seg["timestamp"][1], "Speaker 1", post_process_transcription(seg["text"])) for seg in left_result["chunks"]] | |
right_segs = [(seg["timestamp"][0], seg["timestamp"][1], "Speaker 2", post_process_transcription(seg["text"])) for seg in right_result["chunks"]] | |
#merged_transcript = sorted(left_segs + right_segs, key=lambda x: x[0]) | |
merged_transcript = sorted(left_segs + right_segs, key=lambda x: x[0] if x[0] is not None else 0.0) | |
output = "" | |
for start, end, speaker, text in merged_transcript: | |
output += f"[{speaker}]: {text}\n" | |
clean_output = output.strip() | |
else: | |
audio = AudioSegment.from_wav(audio_path) | |
if audio.channels != 1: #stereo2mono | |
audio = audio.set_channels(1) | |
temp_mono_path = "temp_mono.wav" | |
audio.export(temp_mono_path, format="wav") | |
audio_path = temp_mono_path | |
output = transcribe_pipeline(format_audio(audio_path), task) | |
clean_output = post_process_transcription(output) | |
if temp_mono_path and os.path.exists(temp_mono_path): | |
os.remove(temp_mono_path) | |
for temp_file in ["temp_mono_speaker1.wav", "temp_mono_speaker2.wav"]: | |
if os.path.exists(temp_file): | |
os.remove(temp_file) | |
return clean_output |