Spaces:
Running
Running
import os | |
import torch | |
import gradio as gr | |
from pydub import AudioSegment | |
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor | |
import tempfile | |
import math | |
from datasets import load_dataset, Audio | |
import numpy as np | |
import torchaudio | |
# Set up model | |
device = "cpu" | |
torch_dtype = torch.float32 | |
model_id = "KBLab/kb-whisper-large" | |
processor = AutoProcessor.from_pretrained(model_id) | |
model = AutoModelForSpeechSeq2Seq.from_pretrained(model_id, torch_dtype=torch_dtype).to(device) | |
# Helper: Split audio into chunks | |
def split_audio(audio_path, chunk_duration_ms=10000): | |
audio = AudioSegment.from_file(audio_path) | |
chunks = [audio[i:i + chunk_duration_ms] for i in range(0, len(audio), chunk_duration_ms)] | |
return chunks | |
# Helper: Transcribe a single chunk | |
def transcribe_chunk(chunk): | |
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmpfile: | |
chunk.export(tmpfile.name, format="wav") | |
input_audio, _ = torchaudio.load(tmpfile.name) | |
input_features = processor(input_audio.squeeze(), sampling_rate=16000, return_tensors="pt").input_features | |
input_features = input_features.to(device) | |
predicted_ids = model.generate(input_features, task="transcribe", language="sv") | |
transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0] | |
os.remove(tmpfile.name) | |
return transcription | |
# Full transcription function with progress | |
def transcribe_with_progress(audio_path, progress=gr.Progress()): | |
ext = os.path.splitext(audio_path)[1].lower() | |
if ext != ".wav": | |
sound = AudioSegment.from_file(audio_path) | |
audio_path = audio_path.replace(ext, ".converted.wav") | |
sound.export(audio_path, format="wav") | |
chunks = split_audio(audio_path, chunk_duration_ms=8000) | |
full_transcript = "" | |
total_chunks = len(chunks) | |
for i, chunk in enumerate(chunks): | |
partial_text = transcribe_chunk(chunk) | |
full_transcript += partial_text + " " | |
progress(i + 1, total_chunks) # Update progress bar | |
yield full_transcript.strip() # Stream updated text to UI | |
# UI | |
gr.Interface( | |
fn=transcribe_with_progress, | |
inputs=gr.Audio(type="filepath", label="Upload Swedish Audio"), | |
outputs=gr.Textbox(label="Live Transcript (Swedish)"), | |
title="Live Swedish Transcriber (KB-Whisper)", | |
description="Streams transcription word-by-word with visual progress. Supports .m4a, .mp3, .wav. May be slow on CPU.", | |
live=True | |
).launch() | |