Spaces:
Running
Running
File size: 7,770 Bytes
8854b79 28da944 8854b79 28da944 8854b79 28da944 8854b79 28da944 8854b79 28da944 8854b79 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 |
import gradio as gr
import torchaudio
import torch
from transformers import WhisperProcessor, WhisperForConditionalGeneration, AutomaticSpeechRecognitionPipeline
import numpy as np
import tempfile
import os
# 全域變數存儲模型
processor = None
model = None
asr_pipeline = None
def load_model():
"""載入 Breeze ASR 25 模型"""
global processor, model, asr_pipeline
try:
processor = WhisperProcessor.from_pretrained("MediaTek-Research/Breeze-ASR-25")
model = WhisperForConditionalGeneration.from_pretrained("MediaTek-Research/Breeze-ASR-25")
# 檢查是否有 CUDA
device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device).eval()
# 建立 pipeline
asr_pipeline = AutomaticSpeechRecognitionPipeline(
model=model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
chunk_length_s=0,
device=device
)
return f"✅ 模型載入成功!使用設備: {device}"
except Exception as e:
return f"❌ 模型載入失敗: {str(e)}"
def preprocess_audio(audio_path):
"""音訊預處理"""
# 載入音訊
waveform, sample_rate = torchaudio.load(audio_path)
# 轉為單聲道
if waveform.shape[0] > 1:
waveform = waveform.mean(dim=0)
waveform = waveform.squeeze().numpy()
# 重採樣到 16kHz
if sample_rate != 16000:
resampler = torchaudio.transforms.Resample(sample_rate, 16000)
waveform = resampler(torch.tensor(waveform)).numpy()
return waveform
def transcribe_audio(audio_input):
"""語音辨識主函數"""
global asr_pipeline
try:
# 檢查模型是否已載入
if asr_pipeline is None:
status = load_model()
if "失敗" in status:
return status, "", "", ""
# 檢查音訊輸入
if audio_input is None:
return "❌ 請先上傳音訊檔案或進行錄音", "", "", ""
# 處理不同的音訊輸入格式
if isinstance(audio_input, str):
# 檔案路徑
audio_path = audio_input
elif isinstance(audio_input, tuple):
# Gradio 錄音格式 (sample_rate, audio_data)
sample_rate, audio_data = audio_input
# 建立臨時檔案
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file:
# 確保音訊數據格式正確
if audio_data.dtype != np.float32:
audio_data = audio_data.astype(np.float32)
# 正規化音訊
if audio_data.max() > 1.0:
audio_data = audio_data / 32768.0
# 儲存為 wav 檔案
torchaudio.save(tmp_file.name, torch.tensor(audio_data).unsqueeze(0), sample_rate)
audio_path = tmp_file.name
else:
return "❌ 不支援的音訊格式", "", "", ""
# 預處理音訊
waveform = preprocess_audio(audio_path)
# 執行語音辨識
result = asr_pipeline(waveform, return_timestamps=True)
# 清理臨時檔案
if isinstance(audio_input, tuple) and os.path.exists(audio_path):
os.unlink(audio_path)
# 格式化結果
transcription = result["text"].strip()
# 格式化時間戳記顯示
formatted_text = ""
pure_text = ""
srt_text = ""
if "chunks" in result and result["chunks"]:
for i, chunk in enumerate(result["chunks"], 1):
start_time = chunk["timestamp"][0] if chunk["timestamp"][0] is not None else 0
end_time = chunk["timestamp"][1] if chunk["timestamp"][1] is not None else 0
text = chunk['text'].strip()
if text: # 只處理非空文字
# 格式化顯示文字
#formatted_text += f"[{start_time:.2f}s - {end_time:.2f}s]: {text}\n"
# 純文字(不含時間戳記)
pure_text += f"{text}\n"
# SRT 格式
start_srt = f"{int(start_time//3600):02d}:{int((start_time%3600)//60):02d}:{int(start_time%60):02d},{int((start_time%1)*1000):03d}"
end_srt = f"{int(end_time//3600):02d}:{int((end_time%3600)//60):02d}:{int(end_time%60):02d},{int((end_time%1)*1000):03d}"
srt_text += f"{i}\n{start_srt} --> {end_srt}\n{text}\n\n"
else:
# 如果沒有時間戳記,只顯示文字
#formatted_text = transcription
pure_text = transcription
srt_text = f"1\n00:00:00,000 --> 00:00:10,000\n{transcription}\n\n"
return "✅ 辨識完成", pure_text.strip(), srt_text.strip()
except Exception as e:
return f"❌ 辨識過程發生錯誤: {str(e)}", ""
def clear_all():
"""清除所有內容"""
return None, "🔄 已清除所有內容", "", "", ""
# 建立 Gradio 介面
with gr.Blocks(title="語音辨識系統", theme=gr.themes.Soft()) as demo:
gr.Markdown("""
# 🎤 語音辨識系統 - Breeze ASR 25
### 功能特色:
- 🔧 使用 Breeze ASR 25 模型,專為繁體中文優化
- ⏰ 顯示時間戳記
- 🌐 強化中英混用辨識能力
- 感謝[MediaTek-Research/Breeze-ASR-25](https://huggingface.co/MediaTek-Research/Breeze-ASR-25)
""")
with gr.Row():
with gr.Column(scale=1):
# 音訊輸入區域
gr.Markdown("### 📂 音訊輸入")
with gr.Tab("檔案上傳"):
audio_file = gr.Audio(
sources=["upload"],
label="上傳音訊檔案",
type="filepath"
)
# 控制按鈕
with gr.Row():
transcribe_btn = gr.Button("🚀 開始辨識", variant="primary", size="lg")
clear_btn = gr.Button("🗑️ 清除", variant="secondary")
with gr.Column(scale=1):
# 狀態顯示
status_output = gr.Textbox(
label="📊 狀態",
placeholder="等待操作...",
interactive=False,
lines=2
)
# 純文字結果
pure_text_output = gr.Textbox(
label="📄 純文字結果",
placeholder="純文字結果...",
lines=4,
max_lines=10,
show_copy_button=True
)
# SRT 字幕格式
srt_output = gr.Textbox(
label="🎬 SRT 字幕格式",
placeholder="SRT 格式字幕...",
lines=6,
max_lines=15,
show_copy_button=True
)
# 修正事件綁定
def transcribe_wrapper(audio_file_val, audio_mic_val):
audio_input = audio_file_val if audio_file_val else audio_mic_val
return transcribe_audio(audio_input)
transcribe_btn.click(
fn=transcribe_wrapper,
inputs=[audio_file],
outputs=[status_output, pure_text_output, srt_output]
)
clear_btn.click(
fn=clear_all,
outputs=[audio_file, status_output, pure_text_output, srt_output]
)
# 啟動應用
if __name__ == "__main__":
demo.launch() |