Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import tempfile | |
import spaces | |
import torch | |
import torchaudio | |
from transformers import Qwen2_5OmniForConditionalGeneration, Qwen2_5OmniProcessor | |
from qwen_omni_utils import process_mm_info | |
from opencc import OpenCC | |
import gradio as gr | |
from pyannote.audio import Pipeline as DiarizationPipeline | |
from pydub import AudioSegment, effects | |
# Converter from Simplified to Traditional Chinese | |
cc = OpenCC("s2t") | |
# Define available model IDs | |
MODEL_IDS = { | |
"3B": "Qwen/Qwen2.5-Omni-3B", | |
"7B": "Qwen/Qwen2.5-Omni-7B" | |
} | |
# Caches for loaded models and processors | |
_models = {} | |
_processors = {} | |
def get_model_and_processor(size: str): | |
""" | |
Load and cache the model and processor for the given size ("3B" or "7B"). | |
""" | |
if size not in _models: | |
model_id = MODEL_IDS[size] | |
# Load model with device_map="auto" for ZeroGPU compatibility | |
m = Qwen2_5OmniForConditionalGeneration.from_pretrained( | |
model_id, | |
torch_dtype="auto", | |
device_map="auto" | |
) | |
m.disable_talker() | |
m.eval() | |
p = Qwen2_5OmniProcessor.from_pretrained(model_id) | |
_models[size] = m | |
_processors[size] = p | |
return _models[size], _processors[size] | |
# Cache the diarization pipeline so we only load it once | |
_diar_pipe = None | |
def get_diarization_pipe(): | |
global _diar_pipe | |
if _diar_pipe is None: | |
hf_token = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_TOKEN") | |
try: | |
_diar_pipe = DiarizationPipeline.from_pretrained( | |
"pyannote/speaker-diarization-3.1", | |
use_auth_token=hf_token or True | |
) | |
except Exception: | |
_diar_pipe = DiarizationPipeline.from_pretrained( | |
"pyannote/speaker-diarization@2.1", | |
use_auth_token=hf_token or True | |
) | |
return _diar_pipe | |
# Format a list of "[SPEAKER_X] text" snippets into colored HTML | |
def format_diarization_html(snippets): | |
palette = ["#e74c3c", "#3498db", "#27ae60", "#e67e22", "#9b59b6", "#16a085", "#f1c40f"] | |
speaker_colors = {} | |
html_lines = [] | |
last_spk = None | |
for s in snippets: | |
if s.startswith("[") and "]" in s: | |
spk, txt = s[1:].split("]", 1) | |
spk, txt = spk.strip(), txt.strip() | |
else: | |
spk, txt = "", s.strip() | |
if not txt: | |
continue | |
if spk not in speaker_colors: | |
speaker_colors[spk] = palette[len(speaker_colors) % len(palette)] | |
color = speaker_colors[spk] | |
if spk == last_spk: | |
display = txt | |
else: | |
display = f"<strong>{spk}:</strong> {txt}" | |
last_spk = spk | |
html_lines.append( | |
f"<p style='margin:4px 0; font-family:monospace; color:{color};'>{display}</p>" | |
) | |
return "<div>" + "".join(html_lines) + "</div>" | |
def _strip_prompts(full_text: str) -> str: | |
""" | |
Remove system/user/assistant prefixes so only the actual ASR transcript remains. | |
""" | |
marker = "assistant" | |
if marker in full_text: | |
return full_text.split(marker, 1)[1].strip() | |
else: | |
return full_text.strip() | |
def run_asr( | |
audio_path: str, | |
user_prompt: str, | |
model_size: str | |
): | |
# Validate inputs | |
if not audio_path: | |
yield format_diarization_html(["⚠️ Please upload an audio file first."]) | |
return | |
# Load diarization model onto GPU/CPU | |
diarizer = get_diarization_pipe() | |
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") | |
diarizer.to(device) | |
# Load waveform + sample rate and push to device | |
waveform, sample_rate = torchaudio.load(audio_path) | |
waveform = waveform.to(device) | |
# Get appropriate Qwen model & processor based on selection | |
model, processor = get_model_and_processor(model_size) | |
model.to(device) | |
# Run diarization to get speaker turns | |
diary = diarizer({"waveform": waveform, "sample_rate": sample_rate}) | |
snippets = [] | |
# For each speaker turn, slice audio, transcribe, convert, accumulate | |
for turn, _, speaker in diary.itertracks(yield_label=True): | |
start_ms = int(turn.start * 1000) | |
end_ms = int(turn.end * 1000) | |
# Extract the segment, normalize, export to temp file | |
segment = AudioSegment.from_file(audio_path)[start_ms:end_ms] | |
segment = effects.normalize(segment) | |
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp: | |
segment.export(tmp.name, format="wav") | |
tmp_path = tmp.name | |
# Build messages for this segment | |
sys_prompt = ( | |
"You are a speech recognition model." | |
) | |
messages = [ | |
{"role": "system", "content": [{"type": "text", "text": sys_prompt}]}, | |
{ | |
"role": "user", | |
"content": [ | |
{"type": "audio", "audio": tmp_path}, | |
{"type": "text", "text": user_prompt} | |
], | |
}, | |
] | |
# Apply chat template (no tokenization yet) | |
text_input = processor.apply_chat_template( | |
messages, | |
tokenize=False, | |
add_generation_prompt=True | |
) | |
# Preprocess audio (and any images/videos, though here only audio) | |
audios, images, videos = process_mm_info(messages, use_audio_in_video=True) | |
# Tokenize & move tensors | |
inputs = processor( | |
text=text_input, | |
audio=audios, | |
images=images, | |
videos=videos, | |
return_tensors="pt", | |
padding=True, | |
use_audio_in_video=True | |
) | |
inputs = inputs.to(model.device).to(model.dtype) | |
# Generate for this snippet | |
output_tokens = model.generate( | |
**inputs, | |
use_audio_in_video=True, | |
return_audio=False, | |
thinker_max_new_tokens=512, | |
thinker_do_sample=False | |
) | |
# Decode (system+user+assistant) | |
full_decoded = processor.batch_decode( | |
output_tokens, | |
skip_special_tokens=True, | |
clean_up_tokenization_spaces=False | |
)[0].strip() | |
# Strip prefixes to isolate ASR transcript | |
asr_text = _strip_prompts(full_decoded) | |
# Convert to Traditional Chinese | |
asr_text = cc.convert(asr_text) | |
# Append with speaker label | |
snippets.append(f"[{speaker}] {asr_text}") | |
# Yield updated HTML so Gradio can stream | |
yield format_diarization_html(snippets) | |
# Clean up temp file for this segment | |
os.unlink(tmp_path) | |
return | |
# ----------------------------- | |
# Gradio UI | |
# ----------------------------- | |
DEMO_CSS = """ | |
.diar { | |
padding: 0.5rem; | |
color: #f1f1f1; | |
font-family: monospace; | |
font-size: 0.9rem; | |
} | |
""" | |
with gr.Blocks(css=DEMO_CSS) as demo: | |
gr.Markdown("## Qwen2.5-Omni ASR with Speaker Diarization & S2T Conversion (ZeroGPU)") | |
with gr.Row(): | |
audio_input = gr.Audio( | |
label="Upload Audio (WAV/MP3/…)", | |
type="filepath" | |
) | |
user_input = gr.Textbox( | |
label="User Prompt", | |
value="Transcribe the attached audio to text with punctuation." | |
) | |
model_selector = gr.Radio( | |
choices=["3B", "7B"], | |
value="7B", | |
label="Model Size" | |
) | |
# Example audio files | |
example_list = [ | |
["audio/ads.mp3"], | |
["audio/meeting.mp3"], | |
["audio/news.mp3"] | |
] | |
gr.Examples( | |
examples=example_list, | |
inputs=[audio_input], | |
examples_per_page=3, | |
label="Try one of these audio files ⤵︎" | |
) | |
submit_btn = gr.Button("Transcribe") | |
diarized_output = gr.HTML( | |
label="Speaker-Diarized Transcript (Traditional Chinese)", | |
elem_classes=["diar"] | |
) | |
submit_btn.click( | |
fn=run_asr, | |
inputs=[audio_input, user_input, model_selector], | |
outputs=diarized_output | |
) | |
if __name__ == "__main__": | |
demo.queue() | |
demo.launch() | |