Spaces:
Runtime error
Runtime error
import random | |
import gradio as gr | |
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline | |
### Model Setup ### | |
# Define the two model names. | |
MODEL1_NAME = "Qwen/Qwen2.5-14B-Instruct-1M" # We'll refer to this as Qwen14B. | |
MODEL2_NAME = "Qwen/Qwen2.5-VL-7B-Instruct" # We'll refer to this as QwenVL. | |
# Load Qwen14B model. | |
tokenizer1 = AutoTokenizer.from_pretrained(MODEL1_NAME) | |
model1 = AutoModelForCausalLM.from_pretrained(MODEL1_NAME, device_map="auto") | |
pipe1 = pipeline("text-generation", model=model1, tokenizer=tokenizer1) | |
# Load QwenVL model. | |
tokenizer2 = AutoTokenizer.from_pretrained(MODEL2_NAME) | |
model2 = AutoModelForCausalLM.from_pretrained(MODEL2_NAME, device_map="auto") | |
pipe2 = pipeline("text-generation", model=model2, tokenizer=tokenizer2) | |
def generate_response(prompt: str, model_choice: str) -> str: | |
""" | |
Generate a response given the conversation prompt using the chosen model. | |
""" | |
kwargs = { | |
"max_length": 256, | |
"do_sample": True, | |
"top_p": 0.95, | |
"temperature": 0.9, | |
} | |
if model_choice == "Qwen14B": | |
result = pipe1(prompt, **kwargs) | |
else: # model_choice == "QwenVL" | |
result = pipe2(prompt, **kwargs) | |
return result[0]["generated_text"] | |
def chat_logic(user_input: str, chat_history: list): | |
""" | |
Build the conversation prompt from the history, choose which model responds, | |
generate the response, and update the conversation. | |
""" | |
if chat_history is None: | |
chat_history = [] | |
# If the user provides input, add it to the history and randomly choose a model. | |
if user_input.strip(): | |
chat_history.append(("User", user_input.strip())) | |
selected_model = random.choice(["Qwen14B", "QwenVL"]) | |
else: | |
# When no user input is provided, let the models alternate. | |
if not chat_history: | |
selected_model = random.choice(["Qwen14B", "QwenVL"]) | |
else: | |
last_speaker = chat_history[-1][0] | |
if last_speaker == "Qwen14B": | |
selected_model = "QwenVL" | |
elif last_speaker == "QwenVL": | |
selected_model = "Qwen14B" | |
else: | |
selected_model = random.choice(["Qwen14B", "QwenVL"]) | |
# Build the prompt from the conversation history. | |
prompt = "" | |
for speaker, message in chat_history: | |
prompt += f"{speaker}: {message}\n" | |
prompt += f"{selected_model}:" | |
# Generate the model's response. | |
model_response = generate_response(prompt, selected_model) | |
chat_history.append((selected_model, model_response)) | |
# Return the updated conversation for display and to maintain state. | |
return chat_history, chat_history | |
### Gradio Interface ### | |
with gr.Blocks() as demo: | |
gr.Markdown("# Group Chat: Qwen Models") | |
gr.Markdown( | |
"This demo features two Qwen models conversing with each other. " | |
"Leave the textbox blank to let the models converse automatically, " | |
"or type a message to interject (a random model will then reply)." | |
) | |
# Chat display component. | |
chatbot = gr.Chatbot() | |
# Input row: textbox and send button. | |
with gr.Row(): | |
user_message = gr.Textbox( | |
placeholder="Type your message here or leave blank...", | |
show_label=False | |
) | |
send_btn = gr.Button("Send") | |
# Maintain the conversation history. | |
state = gr.State([]) | |
# On clicking the button, update the conversation. | |
send_btn.click( | |
fn=chat_logic, | |
inputs=[user_message, state], | |
outputs=[chatbot, state] | |
) | |
# Launch the Space. | |
demo.launch() |