Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import torch | |
import transformers | |
import os | |
from PIL import Image | |
os.environ['HF_HOME'] = '/data/hf_home' | |
os.environ['TRANSFORMERS_CACHE'] = '/data/transformers_cache' | |
def process_vision_info(messages): | |
image_inputs = [] | |
video_inputs = [] | |
for message in messages: | |
if message["role"] == "user": | |
content = message["content"] | |
for item in content: | |
if item["type"] == "image": | |
image_inputs.append(item["image"]) | |
elif item["type"] == "video": | |
video_inputs.append(item["video"]) | |
return image_inputs, video_inputs | |
print("Loading text model (Qwen/Qwen2.5-7B)...") | |
text_model_loaded = False | |
text_model_error = "" | |
try: | |
text_model = transformers.AutoModelForCausalLM.from_pretrained( | |
"Qwen/Qwen2.5-7B", | |
torch_dtype=torch.bfloat16, | |
device_map="auto" | |
) | |
text_tokenizer = transformers.AutoTokenizer.from_pretrained("Qwen/Qwen2.5-7B") | |
text_model_loaded = True | |
print("Text model loaded successfully.") | |
except Exception as e: | |
text_model_error = str(e) | |
print(f"Error loading text model: {text_model_error}") | |
text_model, text_tokenizer = None, None | |
print("Loading Vision-Language model (Qwen/Qwen2.5-VL-7B-Instruct)...") | |
vl_model_loaded = False | |
vl_model_error = "" | |
try: | |
vl_model = transformers.Qwen2_5_VLForConditionalGeneration.from_pretrained( | |
"Qwen/Qwen2.5-VL-7B-Instruct", | |
torch_dtype=torch.bfloat16, | |
device_map="auto" | |
) | |
vl_processor = transformers.AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct") | |
vl_model_loaded = True | |
print("Vision-Language model loaded successfully.") | |
except Exception as e: | |
vl_model_error = str(e) | |
print(f"Error loading Vision-Language model: {vl_model_error}") | |
vl_model, vl_processor = None, None | |
def visualize_text_token_probabilities(text: str): | |
if not text_model_loaded: | |
return [(f"Text Model failed to load: {text_model_error}", None)] | |
if not text or not text.strip(): | |
return [("Please enter some text to analyze.", None)] | |
try: | |
inputs = text_tokenizer([text], return_tensors="pt").to(text_model.device) | |
input_ids = inputs.input_ids | |
if input_ids.shape[1] < 2: | |
token = text_tokenizer.decode(input_ids[0]) | |
return [(token, None)] | |
inp = input_ids[:, :-1] | |
outp = input_ids[:, 1:].unsqueeze(-1) | |
with torch.no_grad(): | |
logits = text_model(inp).logits.float() | |
all_probs = torch.softmax(logits, dim=-1) | |
chosen_probs = torch.gather(all_probs, dim=2, index=outp).squeeze(-1).cpu().numpy()[0] | |
highlighted_data = [] | |
outp_tokens = input_ids[0, 1:].cpu().tolist() | |
first_token_str = text_tokenizer.decode([input_ids[0, 0].item()]) | |
highlighted_data.append((first_token_str, None)) | |
for token_id, prob in zip(outp_tokens, chosen_probs): | |
token_str = text_tokenizer.decode([token_id]) | |
highlighted_data.append((token_str, float(prob))) | |
return highlighted_data | |
except Exception as e: | |
print(f"An error occurred during text processing: {e}") | |
return [(f"An error occurred: {str(e)}", None)] | |
def generate_and_visualize_vl_probabilities(image, prompt: str): | |
if not vl_model_loaded: | |
return [(f"Vision-Language Model failed to load: {vl_model_error}", None)] | |
if image is None or not prompt or not prompt.strip(): | |
return [("Please upload an image and provide a text prompt.", None)] | |
try: | |
messages = [{"role": "user", "content": [{"type": "image", "image": image}, {"type": "text", "text": prompt.strip()}]}] | |
text = vl_processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
image_inputs, _ = process_vision_info(messages) | |
inputs = vl_processor(text=[text], images=image_inputs, padding=True, return_tensors="pt").to(vl_model.device) | |
with torch.no_grad(): | |
generated_ids = vl_model.generate(**inputs, max_new_tokens=512) | |
input_token_len = inputs.input_ids.shape[1] | |
if generated_ids.shape[1] <= input_token_len: | |
return [("Model did not generate any new tokens.", None)] | |
original_mask = inputs.attention_mask | |
num_generated_tokens = generated_ids.shape[1] - input_token_len | |
generated_mask = torch.ones( | |
(1, num_generated_tokens), | |
dtype=original_mask.dtype, | |
device=original_mask.device | |
) | |
full_attention_mask = torch.cat([original_mask, generated_mask], dim=1) | |
with torch.no_grad(): | |
outputs = vl_model( | |
input_ids=generated_ids, | |
pixel_values=inputs.get('pixel_values'), | |
image_grid_thw=inputs.get('image_grid_thw'), | |
attention_mask=full_attention_mask | |
) | |
logits = outputs.logits.float() | |
logits_of_generated_part = logits[:, input_token_len - 1:-1, :] | |
labels_of_generated_part = generated_ids[:, input_token_len:] | |
all_probs = torch.softmax(logits_of_generated_part, dim=-1) | |
chosen_probs = torch.gather(all_probs, 2, labels_of_generated_part.unsqueeze(-1)).squeeze(-1) | |
generated_token_ids_only = generated_ids[0, input_token_len:] | |
probs_list = chosen_probs[0].cpu().tolist() | |
highlighted_data = [] | |
for token_id, prob in zip(generated_token_ids_only.tolist(), probs_list): | |
token_str = vl_processor.decode([token_id]) | |
highlighted_data.append((token_str, float(prob))) | |
if not highlighted_data: | |
return [("Model did not generate any new tokens.", None)] | |
return highlighted_data | |
except Exception as e: | |
import traceback | |
traceback.print_exc() | |
print(f"An error occurred during VL processing: {e}") | |
return [(f"An error occurred: {str(e)}", None)] | |
text_en_example = """A conversation between User and Assistant. The user asks a question, and the Assistant solves it. | |
The assistant first thinks about the reasoning process in the mind and then provides the user | |
with the answer. The reasoning process and answer are enclosed within <think> </think> and | |
<answer> </answer> tags, respectively, i.e., <think> reasoning process here </think> | |
<answer> answer here </answer>. User: What is 7 * 6? Assistant: <think> First, the user asked: "what is 7 * 6?" That's a multiplication problem. I need to calculate the product of 7 and 6. | |
I know my multiplication tables. 7 times 6 is 42. I can double-check: 7 Γ 6 means adding 7 six times: 7 + 7 + 7 + 7 + 7 + 7. Let's add that up: 7+7=14, 14+7=21, 21+7=28, 28+7=35, 35+7=42. Yes, that's 42. | |
I think that's fine. </think> <answer> 7 multiplied by 6 equals **42**. | |
If you have any more math questions or need an explanation, feel free to ask! π </answer>""" | |
with gr.Blocks(theme=gr.themes.Soft(), title="Qwen2.5 Token Visualizer") as demo: | |
gr.Markdown( | |
""" | |
# Qwen2.5 Series Token Probability Visualizer | |
This tool visualizes token probabilities for both text and vision-language models from the Qwen2.5 series. | |
The color of each token represents its conditional probability. | |
**<span style="color:red">Red</span> means high probability** (the model was confident), and **<span style="color:black">White</span> means low probability** (the model was surprised). | |
""" | |
) | |
with gr.Tabs(): | |
with gr.TabItem("Text Model (Qwen2.5-7B)"): | |
gr.Markdown("### Analyze Probabilities of Given Text") | |
with gr.Row(): | |
text_input = gr.Textbox( | |
label="Input Text", lines=15, value=text_en_example, | |
placeholder="Enter text here to analyze..." | |
) | |
with gr.Row(): | |
text_submit_btn = gr.Button("Visualize Probabilities", variant="primary") | |
text_output_highlight = gr.HighlightedText( | |
label="Token Probabilities (High: Red, Low: White)", show_legend=True, | |
combine_adjacent=False, | |
) | |
gr.Examples( | |
examples=[[text_en_example]], inputs=text_input, outputs=text_output_highlight, | |
fn=visualize_text_token_probabilities, cache_examples=False | |
) | |
text_submit_btn.click( | |
fn=visualize_text_token_probabilities, inputs=text_input, outputs=text_output_highlight, | |
api_name="visualize_text" | |
) | |
with gr.TabItem("Vision-Language Model (Qwen2.5-VL-7B-Instruct)"): | |
gr.Markdown("### Generate Text from Image and Visualize Probabilities") | |
with gr.Row(): | |
with gr.Column(): | |
vl_image_input = gr.Image(type="pil", label="Upload Image") | |
vl_text_input = gr.Textbox(label="Your Question", placeholder="e.g., Describe this image.") | |
vl_submit_btn = gr.Button("Generate and Visualize", variant="primary") | |
with gr.Column(): | |
vl_output_highlight = gr.HighlightedText( | |
label="Generated Token Probabilities (High: Red, Low: White)", show_legend=True, | |
combine_adjacent=False, | |
) | |
gr.Examples( | |
examples=[["demo.jpeg", "Describe this image in detail."]], | |
inputs=[vl_image_input, vl_text_input], | |
outputs=vl_output_highlight, | |
fn=generate_and_visualize_vl_probabilities, | |
cache_examples=False | |
) | |
vl_submit_btn.click( | |
fn=generate_and_visualize_vl_probabilities, inputs=[vl_image_input, vl_text_input], | |
outputs=vl_output_highlight, api_name="visualize_vl_generation" | |
) | |
if __name__ == "__main__": | |
if not os.path.exists("demo.jpeg"): | |
try: | |
from PIL import Image, ImageDraw, ImageFont | |
img = Image.new('RGB', (400, 200), color = (73, 109, 137)) | |
d = ImageDraw.Draw(img) | |
try: | |
font = ImageFont.truetype("arial.ttf", 20) | |
except IOError: | |
font = ImageFont.load_default() | |
d.text((10,10), "This is a demo image for Gradio.", font=font, fill=(255,255,0)) | |
img.save("demo.jpeg") | |
print("Created a dummy 'demo.jpeg' for the example.") | |
except Exception as e: | |
print(f"Could not create a dummy image: {e}") | |
demo.queue().launch(share=True) |