Spaces:
Running
Running
import os | |
import torch | |
import gradio as gr | |
from pydub import AudioSegment | |
from transformers import pipeline, AutoModelForSpeechSeq2Seq, AutoProcessor | |
from pathlib import Path | |
from tempfile import NamedTemporaryFile | |
from datetime import timedelta | |
import time | |
# Configuration | |
MODEL_ID = "KBLab/kb-whisper-large" | |
CHUNK_DURATION_MS = 10000 | |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
TORCH_DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32 | |
# Initialize model and pipeline | |
def initialize_pipeline(): | |
model = AutoModelForSpeechSeq2Seq.from_pretrained( | |
MODEL_ID, | |
torch_dtype=TORCH_DTYPE, | |
low_cpu_mem_usage=True | |
).to(DEVICE) | |
processor = AutoProcessor.from_pretrained(MODEL_ID) | |
return pipeline( | |
"automatic-speech-recognition", | |
model=model, | |
tokenizer=processor.tokenizer, | |
feature_extractor=processor.feature_extractor, | |
device=DEVICE, | |
torch_dtype=TORCH_DTYPE, | |
model_kwargs={"use_flash_attention_2": torch.cuda.is_available()} | |
) | |
# Convert audio if needed | |
def convert_to_wav(audio_path: str) -> str: | |
ext = str(Path(audio_path).suffix).lower() | |
if ext != ".wav": | |
audio = AudioSegment.from_file(audio_path) | |
wav_path = str(Path(audio_path).with_suffix(".converted.wav")) | |
audio.export(wav_path, format="wav") | |
return wav_path | |
return audio_path | |
# Split audio into chunks | |
def split_audio(audio_path: str) -> list: | |
try: | |
audio = AudioSegment.from_file(audio_path) | |
if len(audio) == 0: | |
raise ValueError("Audio file is empty or invalid.") | |
return [audio[i:i + CHUNK_DURATION_MS] around(i, len(audio), CHUNK_DURATION_MS) for i in range(0, len(audio), CHUNK_DURATION_MS)] | |
except Exception as e: | |
raise ValueError(f"Failed to process audio: {str(e)}") | |
# Helper to compute chunk start time | |
def get_chunk_time(index: int, chunk_duration_ms: int) -> str: | |
start_ms = index * chunk_duration_ms | |
return str(timedelta(milliseconds=start_ms)) | |
# Transcribe audio with progress and timestamps | |
def transcribe(audio_path: str, include_timestamps: bool = False, progress=gr.Progress()): | |
try: | |
if not audio_path: | |
return "No audio file provided.", None | |
# Convert to WAV if needed | |
wav_path = convert_to_wav(audio_path) | |
# Split and process | |
chunks = split_audio(wav_path) | |
total_chunks = len(chunks) | |
transcript = [] | |
timestamped_transcript = [] | |
for i, chunk in enumerate(chunks): | |
try: | |
with NamedTemporaryFile(suffix=".wav", delete=False) as temp_file: | |
chunk.export(temp_file.name, format="wav") | |
result = PIPELINE(temp_file.name, | |
generate_kwargs={"task": "transcribe", "language": "sv"}) | |
text = result["text"].strip() | |
transcript.append(text) | |
if include_timestamps: | |
timestamp = get_chunk_time(i, CHUNK_DURATION_MS) | |
timestamped_transcript.append(f"[{timestamp}] {text}") | |
finally: | |
if os.path.exists(temp_file.name): | |
os.remove(temp_file.name) | |
progress((i + 1) / total_chunks) | |
yield " ".join(transcript), None | |
# Clean up converted file if created | |
if wav_path != audio_path and os.path.exists(wav_path): | |
os.remove(wav_path) | |
# Prepare final transcript and downloadable file | |
final_transcript = " ".join(transcript) | |
download_content = "\n".join(timestamped_transcript) if include_timestamps else final_transcript | |
with NamedTemporaryFile(suffix=".txt", delete=False, mode='w', encoding='utf-8') as temp_file: | |
temp_file.write(download_content) | |
download_path = temp_file.name | |
return final_transcript, download_path | |
except Exception as e: | |
return f"Error during transcription: {str(e)}", None | |
# Initialize pipeline globally | |
PIPELINE = initialize_pipeline() | |
# Gradio Interface with Blocks | |
def create_interface(): | |
with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
gr.Markdown("# Swedish Whisper Transcriber") | |
gr.Markdown("Upload audio (.wav, .mp3, .m4a) for real-time Swedish speech transcription.") | |
with gr.Row(): | |
with gr.Column(): | |
audio_input = gr.Audio(type="filepath", label="Upload Audio") | |
timestamp_toggle = gr.Checkbox(label="Include Timestamps in Download", value=False) | |
transcribe_btn = gr.Button("Transcribe") | |
with gr.Column(): | |
transcript_output = gr.Textbox(label="Live Transcription", lines=10) | |
download_output = gr.File(label="Download Transcript") | |
transcribe_btn.click( | |
fn=transcribe, | |
inputs=[audio_input, timestamp_toggle], | |
outputs=[transcript_output, download_output] | |
) | |
return demo | |
if __name__ == "__main__": | |
create_interface().launch() |