Luigi's picture
adjust system prompt
22ff9fd
import os
import tempfile
import spaces
import torch
import torchaudio
from transformers import Qwen2_5OmniForConditionalGeneration, Qwen2_5OmniProcessor
from qwen_omni_utils import process_mm_info
from opencc import OpenCC
import gradio as gr
from pyannote.audio import Pipeline as DiarizationPipeline
from pydub import AudioSegment, effects
# Converter from Simplified to Traditional Chinese
cc = OpenCC("s2t")
# Define available model IDs
MODEL_IDS = {
"3B": "Qwen/Qwen2.5-Omni-3B",
"7B": "Qwen/Qwen2.5-Omni-7B"
}
# Caches for loaded models and processors
_models = {}
_processors = {}
def get_model_and_processor(size: str):
"""
Load and cache the model and processor for the given size ("3B" or "7B").
"""
if size not in _models:
model_id = MODEL_IDS[size]
# Load model with device_map="auto" for ZeroGPU compatibility
m = Qwen2_5OmniForConditionalGeneration.from_pretrained(
model_id,
torch_dtype="auto",
device_map="auto"
)
m.disable_talker()
m.eval()
p = Qwen2_5OmniProcessor.from_pretrained(model_id)
_models[size] = m
_processors[size] = p
return _models[size], _processors[size]
# Cache the diarization pipeline so we only load it once
_diar_pipe = None
def get_diarization_pipe():
global _diar_pipe
if _diar_pipe is None:
hf_token = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_TOKEN")
try:
_diar_pipe = DiarizationPipeline.from_pretrained(
"pyannote/speaker-diarization-3.1",
use_auth_token=hf_token or True
)
except Exception:
_diar_pipe = DiarizationPipeline.from_pretrained(
"pyannote/speaker-diarization@2.1",
use_auth_token=hf_token or True
)
return _diar_pipe
# Format a list of "[SPEAKER_X] text" snippets into colored HTML
def format_diarization_html(snippets):
palette = ["#e74c3c", "#3498db", "#27ae60", "#e67e22", "#9b59b6", "#16a085", "#f1c40f"]
speaker_colors = {}
html_lines = []
last_spk = None
for s in snippets:
if s.startswith("[") and "]" in s:
spk, txt = s[1:].split("]", 1)
spk, txt = spk.strip(), txt.strip()
else:
spk, txt = "", s.strip()
if not txt:
continue
if spk not in speaker_colors:
speaker_colors[spk] = palette[len(speaker_colors) % len(palette)]
color = speaker_colors[spk]
if spk == last_spk:
display = txt
else:
display = f"<strong>{spk}:</strong> {txt}"
last_spk = spk
html_lines.append(
f"<p style='margin:4px 0; font-family:monospace; color:{color};'>{display}</p>"
)
return "<div>" + "".join(html_lines) + "</div>"
def _strip_prompts(full_text: str) -> str:
"""
Remove system/user/assistant prefixes so only the actual ASR transcript remains.
"""
marker = "assistant"
if marker in full_text:
return full_text.split(marker, 1)[1].strip()
else:
return full_text.strip()
@spaces.GPU
def run_asr(
audio_path: str,
user_prompt: str,
model_size: str
):
# Validate inputs
if not audio_path:
yield format_diarization_html(["⚠️ Please upload an audio file first."])
return
# Load diarization model onto GPU/CPU
diarizer = get_diarization_pipe()
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
diarizer.to(device)
# Load waveform + sample rate and push to device
waveform, sample_rate = torchaudio.load(audio_path)
waveform = waveform.to(device)
# Get appropriate Qwen model & processor based on selection
model, processor = get_model_and_processor(model_size)
model.to(device)
# Run diarization to get speaker turns
diary = diarizer({"waveform": waveform, "sample_rate": sample_rate})
snippets = []
# For each speaker turn, slice audio, transcribe, convert, accumulate
for turn, _, speaker in diary.itertracks(yield_label=True):
start_ms = int(turn.start * 1000)
end_ms = int(turn.end * 1000)
# Extract the segment, normalize, export to temp file
segment = AudioSegment.from_file(audio_path)[start_ms:end_ms]
segment = effects.normalize(segment)
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
segment.export(tmp.name, format="wav")
tmp_path = tmp.name
# Build messages for this segment
sys_prompt = (
"You are a speech recognition model."
)
messages = [
{"role": "system", "content": [{"type": "text", "text": sys_prompt}]},
{
"role": "user",
"content": [
{"type": "audio", "audio": tmp_path},
{"type": "text", "text": user_prompt}
],
},
]
# Apply chat template (no tokenization yet)
text_input = processor.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
# Preprocess audio (and any images/videos, though here only audio)
audios, images, videos = process_mm_info(messages, use_audio_in_video=True)
# Tokenize & move tensors
inputs = processor(
text=text_input,
audio=audios,
images=images,
videos=videos,
return_tensors="pt",
padding=True,
use_audio_in_video=True
)
inputs = inputs.to(model.device).to(model.dtype)
# Generate for this snippet
output_tokens = model.generate(
**inputs,
use_audio_in_video=True,
return_audio=False,
thinker_max_new_tokens=512,
thinker_do_sample=False
)
# Decode (system+user+assistant)
full_decoded = processor.batch_decode(
output_tokens,
skip_special_tokens=True,
clean_up_tokenization_spaces=False
)[0].strip()
# Strip prefixes to isolate ASR transcript
asr_text = _strip_prompts(full_decoded)
# Convert to Traditional Chinese
asr_text = cc.convert(asr_text)
# Append with speaker label
snippets.append(f"[{speaker}] {asr_text}")
# Yield updated HTML so Gradio can stream
yield format_diarization_html(snippets)
# Clean up temp file for this segment
os.unlink(tmp_path)
return
# -----------------------------
# Gradio UI
# -----------------------------
DEMO_CSS = """
.diar {
padding: 0.5rem;
color: #f1f1f1;
font-family: monospace;
font-size: 0.9rem;
}
"""
with gr.Blocks(css=DEMO_CSS) as demo:
gr.Markdown("## Qwen2.5-Omni ASR with Speaker Diarization & S2T Conversion (ZeroGPU)")
with gr.Row():
audio_input = gr.Audio(
label="Upload Audio (WAV/MP3/…)",
type="filepath"
)
user_input = gr.Textbox(
label="User Prompt",
value="Transcribe the attached audio to text with punctuation."
)
model_selector = gr.Radio(
choices=["3B", "7B"],
value="7B",
label="Model Size"
)
# Example audio files
example_list = [
["audio/ads.mp3"],
["audio/meeting.mp3"],
["audio/news.mp3"]
]
gr.Examples(
examples=example_list,
inputs=[audio_input],
examples_per_page=3,
label="Try one of these audio files ⤵︎"
)
submit_btn = gr.Button("Transcribe")
diarized_output = gr.HTML(
label="Speaker-Diarized Transcript (Traditional Chinese)",
elem_classes=["diar"]
)
submit_btn.click(
fn=run_asr,
inputs=[audio_input, user_input, model_selector],
outputs=diarized_output
)
if __name__ == "__main__":
demo.queue()
demo.launch()