import gradio as gr
import torch
import transformers
import os
from PIL import Image
import spaces
def process_vision_info(messages):
image_inputs = []
video_inputs = []
for message in messages:
if message["role"] == "user":
content = message["content"]
for item in content:
if item["type"] == "image":
image_inputs.append(item["image"])
elif item["type"] == "video":
video_inputs.append(item["video"])
return image_inputs, video_inputs
print("Loading text model (Qwen/Qwen2.5-7B)...")
text_model_loaded = False
text_model_error = ""
try:
text_model = transformers.AutoModelForCausalLM.from_pretrained(
"Qwen/Qwen2.5-7B",
torch_dtype=torch.bfloat16,
device_map="auto"
)
text_tokenizer = transformers.AutoTokenizer.from_pretrained("Qwen/Qwen2.5-7B")
text_model_loaded = True
print("Text model loaded successfully.")
except Exception as e:
text_model_error = str(e)
print(f"Error loading text model: {text_model_error}")
text_model, text_tokenizer = None, None
print("Loading Vision-Language model (Qwen/Qwen2.5-VL-7B-Instruct)...")
vl_model_loaded = False
vl_model_error = ""
try:
vl_model = transformers.Qwen2_5_VLForConditionalGeneration.from_pretrained(
"Qwen/Qwen2.5-VL-7B-Instruct",
torch_dtype=torch.bfloat16,
device_map="auto"
)
vl_processor = transformers.AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct")
vl_model_loaded = True
print("Vision-Language model loaded successfully.")
except Exception as e:
vl_model_error = str(e)
print(f"Error loading Vision-Language model: {vl_model_error}")
vl_model, vl_processor = None, None
@spaces.GPU
def visualize_text_token_probabilities(text: str):
if not text_model_loaded:
return [(f"Text Model failed to load: {text_model_error}", None)]
if not text or not text.strip():
return [("Please enter some text to analyze.", None)]
try:
inputs = text_tokenizer([text], return_tensors="pt").to(text_model.device)
input_ids = inputs.input_ids
if input_ids.shape[1] < 2:
token = text_tokenizer.decode(input_ids[0])
return [(token, None)]
inp = input_ids[:, :-1]
outp = input_ids[:, 1:].unsqueeze(-1)
with torch.no_grad():
logits = text_model(inp).logits.float()
all_probs = torch.softmax(logits, dim=-1)
chosen_probs = torch.gather(all_probs, dim=2, index=outp).squeeze(-1).cpu().numpy()[0]
highlighted_data = []
outp_tokens = input_ids[0, 1:].cpu().tolist()
first_token_str = text_tokenizer.decode([input_ids[0, 0].item()])
highlighted_data.append((first_token_str, None))
for token_id, prob in zip(outp_tokens, chosen_probs):
token_str = text_tokenizer.decode([token_id])
highlighted_data.append((token_str, float(prob)))
return highlighted_data
except Exception as e:
print(f"An error occurred during text processing: {e}")
return [(f"An error occurred: {str(e)}", None)]
@spaces.GPU
def generate_and_visualize_vl_probabilities(image, prompt: str):
if not vl_model_loaded:
return [(f"Vision-Language Model failed to load: {vl_model_error}", None)]
if image is None or not prompt or not prompt.strip():
return [("Please upload an image and provide a text prompt.", None)]
try:
messages = [{"role": "user", "content": [{"type": "image", "image": image}, {"type": "text", "text": prompt.strip()}]}]
text = vl_processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
image_inputs, _ = process_vision_info(messages)
inputs = vl_processor(text=[text], images=image_inputs, padding=True, return_tensors="pt").to(vl_model.device)
with torch.no_grad():
generated_ids = vl_model.generate(**inputs, max_new_tokens=512)
input_token_len = inputs.input_ids.shape[1]
if generated_ids.shape[1] <= input_token_len:
return [("Model did not generate any new tokens.", None)]
original_mask = inputs.attention_mask
num_generated_tokens = generated_ids.shape[1] - input_token_len
generated_mask = torch.ones(
(1, num_generated_tokens),
dtype=original_mask.dtype,
device=original_mask.device
)
full_attention_mask = torch.cat([original_mask, generated_mask], dim=1)
with torch.no_grad():
outputs = vl_model(
input_ids=generated_ids,
pixel_values=inputs.get('pixel_values'),
image_grid_thw=inputs.get('image_grid_thw'),
attention_mask=full_attention_mask
)
logits = outputs.logits.float()
logits_of_generated_part = logits[:, input_token_len - 1:-1, :]
labels_of_generated_part = generated_ids[:, input_token_len:]
all_probs = torch.softmax(logits_of_generated_part, dim=-1)
chosen_probs = torch.gather(all_probs, 2, labels_of_generated_part.unsqueeze(-1)).squeeze(-1)
generated_token_ids_only = generated_ids[0, input_token_len:]
probs_list = chosen_probs[0].cpu().tolist()
highlighted_data = []
for token_id, prob in zip(generated_token_ids_only.tolist(), probs_list):
token_str = vl_processor.decode([token_id])
highlighted_data.append((token_str, float(prob)))
if not highlighted_data:
return [("Model did not generate any new tokens.", None)]
return highlighted_data
except Exception as e:
import traceback
traceback.print_exc()
print(f"An error occurred during VL processing: {e}")
return [(f"An error occurred: {str(e)}", None)]
text_en_example = """A conversation between User and Assistant. The user asks a question, and the Assistant solves it.
The assistant first thinks about the reasoning process in the mind and then provides the user
with the answer. The reasoning process and answer are enclosed within and
tags, respectively, i.e., reasoning process here
answer here . User: What is 7 * 6? Assistant: First, the user asked: "what is 7 * 6?" That's a multiplication problem. I need to calculate the product of 7 and 6.
I know my multiplication tables. 7 times 6 is 42. I can double-check: 7 × 6 means adding 7 six times: 7 + 7 + 7 + 7 + 7 + 7. Let's add that up: 7+7=14, 14+7=21, 21+7=28, 28+7=35, 35+7=42. Yes, that's 42.
I think that's fine. 7 multiplied by 6 equals **42**.
If you have any more math questions or need an explanation, feel free to ask! 😊 """
with gr.Blocks(theme=gr.themes.Soft(), title="Qwen2.5 Token Visualizer") as demo:
gr.Markdown(
"""
# Qwen2.5 Series Token Probability Visualizer
This tool visualizes token probabilities for both text and vision-language models from the Qwen2.5 series.
The color of each token represents its conditional probability.
**Red means high probability** (the model was confident), and **White means low probability** (the model was surprised).
"""
)
with gr.Tabs():
with gr.TabItem("Text Model (Qwen2.5-7B)"):
gr.Markdown("### Analyze Probabilities of Given Text")
with gr.Row():
text_input = gr.Textbox(
label="Input Text", lines=15, value=text_en_example,
placeholder="Enter text here to analyze..."
)
with gr.Row():
text_submit_btn = gr.Button("Visualize Probabilities", variant="primary")
text_output_highlight = gr.HighlightedText(
label="Token Probabilities (High: Red, Low: White)", show_legend=True,
combine_adjacent=False,
)
gr.Examples(
examples=[[text_en_example]], inputs=text_input, outputs=text_output_highlight,
fn=visualize_text_token_probabilities, cache_examples=False
)
text_submit_btn.click(
fn=visualize_text_token_probabilities, inputs=text_input, outputs=text_output_highlight,
api_name="visualize_text"
)
with gr.TabItem("Vision-Language Model (Qwen2.5-VL-7B-Instruct)"):
gr.Markdown("### Generate Text from Image and Visualize Probabilities")
with gr.Row():
with gr.Column():
vl_image_input = gr.Image(type="pil", label="Upload Image")
vl_text_input = gr.Textbox(label="Your Question", placeholder="e.g., Describe this image.")
vl_submit_btn = gr.Button("Generate and Visualize", variant="primary")
with gr.Column():
vl_output_highlight = gr.HighlightedText(
label="Generated Token Probabilities (High: Red, Low: White)", show_legend=True,
combine_adjacent=False,
)
gr.Examples(
examples=[["demo.jpeg", "Describe this image in detail."]],
inputs=[vl_image_input, vl_text_input],
outputs=vl_output_highlight,
fn=generate_and_visualize_vl_probabilities,
cache_examples=False
)
vl_submit_btn.click(
fn=generate_and_visualize_vl_probabilities, inputs=[vl_image_input, vl_text_input],
outputs=vl_output_highlight, api_name="visualize_vl_generation"
)
if __name__ == "__main__":
if not os.path.exists("demo.jpeg"):
try:
from PIL import Image, ImageDraw, ImageFont
img = Image.new('RGB', (400, 200), color = (73, 109, 137))
d = ImageDraw.Draw(img)
try:
font = ImageFont.truetype("arial.ttf", 20)
except IOError:
font = ImageFont.load_default()
d.text((10,10), "This is a demo image for Gradio.", font=font, fill=(255,255,0))
img.save("demo.jpeg")
print("Created a dummy 'demo.jpeg' for the example.")
except Exception as e:
print(f"Could not create a dummy image: {e}")
demo.queue().launch(share=True)