File size: 5,328 Bytes
b7fa1b5
5bcf187
95d83f6
b7fa1b5
 
5bcf187
 
95d83f6
c01ffa1
5bcf187
b7fa1b5
 
 
95d83f6
 
 
 
5bcf187
 
 
 
95d83f6
5bcf187
bd82412
b0cf7f8
5bcf187
 
95d83f6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0053d2c
 
5bcf187
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b7fa1b5
5bcf187
 
b7fa1b5
5bcf187
 
0053d2c
621a46f
 
5669eef
4bd685e
 
5669eef
5bcf187
 
 
 
 
95d83f6
 
 
 
 
 
 
 
 
 
5669eef
5bcf187
 
 
 
95d83f6
11d760e
bfe3b33
 
95d83f6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a41c75a
 
 
5669eef
95d83f6
 
24a2822
790d7cc
 
95d83f6
 
 
 
 
 
 
 
 
621a46f
790d7cc
95d83f6
4bd685e
95d83f6
 
 
 
 
5669eef
0053d2c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
from pydub import AudioSegment
import os
from transformers import WhisperForConditionalGeneration, WhisperProcessor, WhisperTokenizer
import torchaudio
import torch
import re
from transformers import pipeline
from peft import PeftModel, PeftConfig
import spaces

device = 0 if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float32

### Configuration
MODEL_NAME_V2 = "./whisper-large-v3-catalan"
MODEL_NAME_V1 = "projecte-aina/whisper-large-v3-tiny-caesar"
CHUNK_LENGTH = 30
BATCH_SIZE = 1

pipe = pipeline(
    task="automatic-speech-recognition",
    model=MODEL_NAME_V1,
    chunk_length_s=30,
    device=device,
    token=os.getenv("HF_TOKEN")
    ) 


peft_config = PeftConfig.from_pretrained(MODEL_NAME_V2)
model = WhisperForConditionalGeneration.from_pretrained(
    peft_config.base_model_name_or_path, 
    device_map="auto"
)

task = "transcribe"
    
model = PeftModel.from_pretrained(model, MODEL_NAME_V2)
model.config.use_cache = True
    
tokenizer = WhisperTokenizer.from_pretrained(peft_config.base_model_name_or_path, task=task)
processor = WhisperProcessor.from_pretrained(peft_config.base_model_name_or_path, task=task)
feature_extractor = processor.feature_extractor
forced_decoder_ids = processor.get_decoder_prompt_ids(task=task)
    
asr_pipe = pipeline(
    task="automatic-speech-recognition",
    model=model,
    tokenizer=tokenizer,
    feature_extractor=feature_extractor,
    chunk_length_s=30)

def asr(audio_path, task):
    asr_result = asr_pipe(audio_path, batch_size=BATCH_SIZE, generate_kwargs={"task":task}, return_timestamps=True)
    base_model = asr_pipe.model.base_model if hasattr(asr_pipe.model, "base_model") else asr_pipe.model
    return asr_result

def post_process_transcription(transcription, max_repeats=2):
    tokens = re.findall(r'\b\w+\'?\w*\b[.,!?]?', transcription)

    cleaned_tokens = []
    repetition_count = 0
    previous_token = None

    for token in tokens:
        reduced_token = re.sub(r"(\w{1,3})(\1{2,})", "", token)

        if reduced_token == previous_token:
            repetition_count += 1
            if repetition_count <= max_repeats:
                cleaned_tokens.append(reduced_token)
        else:
            repetition_count = 1
            cleaned_tokens.append(reduced_token)

        previous_token = reduced_token

    cleaned_transcription = " ".join(cleaned_tokens)
    cleaned_transcription = re.sub(r'\s+', ' ', cleaned_transcription).strip()

    return cleaned_transcription


def format_audio(audio_path):
    input_audio, sample_rate = torchaudio.load(audio_path)

    if input_audio.shape[0] == 2:  #stereo2mono
        input_audio = torch.mean(input_audio, dim=0, keepdim=True) 
    
    resampler = torchaudio.transforms.Resample(sample_rate, 16000)
    input_audio = resampler(input_audio)
    input_audio = input_audio.squeeze().numpy()
    return(input_audio)

def split_stereo_channels(audio_path):

    audio = AudioSegment.from_wav(audio_path)
    
    channels = audio.split_to_mono()
    if len(channels) != 2:
        raise ValueError(f"Audio {audio_path} does not have 2 channels.")

    channels[0].export(f"temp_mono_speaker1.wav", format="wav")  # Right
    channels[1].export(f"temp_mono_speaker2.wav", format="wav")  # Left

def transcribe_pipeline(audio, task):
    text = pipe(audio, batch_size=BATCH_SIZE, generate_kwargs={"task": task}, return_timestamps=True)["text"]
    return text

def generate(audio_path, use_v2):
    task = "transcribe"
    temp_mono_path = None 

    if use_v2:
        split_stereo_channels(audio_path)

        audio_id = os.path.splitext(os.path.basename(audio_path))[0]

        left_channel_path = "temp_mono_speaker2.wav"
        right_channel_path = "temp_mono_speaker1.wav"

        left_audio = format_audio(left_channel_path)
        right_audio = format_audio(right_channel_path)

        left_result = asr(left_audio, task)
        right_result = asr(right_audio, task)

        left_segs = [(seg["timestamp"][0], seg["timestamp"][1], "Speaker 1", post_process_transcription(seg["text"])) for seg in left_result["chunks"]]
        right_segs = [(seg["timestamp"][0], seg["timestamp"][1], "Speaker 2", post_process_transcription(seg["text"])) for seg in right_result["chunks"]]

        #merged_transcript = sorted(left_segs + right_segs, key=lambda x: x[0])
        merged_transcript = sorted(left_segs + right_segs, key=lambda x: x[0] if x[0] is not None else 0.0)

        
        output = ""
        for start, end, speaker, text in merged_transcript:
            output += f"[{speaker}]: {text}\n"        
       
        clean_output = output.strip()
        
    else: 
        audio = AudioSegment.from_wav(audio_path)

        if audio.channels != 1: #stereo2mono
            audio = audio.set_channels(1)
            temp_mono_path = "temp_mono.wav"
            audio.export(temp_mono_path, format="wav") 
            audio_path = temp_mono_path 
        output = transcribe_pipeline(format_audio(audio_path), task)
        clean_output = post_process_transcription(output)

    if temp_mono_path and os.path.exists(temp_mono_path):
        os.remove(temp_mono_path)
    
    for temp_file in ["temp_mono_speaker1.wav", "temp_mono_speaker2.wav"]:
        if os.path.exists(temp_file):
            os.remove(temp_file)
        
    return clean_output