MaxiiMin commited on
Commit
2142b71
·
verified ·
1 Parent(s): 10e57d7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +232 -37
app.py CHANGED
@@ -1,44 +1,239 @@
1
  import gradio as gr
2
  import torch
3
- from transformers import AutoModelForCausalLM, AutoTokenizer
4
-
5
- model_name = "Qwen/Qwen2.5-0.5B"
6
- model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")
7
- tokenizer = AutoTokenizer.from_pretrained(model_name)
8
-
9
- def highlight_probabilities(text):
10
- inputs = tokenizer([text], return_tensors="pt").input_ids.to(model.device)
11
- inp, outp = inputs[:, :-1], inputs[:, 1:].unsqueeze(-1)
12
-
13
- with torch.no_grad():
14
- logits = model(inp).logits
15
-
16
- probs = torch.softmax(logits, dim=-1)
17
- chosen = torch.gather(probs, dim=2, index=outp).squeeze(-1).cpu().numpy()[0]
18
-
19
- tokens = tokenizer.convert_ids_to_tokens(inp[0].cpu().tolist())
20
- highlights = [
21
- (tok.replace("Ġ", ""), float(p)) for tok, p in zip(tokens, chosen)
22
- ]
23
- return highlights
24
-
25
- with gr.Blocks() as demo:
26
- gr.Markdown("## Token-by-Token Probability Highlighter")
27
- txt = gr.Textbox(
28
- label="Input Text",
29
- placeholder="Type or paste any text here…" ,
30
- lines=4
31
  )
32
- highlighted = gr.HighlightedText(
33
- label="Token Probabilities",
34
- combine_adjacent=True,
35
- show_legend=True,
 
 
 
 
 
 
 
 
 
 
 
 
36
  )
37
- txt.change(
38
- fn=highlight_probabilities,
39
- inputs=txt,
40
- outputs=highlighted
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
  if __name__ == "__main__":
44
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
  import torch
3
+ import transformers
4
+ import os
5
+ from PIL import Image
6
+
7
+ os.environ['HF_HOME'] = '/data/hf_home'
8
+ os.environ['TRANSFORMERS_CACHE'] = '/data/transformers_cache'
9
+
10
+ def process_vision_info(messages):
11
+ image_inputs = []
12
+ video_inputs = []
13
+ for message in messages:
14
+ if message["role"] == "user":
15
+ content = message["content"]
16
+ for item in content:
17
+ if item["type"] == "image":
18
+ image_inputs.append(item["image"])
19
+ elif item["type"] == "video":
20
+ video_inputs.append(item["video"])
21
+ return image_inputs, video_inputs
22
+
23
+ print("Loading text model (Qwen/Qwen2.5-7B)...")
24
+ text_model_loaded = False
25
+ text_model_error = ""
26
+ try:
27
+ text_model = transformers.AutoModelForCausalLM.from_pretrained(
28
+ "Qwen/Qwen2.5-7B",
29
+ torch_dtype=torch.bfloat16,
30
+ device_map="auto"
31
  )
32
+ text_tokenizer = transformers.AutoTokenizer.from_pretrained("Qwen/Qwen2.5-7B")
33
+ text_model_loaded = True
34
+ print("Text model loaded successfully.")
35
+ except Exception as e:
36
+ text_model_error = str(e)
37
+ print(f"Error loading text model: {text_model_error}")
38
+ text_model, text_tokenizer = None, None
39
+
40
+ print("Loading Vision-Language model (Qwen/Qwen2.5-VL-7B-Instruct)...")
41
+ vl_model_loaded = False
42
+ vl_model_error = ""
43
+ try:
44
+ vl_model = transformers.Qwen2_5_VLForConditionalGeneration.from_pretrained(
45
+ "Qwen/Qwen2.5-VL-7B-Instruct",
46
+ torch_dtype=torch.bfloat16,
47
+ device_map="auto"
48
  )
49
+ vl_processor = transformers.AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct")
50
+ vl_model_loaded = True
51
+ print("Vision-Language model loaded successfully.")
52
+ except Exception as e:
53
+ vl_model_error = str(e)
54
+ print(f"Error loading Vision-Language model: {vl_model_error}")
55
+ vl_model, vl_processor = None, None
56
+
57
+ def visualize_text_token_probabilities(text: str):
58
+ if not text_model_loaded:
59
+ return [(f"Text Model failed to load: {text_model_error}", None)]
60
+ if not text or not text.strip():
61
+ return [("Please enter some text to analyze.", None)]
62
+
63
+ try:
64
+ inputs = text_tokenizer([text], return_tensors="pt").to(text_model.device)
65
+ input_ids = inputs.input_ids
66
+ if input_ids.shape[1] < 2:
67
+ token = text_tokenizer.decode(input_ids[0])
68
+ return [(token, None)]
69
+
70
+ inp = input_ids[:, :-1]
71
+ outp = input_ids[:, 1:].unsqueeze(-1)
72
+ with torch.no_grad():
73
+ logits = text_model(inp).logits.float()
74
+
75
+ all_probs = torch.softmax(logits, dim=-1)
76
+ chosen_probs = torch.gather(all_probs, dim=2, index=outp).squeeze(-1).cpu().numpy()[0]
77
+
78
+ highlighted_data = []
79
+ outp_tokens = input_ids[0, 1:].cpu().tolist()
80
+
81
+ first_token_str = text_tokenizer.decode([input_ids[0, 0].item()])
82
+ highlighted_data.append((first_token_str, None))
83
+
84
+ for token_id, prob in zip(outp_tokens, chosen_probs):
85
+ token_str = text_tokenizer.decode([token_id])
86
+ highlighted_data.append((token_str, float(prob)))
87
+
88
+ return highlighted_data
89
+ except Exception as e:
90
+ print(f"An error occurred during text processing: {e}")
91
+ return [(f"An error occurred: {str(e)}", None)]
92
+
93
+ def generate_and_visualize_vl_probabilities(image, prompt: str):
94
+ if not vl_model_loaded:
95
+ return [(f"Vision-Language Model failed to load: {vl_model_error}", None)]
96
+ if image is None or not prompt or not prompt.strip():
97
+ return [("Please upload an image and provide a text prompt.", None)]
98
+
99
+ try:
100
+ messages = [{"role": "user", "content": [{"type": "image", "image": image}, {"type": "text", "text": prompt.strip()}]}]
101
+ text = vl_processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
102
+ image_inputs, _ = process_vision_info(messages)
103
+ inputs = vl_processor(text=[text], images=image_inputs, padding=True, return_tensors="pt").to(vl_model.device)
104
+
105
+ with torch.no_grad():
106
+ generated_ids = vl_model.generate(**inputs, max_new_tokens=512)
107
+
108
+ input_token_len = inputs.input_ids.shape[1]
109
+ if generated_ids.shape[1] <= input_token_len:
110
+ return [("Model did not generate any new tokens.", None)]
111
+
112
+ original_mask = inputs.attention_mask
113
+ num_generated_tokens = generated_ids.shape[1] - input_token_len
114
+ generated_mask = torch.ones(
115
+ (1, num_generated_tokens),
116
+ dtype=original_mask.dtype,
117
+ device=original_mask.device
118
+ )
119
+ full_attention_mask = torch.cat([original_mask, generated_mask], dim=1)
120
+
121
+ with torch.no_grad():
122
+ outputs = vl_model(
123
+ input_ids=generated_ids,
124
+ pixel_values=inputs.get('pixel_values'),
125
+ image_grid_thw=inputs.get('image_grid_thw'),
126
+ attention_mask=full_attention_mask
127
+ )
128
+ logits = outputs.logits.float()
129
+
130
+ logits_of_generated_part = logits[:, input_token_len - 1:-1, :]
131
+ labels_of_generated_part = generated_ids[:, input_token_len:]
132
+
133
+ all_probs = torch.softmax(logits_of_generated_part, dim=-1)
134
+ chosen_probs = torch.gather(all_probs, 2, labels_of_generated_part.unsqueeze(-1)).squeeze(-1)
135
+
136
+ generated_token_ids_only = generated_ids[0, input_token_len:]
137
+ probs_list = chosen_probs[0].cpu().tolist()
138
+ highlighted_data = []
139
+
140
+ for token_id, prob in zip(generated_token_ids_only.tolist(), probs_list):
141
+ token_str = vl_processor.decode([token_id])
142
+ highlighted_data.append((token_str, float(prob)))
143
+
144
+ if not highlighted_data:
145
+ return [("Model did not generate any new tokens.", None)]
146
+ return highlighted_data
147
+ except Exception as e:
148
+ import traceback
149
+ traceback.print_exc()
150
+ print(f"An error occurred during VL processing: {e}")
151
+ return [(f"An error occurred: {str(e)}", None)]
152
+
153
+ text_en_example = """A conversation between User and Assistant. The user asks a question, and the Assistant solves it.
154
+ The assistant first thinks about the reasoning process in the mind and then provides the user
155
+ with the answer. The reasoning process and answer are enclosed within <think> </think> and
156
+ <answer> </answer> tags, respectively, i.e., <think> reasoning process here </think>
157
+ <answer> answer here </answer>. User: What is 7 * 6? Assistant: <think> First, the user asked: "what is 7 * 6?" That's a multiplication problem. I need to calculate the product of 7 and 6.
158
+
159
+ I know my multiplication tables. 7 times 6 is 42. I can double-check: 7 × 6 means adding 7 six times: 7 + 7 + 7 + 7 + 7 + 7. Let's add that up: 7+7=14, 14+7=21, 21+7=28, 28+7=35, 35+7=42. Yes, that's 42.
160
+
161
+ I think that's fine. </think> <answer> 7 multiplied by 6 equals **42**.
162
+
163
+ If you have any more math questions or need an explanation, feel free to ask! 😊 </answer>"""
164
+
165
+ with gr.Blocks(theme=gr.themes.Soft(), title="Qwen2.5 Token Visualizer") as demo:
166
+ gr.Markdown(
167
+ """
168
+ # Qwen2.5 Series Token Probability Visualizer
169
+ This tool visualizes token probabilities for both text and vision-language models from the Qwen2.5 series.
170
+ The color of each token represents its conditional probability.
171
+ **<span style="color:red">Red</span> means high probability** (the model was confident), and **<span style="color:black">White</span> means low probability** (the model was surprised).
172
+ """
173
  )
174
+ with gr.Tabs():
175
+ with gr.TabItem("Text Model (Qwen2.5-7B)"):
176
+ gr.Markdown("### Analyze Probabilities of Given Text")
177
+ with gr.Row():
178
+ text_input = gr.Textbox(
179
+ label="Input Text", lines=15, value=text_en_example,
180
+ placeholder="Enter text here to analyze..."
181
+ )
182
+ with gr.Row():
183
+ text_submit_btn = gr.Button("Visualize Probabilities", variant="primary")
184
+
185
+ text_output_highlight = gr.HighlightedText(
186
+ label="Token Probabilities (High: Red, Low: White)", show_legend=True,
187
+ combine_adjacent=False,
188
+ )
189
+ gr.Examples(
190
+ examples=[[text_en_example]], inputs=text_input, outputs=text_output_highlight,
191
+ fn=visualize_text_token_probabilities, cache_examples=False
192
+ )
193
+ text_submit_btn.click(
194
+ fn=visualize_text_token_probabilities, inputs=text_input, outputs=text_output_highlight,
195
+ api_name="visualize_text"
196
+ )
197
+
198
+ with gr.TabItem("Vision-Language Model (Qwen2.5-VL-7B-Instruct)"):
199
+ gr.Markdown("### Generate Text from Image and Visualize Probabilities")
200
+ with gr.Row():
201
+ with gr.Column():
202
+ vl_image_input = gr.Image(type="pil", label="Upload Image")
203
+ vl_text_input = gr.Textbox(label="Your Question", placeholder="e.g., Describe this image.")
204
+ vl_submit_btn = gr.Button("Generate and Visualize", variant="primary")
205
+ with gr.Column():
206
+ vl_output_highlight = gr.HighlightedText(
207
+ label="Generated Token Probabilities (High: Red, Low: White)", show_legend=True,
208
+ combine_adjacent=False,
209
+ )
210
+
211
+ gr.Examples(
212
+ examples=[["demo.jpeg", "Describe this image in detail."]],
213
+ inputs=[vl_image_input, vl_text_input],
214
+ outputs=vl_output_highlight,
215
+ fn=generate_and_visualize_vl_probabilities,
216
+ cache_examples=False
217
+ )
218
+ vl_submit_btn.click(
219
+ fn=generate_and_visualize_vl_probabilities, inputs=[vl_image_input, vl_text_input],
220
+ outputs=vl_output_highlight, api_name="visualize_vl_generation"
221
+ )
222
 
223
  if __name__ == "__main__":
224
+ if not os.path.exists("demo.jpeg"):
225
+ try:
226
+ from PIL import Image, ImageDraw, ImageFont
227
+ img = Image.new('RGB', (400, 200), color = (73, 109, 137))
228
+ d = ImageDraw.Draw(img)
229
+ try:
230
+ font = ImageFont.truetype("arial.ttf", 20)
231
+ except IOError:
232
+ font = ImageFont.load_default()
233
+ d.text((10,10), "This is a demo image for Gradio.", font=font, fill=(255,255,0))
234
+ img.save("demo.jpeg")
235
+ print("Created a dummy 'demo.jpeg' for the example.")
236
+ except Exception as e:
237
+ print(f"Could not create a dummy image: {e}")
238
+
239
+ demo.queue().launch(share=True)