Spaces:
Sleeping
Sleeping
#!/usr/bin/env python3 | |
""" | |
Qwen2.5-Omni Complete Multimodal Demo | |
A comprehensive Gradio interface for the Qwen2.5-Omni-3B multimodal AI model | |
Optimized for Apple Silicon (MPS) with efficient memory management | |
""" | |
import os | |
import gc | |
import sys | |
import time | |
import signal | |
import warnings | |
from typing import List, Dict, Any, Optional, Tuple, Union | |
import tempfile | |
import soundfile as sf | |
# Suppress warnings for cleaner output | |
warnings.filterwarnings("ignore", category=FutureWarning) | |
warnings.filterwarnings("ignore", category=UserWarning) | |
import torch | |
import numpy as np | |
import gradio as gr | |
from PIL import Image | |
# Global variables for model and processor | |
model = None | |
processor = None | |
device = None | |
def cleanup_resources(): | |
"""Clean up model and free memory""" | |
global model, processor | |
try: | |
if model is not None: | |
del model | |
model = None | |
if processor is not None: | |
del processor | |
processor = None | |
# Force garbage collection | |
gc.collect() | |
# Clear CUDA/MPS cache if available | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): | |
torch.mps.empty_cache() | |
print("β Resources cleaned up successfully") | |
except Exception as e: | |
print(f"β οΈ Warning during cleanup: {e}") | |
def signal_handler(signum, frame): | |
"""Handle interrupt signals gracefully""" | |
print("\nπ Interrupt received, cleaning up...") | |
cleanup_resources() | |
sys.exit(0) | |
# Register signal handlers | |
signal.signal(signal.SIGINT, signal_handler) | |
signal.signal(signal.SIGTERM, signal_handler) | |
def load_model(): | |
"""Load the Qwen2.5-Omni model and processor""" | |
global model, processor, device | |
if model is not None: | |
return "β Model already loaded!" | |
try: | |
# Check device | |
if torch.backends.mps.is_available(): | |
device = torch.device("mps") | |
device_info = "π Using Apple Silicon MPS acceleration" | |
else: | |
device = torch.device("cpu") | |
device_info = "β οΈ Using CPU (MPS not available)" | |
# Import the specific Qwen2.5-Omni classes | |
from transformers import Qwen2_5OmniForConditionalGeneration, Qwen2_5OmniProcessor | |
# Load processor with optimizations | |
processor = Qwen2_5OmniProcessor.from_pretrained( | |
"Qwen/Qwen2.5-Omni-3B", | |
trust_remote_code=True, | |
use_fast=True # Use fast tokenizer if available | |
) | |
# Load model with memory-efficient settings - keep bfloat16 for all functionalities | |
model = Qwen2_5OmniForConditionalGeneration.from_pretrained( | |
"Qwen/Qwen2.5-Omni-3B", | |
torch_dtype=torch.bfloat16, | |
trust_remote_code=True, | |
device_map="auto" if device.type != "mps" else None, | |
low_cpu_mem_usage=True, | |
use_safetensors=True, | |
attn_implementation="sdpa" | |
) | |
# Immediately disable the audio generation module to prevent any initialization overhead | |
model.disable_talker() | |
print("π€ Talker module disabled immediately after loading to optimize performance") | |
# Explicitly move to device for MPS while keeping bfloat16 | |
if device.type == "mps": | |
model = model.to(device=device, dtype=torch.bfloat16) | |
print(f"π§ Model loaded with dtype: bfloat16 (memory efficient)") | |
# Clear any cached memory after loading | |
gc.collect() | |
gc.collect() # Run twice for good measure | |
if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): | |
torch.mps.empty_cache() | |
return f"β Model loaded successfully!\n{device_info}\nDevice: {device}" | |
except Exception as e: | |
return f"β Error loading model: {str(e)}" | |
def text_chat(message, history, system_prompt, temperature, max_tokens): | |
"""Handle text-only conversations correctly.""" | |
if model is None or processor is None: | |
history.append((message, "β Error: Model is not loaded. Please load the model first.")) | |
return history, "" | |
if not message or not message.strip(): | |
return history, "" | |
try: | |
conversation = [] | |
if system_prompt and system_prompt.strip(): | |
conversation.append({"role": "system", "content": [{"type": "text", "text": system_prompt}]}) | |
# Correctly process history for the model | |
for user_msg, assistant_msg in history: | |
if user_msg: | |
conversation.append({"role": "user", "content": [{"type": "text", "text": user_msg}]}) | |
if assistant_msg: | |
# Avoid adding error messages to the model's context | |
if not assistant_msg.startswith("β Error:"): | |
conversation.append({"role": "assistant", "content": [{"type": "text", "text": assistant_msg}]}) | |
conversation.append({"role": "user", "content": [{"type": "text", "text": message}]}) | |
text = processor.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False) | |
inputs = processor(text=text, return_tensors="pt", padding=True).to(device) | |
with torch.no_grad(): | |
generated_ids = model.generate( | |
**inputs, | |
max_new_tokens=max_tokens, | |
temperature=temperature, | |
do_sample=True, | |
pad_token_id=processor.tokenizer.eos_token_id | |
) | |
input_token_len = inputs["input_ids"].shape[1] | |
response_ids = generated_ids[:, input_token_len:] | |
response = processor.batch_decode(response_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] | |
history.append((message, response)) | |
return history, "" | |
except Exception as e: | |
import traceback | |
traceback.print_exc() | |
error_message = f"β Error in text chat: {str(e)}" | |
history.append((message, error_message)) | |
return history, "" | |
def multimodal_chat(message, image, audio, history, system_prompt, temperature, max_tokens): | |
""" | |
Handle multimodal conversations (text, image, and audio) using the correct | |
processor.apply_chat_template method as per the official documentation. | |
""" | |
global model, processor, device | |
if model is None or processor is None: | |
history.append((message, "β Error: Model is not loaded. Please load the model first.")) | |
return history, "" | |
if not message.strip() and image is None and audio is None: | |
history.append(("", "Please provide an input (text, image, or audio).")) | |
return history, "" | |
# --- Create a temporary directory for media files --- | |
temp_dir = tempfile.mkdtemp() | |
try: | |
# --- Build the conversation history in the required format --- | |
conversation = [] | |
if system_prompt and system_prompt.strip(): | |
conversation.append({"role": "system", "content": [{"type": "text", "text": system_prompt}]}) | |
# Process Gradio history into the conversation format | |
for user_turn, bot_turn in history: | |
# For simplicity, we only process the text part of the history. | |
# A more robust solution would parse the [Image] and [Audio] tags | |
# and reconstruct the full multimodal history. | |
if user_turn: | |
conversation.append({"role": "user", "content": [{"type": "text", "text": user_turn.replace("[Image]", "").replace("[Audio]", "").strip()}]}) | |
if bot_turn and not bot_turn.startswith("β Error:"): | |
conversation.append({"role": "assistant", "content": [{"type": "text", "text": bot_turn}]}) | |
# --- Prepare the current user's turn --- | |
current_content = [] | |
user_message_for_history = "" | |
# Process text | |
if message and message.strip(): | |
current_content.append({"type": "text", "text": message}) | |
user_message_for_history += message | |
# Process image | |
if image is not None: | |
# --- FIX: Resize large images to prevent OOM errors --- | |
MAX_PIXELS = 1024 * 1024 # 1 megapixel | |
if image.width * image.height > MAX_PIXELS: | |
image.thumbnail((1024, 1024), Image.Resampling.LANCZOS) | |
temp_image_path = os.path.join(temp_dir, "temp_image.png") | |
image.save(temp_image_path) | |
current_content.append({"type": "image", "image": temp_image_path}) | |
user_message_for_history += " [Image]" | |
# Process audio | |
if audio is not None: | |
sample_rate, audio_data = audio | |
temp_audio_path = os.path.join(temp_dir, "temp_audio.wav") | |
sf.write(temp_audio_path, audio_data, sample_rate) | |
current_content.append({"type": "audio", "audio": temp_audio_path}) | |
user_message_for_history += " [Audio]" | |
if not current_content: | |
history.append(("", "Please provide some input.")) | |
return history, "" | |
conversation.append({"role": "user", "content": current_content}) | |
# --- Use `apply_chat_template` as per the documentation --- | |
# This is the single, correct way to process all modalities. | |
inputs = processor.apply_chat_template( | |
conversation, | |
add_generation_prompt=True, | |
tokenize=True, | |
return_dict=True, | |
return_tensors="pt", | |
padding=True, | |
).to(device) | |
# --- Generation --- | |
with torch.no_grad(): | |
# Note: The model's generate function does not return audio directly in this setup | |
# We are focusing on getting the text response right first. | |
generated_ids = model.generate( | |
**inputs, | |
max_new_tokens=max_tokens, | |
temperature=temperature, | |
do_sample=True, | |
pad_token_id=processor.tokenizer.eos_token_id, | |
# return_audio=False # This might be needed if audio output is enabled by default | |
) | |
# The generate call for the full Omni model might return a tuple (text_ids, audio_wav) | |
# We handle both cases to be safe. | |
if isinstance(generated_ids, tuple): | |
response_ids = generated_ids[0] | |
else: | |
response_ids = generated_ids | |
input_token_len = inputs["input_ids"].shape[1] | |
response_ids_decoded = response_ids[:, input_token_len:] | |
response = processor.batch_decode(response_ids_decoded, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] | |
history.append((user_message_for_history.strip(), response)) | |
return history, "" | |
except Exception as e: | |
import traceback | |
error_message = f"β Multimodal chat error: {traceback.format_exc()}" | |
print(error_message) # Print full traceback to console for debugging | |
history.append((message, f"β Error: {e}")) | |
return history, "" | |
finally: | |
# --- Clean up temporary files --- | |
if os.path.exists(temp_dir): | |
import shutil | |
shutil.rmtree(temp_dir) | |
def clear_history(): | |
"""Clear chat history""" | |
return [] | |
def clear_model_cache(): | |
"""Clear model cache and free memory""" | |
global model, processor | |
try: | |
cleanup_resources() | |
# Clear additional caches | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): | |
torch.mps.empty_cache() | |
return "β Cache cleared successfully! Click 'Load Model' to reload." | |
except Exception as e: | |
return f"β Error clearing cache: {str(e)}" | |
def create_interface(): | |
"""Create the complete Gradio interface with the fix.""" | |
with gr.Blocks(title="Qwen2.5-Omni Multimodal Demo", theme=gr.themes.Soft()) as demo: | |
gr.Markdown(""" | |
# π€ Qwen2.5-Omni Complete Multimodal Demo | |
A comprehensive and corrected Gradio interface for the Qwen2.5-Omni-3B model. | |
""") | |
with gr.Row(): | |
with gr.Column(scale=2): | |
load_btn = gr.Button("π Load Model", variant="primary") | |
with gr.Column(scale=2): | |
cache_clear_btn = gr.Button("π§Ή Clear Cache", variant="secondary") | |
with gr.Column(scale=3): | |
model_status = gr.Textbox(label="Model Status", value="Model not loaded", interactive=False) | |
load_btn.click(load_model, outputs=model_status) | |
cache_clear_btn.click(clear_model_cache, outputs=model_status) | |
with gr.Tabs(): | |
with gr.Tab("π¬ Text Chat"): | |
text_chatbot = gr.Chatbot(label="Conversation", height=450) | |
with gr.Row(): | |
text_msg = gr.Textbox(label="Your message", placeholder="Type your message...", scale=4, container=False) | |
text_send = gr.Button("Send", variant="primary", scale=1) | |
with gr.Row(): | |
text_clear = gr.Button("Clear History") | |
with gr.Accordion("Settings", open=False): | |
text_system = gr.Textbox(label="System Prompt", value="You are a helpful AI assistant.") | |
text_temp = gr.Slider(0.1, 1.5, value=0.7, label="Temperature") | |
text_max_tokens = gr.Slider(50, 1000, value=500, label="Max New Tokens", step=50) | |
text_send.click(text_chat, inputs=[text_msg, text_chatbot, text_system, text_temp, text_max_tokens], outputs=[text_chatbot, text_msg]) | |
text_msg.submit(text_chat, inputs=[text_msg, text_chatbot, text_system, text_temp, text_max_tokens], outputs=[text_chatbot, text_msg]) | |
text_clear.click(clear_history, outputs=text_chatbot) | |
with gr.Tab("π Multimodal Chat"): | |
multi_chatbot = gr.Chatbot(label="Multimodal Conversation", height=450) | |
multi_text = gr.Textbox(label="Text Message (optional)", placeholder="Describe what you want to know...", scale=4, container=False) | |
with gr.Row(): | |
multi_image = gr.Image(label="Upload Image (optional)", type="pil") | |
multi_audio = gr.Audio(label="Upload Audio (optional)", type="numpy") | |
with gr.Row(): | |
multi_send = gr.Button("Send Multimodal Input", variant="primary") | |
multi_clear = gr.Button("Clear History") | |
with gr.Accordion("Settings", open=False): | |
multi_system = gr.Textbox(label="System Prompt", value="You are Qwen, capable of understanding images, audio, and text.") | |
multi_temp = gr.Slider(0.1, 1.5, value=0.7, label="Temperature") | |
multi_max_tokens = gr.Slider(50, 1000, value=500, label="Max New Tokens", step=50) | |
multi_send.click(multimodal_chat, inputs=[multi_text, multi_image, multi_audio, multi_chatbot, multi_system, multi_temp, multi_max_tokens], outputs=[multi_chatbot, multi_text]) | |
multi_clear.click(clear_history, outputs=multi_chatbot) | |
with gr.Tab("βΉοΈ Model Info"): | |
# Placeholder for model info content | |
gr.Markdown("Model information will be displayed here.") | |
return demo | |
if __name__ == "__main__": | |
try: | |
os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
os.environ["OMP_NUM_THREADS"] = "1" | |
demo = create_interface() | |
print("π Starting Qwen2.5-Omni Gradio Demo...") | |
print("π Memory management optimizations enabled") | |
print("π Access the interface at: http://localhost:7860") | |
demo.launch( | |
server_name="0.0.0.0", | |
server_port=7860, | |
share=False, | |
show_error=True, | |
quiet=False | |
) | |
except KeyboardInterrupt: | |
print("\nπ Shutting down gracefully...") | |
cleanup_resources() | |
except Exception as e: | |
print(f"β Error starting demo: {e}") | |
cleanup_resources() |