File size: 5,239 Bytes
db55266
 
 
 
be25d7c
 
 
 
 
db55266
be25d7c
 
 
 
 
0b63b29
be25d7c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
db55266
be25d7c
 
 
 
 
 
 
 
 
0b63b29
be25d7c
 
 
 
 
 
 
 
 
0b63b29
be25d7c
 
 
 
0b63b29
be25d7c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0b63b29
be25d7c
 
0b63b29
be25d7c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
db55266
be25d7c
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import os
import torch
import gradio as gr
from pydub import AudioSegment
from transformers import pipeline, AutoModelForSpeechSeq2Seq, AutoProcessor
from pathlib import Path
from tempfile import NamedTemporaryFile
from datetime import timedelta
import time

# Configuration
MODEL_ID = "KBLab/kb-whisper-large"
CHUNK_DURATION_MS = 10000
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
TORCH_DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32

# Initialize model and pipeline
def initialize_pipeline():
    model = AutoModelForSpeechSeq2Seq.from_pretrained(
        MODEL_ID,
        torch_dtype=TORCH_DTYPE,
        low_cpu_mem_usage=True
    ).to(DEVICE)
    
    processor = AutoProcessor.from_pretrained(MODEL_ID)
    
    return pipeline(
        "automatic-speech-recognition",
        model=model,
        tokenizer=processor.tokenizer,
        feature_extractor=processor.feature_extractor,
        device=DEVICE,
        torch_dtype=TORCH_DTYPE,
        model_kwargs={"use_flash_attention_2": torch.cuda.is_available()}
    )

# Convert audio if needed
def convert_to_wav(audio_path: str) -> str:
    ext = str(Path(audio_path).suffix).lower()
    if ext != ".wav":
        audio = AudioSegment.from_file(audio_path)
        wav_path = str(Path(audio_path).with_suffix(".converted.wav"))
        audio.export(wav_path, format="wav")
        return wav_path
    return audio_path

# Split audio into chunks
def split_audio(audio_path: str) -> list:
    try:
        audio = AudioSegment.from_file(audio_path)
        if len(audio) == 0:
            raise ValueError("Audio file is empty or invalid.")
        return [audio[i:i + CHUNK_DURATION_MS] around(i, len(audio), CHUNK_DURATION_MS) for i in range(0, len(audio), CHUNK_DURATION_MS)]
    except Exception as e:
        raise ValueError(f"Failed to process audio: {str(e)}")

# Helper to compute chunk start time
def get_chunk_time(index: int, chunk_duration_ms: int) -> str:
    start_ms = index * chunk_duration_ms
    return str(timedelta(milliseconds=start_ms))

# Transcribe audio with progress and timestamps
def transcribe(audio_path: str, include_timestamps: bool = False, progress=gr.Progress()):
    try:
        if not audio_path:
            return "No audio file provided.", None
            
        # Convert to WAV if needed
        wav_path = convert_to_wav(audio_path)
        
        # Split and process
        chunks = split_audio(wav_path)
        total_chunks = len(chunks)
        transcript = []
        timestamped_transcript = []
        
        for i, chunk in enumerate(chunks):
            try:
                with NamedTemporaryFile(suffix=".wav", delete=False) as temp_file:
                    chunk.export(temp_file.name, format="wav")
                    result = PIPELINE(temp_file.name, 
                                   generate_kwargs={"task": "transcribe", "language": "sv"})
                    text = result["text"].strip()
                    transcript.append(text)
                    if include_timestamps:
                        timestamp = get_chunk_time(i, CHUNK_DURATION_MS)
                        timestamped_transcript.append(f"[{timestamp}] {text}")
            finally:
                if os.path.exists(temp_file.name):
                    os.remove(temp_file.name)
                    
            progress((i + 1) / total_chunks)
            yield " ".join(transcript), None
            
        # Clean up converted file if created
        if wav_path != audio_path and os.path.exists(wav_path):
            os.remove(wav_path)
            
        # Prepare final transcript and downloadable file
        final_transcript = " ".join(transcript)
        download_content = "\n".join(timestamped_transcript) if include_timestamps else final_transcript
        with NamedTemporaryFile(suffix=".txt", delete=False, mode='w', encoding='utf-8') as temp_file:
            temp_file.write(download_content)
            download_path = temp_file.name
            
        return final_transcript, download_path
        
    except Exception as e:
        return f"Error during transcription: {str(e)}", None

# Initialize pipeline globally
PIPELINE = initialize_pipeline()

# Gradio Interface with Blocks
def create_interface():
    with gr.Blocks(theme=gr.themes.Soft()) as demo:
        gr.Markdown("# Swedish Whisper Transcriber")
        gr.Markdown("Upload audio (.wav, .mp3, .m4a) for real-time Swedish speech transcription.")
        
        with gr.Row():
            with gr.Column():
                audio_input = gr.Audio(type="filepath", label="Upload Audio")
                timestamp_toggle = gr.Checkbox(label="Include Timestamps in Download", value=False)
                transcribe_btn = gr.Button("Transcribe")
            
            with gr.Column():
                transcript_output = gr.Textbox(label="Live Transcription", lines=10)
                download_output = gr.File(label="Download Transcript")
        
        transcribe_btn.click(
            fn=transcribe,
            inputs=[audio_input, timestamp_toggle],
            outputs=[transcript_output, download_output]
        )
    
    return demo

if __name__ == "__main__":
    create_interface().launch()