WizardForest's picture
Update app.py
28da944 verified
import gradio as gr
import torchaudio
import torch
from transformers import WhisperProcessor, WhisperForConditionalGeneration, AutomaticSpeechRecognitionPipeline
import numpy as np
import tempfile
import os
# 全域變數存儲模型
processor = None
model = None
asr_pipeline = None
def load_model():
"""載入 Breeze ASR 25 模型"""
global processor, model, asr_pipeline
try:
processor = WhisperProcessor.from_pretrained("MediaTek-Research/Breeze-ASR-25")
model = WhisperForConditionalGeneration.from_pretrained("MediaTek-Research/Breeze-ASR-25")
# 檢查是否有 CUDA
device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device).eval()
# 建立 pipeline
asr_pipeline = AutomaticSpeechRecognitionPipeline(
model=model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
chunk_length_s=0,
device=device
)
return f"✅ 模型載入成功!使用設備: {device}"
except Exception as e:
return f"❌ 模型載入失敗: {str(e)}"
def preprocess_audio(audio_path):
"""音訊預處理"""
# 載入音訊
waveform, sample_rate = torchaudio.load(audio_path)
# 轉為單聲道
if waveform.shape[0] > 1:
waveform = waveform.mean(dim=0)
waveform = waveform.squeeze().numpy()
# 重採樣到 16kHz
if sample_rate != 16000:
resampler = torchaudio.transforms.Resample(sample_rate, 16000)
waveform = resampler(torch.tensor(waveform)).numpy()
return waveform
def transcribe_audio(audio_input):
"""語音辨識主函數"""
global asr_pipeline
try:
# 檢查模型是否已載入
if asr_pipeline is None:
status = load_model()
if "失敗" in status:
return status, "", "", ""
# 檢查音訊輸入
if audio_input is None:
return "❌ 請先上傳音訊檔案或進行錄音", "", "", ""
# 處理不同的音訊輸入格式
if isinstance(audio_input, str):
# 檔案路徑
audio_path = audio_input
elif isinstance(audio_input, tuple):
# Gradio 錄音格式 (sample_rate, audio_data)
sample_rate, audio_data = audio_input
# 建立臨時檔案
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file:
# 確保音訊數據格式正確
if audio_data.dtype != np.float32:
audio_data = audio_data.astype(np.float32)
# 正規化音訊
if audio_data.max() > 1.0:
audio_data = audio_data / 32768.0
# 儲存為 wav 檔案
torchaudio.save(tmp_file.name, torch.tensor(audio_data).unsqueeze(0), sample_rate)
audio_path = tmp_file.name
else:
return "❌ 不支援的音訊格式", "", "", ""
# 預處理音訊
waveform = preprocess_audio(audio_path)
# 執行語音辨識
result = asr_pipeline(waveform, return_timestamps=True)
# 清理臨時檔案
if isinstance(audio_input, tuple) and os.path.exists(audio_path):
os.unlink(audio_path)
# 格式化結果
transcription = result["text"].strip()
# 格式化時間戳記顯示
formatted_text = ""
pure_text = ""
srt_text = ""
if "chunks" in result and result["chunks"]:
for i, chunk in enumerate(result["chunks"], 1):
start_time = chunk["timestamp"][0] if chunk["timestamp"][0] is not None else 0
end_time = chunk["timestamp"][1] if chunk["timestamp"][1] is not None else 0
text = chunk['text'].strip()
if text: # 只處理非空文字
# 格式化顯示文字
#formatted_text += f"[{start_time:.2f}s - {end_time:.2f}s]: {text}\n"
# 純文字(不含時間戳記)
pure_text += f"{text}\n"
# SRT 格式
start_srt = f"{int(start_time//3600):02d}:{int((start_time%3600)//60):02d}:{int(start_time%60):02d},{int((start_time%1)*1000):03d}"
end_srt = f"{int(end_time//3600):02d}:{int((end_time%3600)//60):02d}:{int(end_time%60):02d},{int((end_time%1)*1000):03d}"
srt_text += f"{i}\n{start_srt} --> {end_srt}\n{text}\n\n"
else:
# 如果沒有時間戳記,只顯示文字
#formatted_text = transcription
pure_text = transcription
srt_text = f"1\n00:00:00,000 --> 00:00:10,000\n{transcription}\n\n"
return "✅ 辨識完成", pure_text.strip(), srt_text.strip()
except Exception as e:
return f"❌ 辨識過程發生錯誤: {str(e)}", ""
def clear_all():
"""清除所有內容"""
return None, "🔄 已清除所有內容", "", "", ""
# 建立 Gradio 介面
with gr.Blocks(title="語音辨識系統", theme=gr.themes.Soft()) as demo:
gr.Markdown("""
# 🎤 語音辨識系統 - Breeze ASR 25
### 功能特色:
- 🔧 使用 Breeze ASR 25 模型,專為繁體中文優化
- ⏰ 顯示時間戳記
- 🌐 強化中英混用辨識能力
- 感謝[MediaTek-Research/Breeze-ASR-25](https://huggingface.co/MediaTek-Research/Breeze-ASR-25)
""")
with gr.Row():
with gr.Column(scale=1):
# 音訊輸入區域
gr.Markdown("### 📂 音訊輸入")
with gr.Tab("檔案上傳"):
audio_file = gr.Audio(
sources=["upload"],
label="上傳音訊檔案",
type="filepath"
)
# 控制按鈕
with gr.Row():
transcribe_btn = gr.Button("🚀 開始辨識", variant="primary", size="lg")
clear_btn = gr.Button("🗑️ 清除", variant="secondary")
with gr.Column(scale=1):
# 狀態顯示
status_output = gr.Textbox(
label="📊 狀態",
placeholder="等待操作...",
interactive=False,
lines=2
)
# 純文字結果
pure_text_output = gr.Textbox(
label="📄 純文字結果",
placeholder="純文字結果...",
lines=4,
max_lines=10,
show_copy_button=True
)
# SRT 字幕格式
srt_output = gr.Textbox(
label="🎬 SRT 字幕格式",
placeholder="SRT 格式字幕...",
lines=6,
max_lines=15,
show_copy_button=True
)
# 修正事件綁定
def transcribe_wrapper(audio_file_val, audio_mic_val):
audio_input = audio_file_val if audio_file_val else audio_mic_val
return transcribe_audio(audio_input)
transcribe_btn.click(
fn=transcribe_wrapper,
inputs=[audio_file],
outputs=[status_output, pure_text_output, srt_output]
)
clear_btn.click(
fn=clear_all,
outputs=[audio_file, status_output, pure_text_output, srt_output]
)
# 啟動應用
if __name__ == "__main__":
demo.launch()