Spaces:
Running
Running
import gradio as gr | |
import torchaudio | |
import torch | |
from transformers import WhisperProcessor, WhisperForConditionalGeneration, AutomaticSpeechRecognitionPipeline | |
import numpy as np | |
import tempfile | |
import os | |
# 全域變數存儲模型 | |
processor = None | |
model = None | |
asr_pipeline = None | |
def load_model(): | |
"""載入 Breeze ASR 25 模型""" | |
global processor, model, asr_pipeline | |
try: | |
processor = WhisperProcessor.from_pretrained("MediaTek-Research/Breeze-ASR-25") | |
model = WhisperForConditionalGeneration.from_pretrained("MediaTek-Research/Breeze-ASR-25") | |
# 檢查是否有 CUDA | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
model = model.to(device).eval() | |
# 建立 pipeline | |
asr_pipeline = AutomaticSpeechRecognitionPipeline( | |
model=model, | |
tokenizer=processor.tokenizer, | |
feature_extractor=processor.feature_extractor, | |
chunk_length_s=0, | |
device=device | |
) | |
return f"✅ 模型載入成功!使用設備: {device}" | |
except Exception as e: | |
return f"❌ 模型載入失敗: {str(e)}" | |
def preprocess_audio(audio_path): | |
"""音訊預處理""" | |
# 載入音訊 | |
waveform, sample_rate = torchaudio.load(audio_path) | |
# 轉為單聲道 | |
if waveform.shape[0] > 1: | |
waveform = waveform.mean(dim=0) | |
waveform = waveform.squeeze().numpy() | |
# 重採樣到 16kHz | |
if sample_rate != 16000: | |
resampler = torchaudio.transforms.Resample(sample_rate, 16000) | |
waveform = resampler(torch.tensor(waveform)).numpy() | |
return waveform | |
def transcribe_audio(audio_input): | |
"""語音辨識主函數""" | |
global asr_pipeline | |
try: | |
# 檢查模型是否已載入 | |
if asr_pipeline is None: | |
status = load_model() | |
if "失敗" in status: | |
return status, "", "", "" | |
# 檢查音訊輸入 | |
if audio_input is None: | |
return "❌ 請先上傳音訊檔案或進行錄音", "", "", "" | |
# 處理不同的音訊輸入格式 | |
if isinstance(audio_input, str): | |
# 檔案路徑 | |
audio_path = audio_input | |
elif isinstance(audio_input, tuple): | |
# Gradio 錄音格式 (sample_rate, audio_data) | |
sample_rate, audio_data = audio_input | |
# 建立臨時檔案 | |
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file: | |
# 確保音訊數據格式正確 | |
if audio_data.dtype != np.float32: | |
audio_data = audio_data.astype(np.float32) | |
# 正規化音訊 | |
if audio_data.max() > 1.0: | |
audio_data = audio_data / 32768.0 | |
# 儲存為 wav 檔案 | |
torchaudio.save(tmp_file.name, torch.tensor(audio_data).unsqueeze(0), sample_rate) | |
audio_path = tmp_file.name | |
else: | |
return "❌ 不支援的音訊格式", "", "", "" | |
# 預處理音訊 | |
waveform = preprocess_audio(audio_path) | |
# 執行語音辨識 | |
result = asr_pipeline(waveform, return_timestamps=True) | |
# 清理臨時檔案 | |
if isinstance(audio_input, tuple) and os.path.exists(audio_path): | |
os.unlink(audio_path) | |
# 格式化結果 | |
transcription = result["text"].strip() | |
# 格式化時間戳記顯示 | |
formatted_text = "" | |
pure_text = "" | |
srt_text = "" | |
if "chunks" in result and result["chunks"]: | |
for i, chunk in enumerate(result["chunks"], 1): | |
start_time = chunk["timestamp"][0] if chunk["timestamp"][0] is not None else 0 | |
end_time = chunk["timestamp"][1] if chunk["timestamp"][1] is not None else 0 | |
text = chunk['text'].strip() | |
if text: # 只處理非空文字 | |
# 格式化顯示文字 | |
#formatted_text += f"[{start_time:.2f}s - {end_time:.2f}s]: {text}\n" | |
# 純文字(不含時間戳記) | |
pure_text += f"{text}\n" | |
# SRT 格式 | |
start_srt = f"{int(start_time//3600):02d}:{int((start_time%3600)//60):02d}:{int(start_time%60):02d},{int((start_time%1)*1000):03d}" | |
end_srt = f"{int(end_time//3600):02d}:{int((end_time%3600)//60):02d}:{int(end_time%60):02d},{int((end_time%1)*1000):03d}" | |
srt_text += f"{i}\n{start_srt} --> {end_srt}\n{text}\n\n" | |
else: | |
# 如果沒有時間戳記,只顯示文字 | |
#formatted_text = transcription | |
pure_text = transcription | |
srt_text = f"1\n00:00:00,000 --> 00:00:10,000\n{transcription}\n\n" | |
return "✅ 辨識完成", pure_text.strip(), srt_text.strip() | |
except Exception as e: | |
return f"❌ 辨識過程發生錯誤: {str(e)}", "" | |
def clear_all(): | |
"""清除所有內容""" | |
return None, "🔄 已清除所有內容", "", "", "" | |
# 建立 Gradio 介面 | |
with gr.Blocks(title="語音辨識系統", theme=gr.themes.Soft()) as demo: | |
gr.Markdown(""" | |
# 🎤 語音辨識系統 - Breeze ASR 25 | |
### 功能特色: | |
- 🔧 使用 Breeze ASR 25 模型,專為繁體中文優化 | |
- ⏰ 顯示時間戳記 | |
- 🌐 強化中英混用辨識能力 | |
- 感謝[MediaTek-Research/Breeze-ASR-25](https://huggingface.co/MediaTek-Research/Breeze-ASR-25) | |
""") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
# 音訊輸入區域 | |
gr.Markdown("### 📂 音訊輸入") | |
with gr.Tab("檔案上傳"): | |
audio_file = gr.Audio( | |
sources=["upload"], | |
label="上傳音訊檔案", | |
type="filepath" | |
) | |
# 控制按鈕 | |
with gr.Row(): | |
transcribe_btn = gr.Button("🚀 開始辨識", variant="primary", size="lg") | |
clear_btn = gr.Button("🗑️ 清除", variant="secondary") | |
with gr.Column(scale=1): | |
# 狀態顯示 | |
status_output = gr.Textbox( | |
label="📊 狀態", | |
placeholder="等待操作...", | |
interactive=False, | |
lines=2 | |
) | |
# 純文字結果 | |
pure_text_output = gr.Textbox( | |
label="📄 純文字結果", | |
placeholder="純文字結果...", | |
lines=4, | |
max_lines=10, | |
show_copy_button=True | |
) | |
# SRT 字幕格式 | |
srt_output = gr.Textbox( | |
label="🎬 SRT 字幕格式", | |
placeholder="SRT 格式字幕...", | |
lines=6, | |
max_lines=15, | |
show_copy_button=True | |
) | |
# 修正事件綁定 | |
def transcribe_wrapper(audio_file_val, audio_mic_val): | |
audio_input = audio_file_val if audio_file_val else audio_mic_val | |
return transcribe_audio(audio_input) | |
transcribe_btn.click( | |
fn=transcribe_wrapper, | |
inputs=[audio_file], | |
outputs=[status_output, pure_text_output, srt_output] | |
) | |
clear_btn.click( | |
fn=clear_all, | |
outputs=[audio_file, status_output, pure_text_output, srt_output] | |
) | |
# 啟動應用 | |
if __name__ == "__main__": | |
demo.launch() |