File size: 10,547 Bytes
ac52ed4
 
2142b71
 
 
432b0ac
2142b71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ac52ed4
2142b71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ac52ed4
2142b71
 
 
 
 
 
 
 
cf5c044
2142b71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
432b0ac
2142b71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ac52ed4
2142b71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ac52ed4
 
2142b71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
import gradio as gr
import torch
import transformers
import os
from PIL import Image
import spaces

def process_vision_info(messages):
    image_inputs = []
    video_inputs = []
    for message in messages:
        if message["role"] == "user":
            content = message["content"]
            for item in content:
                if item["type"] == "image":
                    image_inputs.append(item["image"])
                elif item["type"] == "video":
                    video_inputs.append(item["video"])
    return image_inputs, video_inputs

print("Loading text model (Qwen/Qwen2.5-7B)...")
text_model_loaded = False
text_model_error = ""
try:
    text_model = transformers.AutoModelForCausalLM.from_pretrained(
        "Qwen/Qwen2.5-7B",
        torch_dtype=torch.bfloat16,
        device_map="auto"
    )
    text_tokenizer = transformers.AutoTokenizer.from_pretrained("Qwen/Qwen2.5-7B")
    text_model_loaded = True
    print("Text model loaded successfully.")
except Exception as e:
    text_model_error = str(e)
    print(f"Error loading text model: {text_model_error}")
    text_model, text_tokenizer = None, None

print("Loading Vision-Language model (Qwen/Qwen2.5-VL-7B-Instruct)...")
vl_model_loaded = False
vl_model_error = ""
try:
    vl_model = transformers.Qwen2_5_VLForConditionalGeneration.from_pretrained(
        "Qwen/Qwen2.5-VL-7B-Instruct",
        torch_dtype=torch.bfloat16,
        device_map="auto"
    )
    vl_processor = transformers.AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct")
    vl_model_loaded = True
    print("Vision-Language model loaded successfully.")
except Exception as e:
    vl_model_error = str(e)
    print(f"Error loading Vision-Language model: {vl_model_error}")
    vl_model, vl_processor = None, None

@spaces.GPU
def visualize_text_token_probabilities(text: str):
    if not text_model_loaded:
        return [(f"Text Model failed to load: {text_model_error}", None)]
    if not text or not text.strip():
        return [("Please enter some text to analyze.", None)]

    try:
        inputs = text_tokenizer([text], return_tensors="pt").to(text_model.device)
        input_ids = inputs.input_ids
        if input_ids.shape[1] < 2:
            token = text_tokenizer.decode(input_ids[0])
            return [(token, None)]

        inp = input_ids[:, :-1]
        outp = input_ids[:, 1:].unsqueeze(-1)
        with torch.no_grad():
            logits = text_model(inp).logits.float()

        all_probs = torch.softmax(logits, dim=-1)
        chosen_probs = torch.gather(all_probs, dim=2, index=outp).squeeze(-1).cpu().numpy()[0]

        highlighted_data = []
        outp_tokens = input_ids[0, 1:].cpu().tolist()

        first_token_str = text_tokenizer.decode([input_ids[0, 0].item()])
        highlighted_data.append((first_token_str, None))

        for token_id, prob in zip(outp_tokens, chosen_probs):
            token_str = text_tokenizer.decode([token_id])
            highlighted_data.append((token_str, float(prob)))

        return highlighted_data
    except Exception as e:
        print(f"An error occurred during text processing: {e}")
        return [(f"An error occurred: {str(e)}", None)]

@spaces.GPU
def generate_and_visualize_vl_probabilities(image, prompt: str):
    if not vl_model_loaded:
        return [(f"Vision-Language Model failed to load: {vl_model_error}", None)]
    if image is None or not prompt or not prompt.strip():
        return [("Please upload an image and provide a text prompt.", None)]

    try:
        messages = [{"role": "user", "content": [{"type": "image", "image": image}, {"type": "text", "text": prompt.strip()}]}]
        text = vl_processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
        image_inputs, _ = process_vision_info(messages)
        inputs = vl_processor(text=[text], images=image_inputs, padding=True, return_tensors="pt").to(vl_model.device)

        with torch.no_grad():
            generated_ids = vl_model.generate(**inputs, max_new_tokens=512)

        input_token_len = inputs.input_ids.shape[1]
        if generated_ids.shape[1] <= input_token_len:
            return [("Model did not generate any new tokens.", None)]

        original_mask = inputs.attention_mask
        num_generated_tokens = generated_ids.shape[1] - input_token_len
        generated_mask = torch.ones(
            (1, num_generated_tokens),
            dtype=original_mask.dtype,
            device=original_mask.device
        )
        full_attention_mask = torch.cat([original_mask, generated_mask], dim=1)

        with torch.no_grad():
            outputs = vl_model(
                input_ids=generated_ids,
                pixel_values=inputs.get('pixel_values'),
                image_grid_thw=inputs.get('image_grid_thw'),
                attention_mask=full_attention_mask
            )
            logits = outputs.logits.float()

        logits_of_generated_part = logits[:, input_token_len - 1:-1, :]
        labels_of_generated_part = generated_ids[:, input_token_len:]

        all_probs = torch.softmax(logits_of_generated_part, dim=-1)
        chosen_probs = torch.gather(all_probs, 2, labels_of_generated_part.unsqueeze(-1)).squeeze(-1)

        generated_token_ids_only = generated_ids[0, input_token_len:]
        probs_list = chosen_probs[0].cpu().tolist()
        highlighted_data = []

        for token_id, prob in zip(generated_token_ids_only.tolist(), probs_list):
            token_str = vl_processor.decode([token_id])
            highlighted_data.append((token_str, float(prob)))

        if not highlighted_data:
             return [("Model did not generate any new tokens.", None)]
        return highlighted_data
    except Exception as e:
        import traceback
        traceback.print_exc()
        print(f"An error occurred during VL processing: {e}")
        return [(f"An error occurred: {str(e)}", None)]

text_en_example = """A conversation between User and Assistant. The user asks a question, and the Assistant solves it.
The assistant first thinks about the reasoning process in the mind and then provides the user
with the answer. The reasoning process and answer are enclosed within <think> </think> and
<answer> </answer> tags, respectively, i.e., <think> reasoning process here </think>
<answer> answer here </answer>. User: What is 7 * 6? Assistant: <think> First, the user asked: "what is 7 * 6?" That's a multiplication problem. I need to calculate the product of 7 and 6.

I know my multiplication tables. 7 times 6 is 42. I can double-check: 7 Γ— 6 means adding 7 six times: 7 + 7 + 7 + 7 + 7 + 7. Let's add that up: 7+7=14, 14+7=21, 21+7=28, 28+7=35, 35+7=42. Yes, that's 42.

I think that's fine. </think> <answer> 7 multiplied by 6 equals **42**.

If you have any more math questions or need an explanation, feel free to ask! 😊 </answer>"""

with gr.Blocks(theme=gr.themes.Soft(), title="Qwen2.5 Token Visualizer") as demo:
    gr.Markdown(
        """
        # Qwen2.5 Series Token Probability Visualizer
        This tool visualizes token probabilities for both text and vision-language models from the Qwen2.5 series.
        The color of each token represents its conditional probability.
        **<span style="color:red">Red</span> means high probability** (the model was confident), and **<span style="color:black">White</span> means low probability** (the model was surprised).
        """
    )
    with gr.Tabs():
        with gr.TabItem("Text Model (Qwen2.5-7B)"):
            gr.Markdown("### Analyze Probabilities of Given Text")
            with gr.Row():
                text_input = gr.Textbox(
                    label="Input Text", lines=15, value=text_en_example,
                    placeholder="Enter text here to analyze..."
                )
            with gr.Row():
                text_submit_btn = gr.Button("Visualize Probabilities", variant="primary")

            text_output_highlight = gr.HighlightedText(
                label="Token Probabilities (High: Red, Low: White)", show_legend=True,
                combine_adjacent=False,
            )
            gr.Examples(
                examples=[[text_en_example]], inputs=text_input, outputs=text_output_highlight,
                fn=visualize_text_token_probabilities, cache_examples=False
            )
            text_submit_btn.click(
                fn=visualize_text_token_probabilities, inputs=text_input, outputs=text_output_highlight,
                api_name="visualize_text"
            )

        with gr.TabItem("Vision-Language Model (Qwen2.5-VL-7B-Instruct)"):
            gr.Markdown("### Generate Text from Image and Visualize Probabilities")
            with gr.Row():
                with gr.Column():
                    vl_image_input = gr.Image(type="pil", label="Upload Image")
                    vl_text_input = gr.Textbox(label="Your Question", placeholder="e.g., Describe this image.")
                    vl_submit_btn = gr.Button("Generate and Visualize", variant="primary")
                with gr.Column():
                    vl_output_highlight = gr.HighlightedText(
                        label="Generated Token Probabilities (High: Red, Low: White)", show_legend=True,
                        combine_adjacent=False,
                    )

            gr.Examples(
                examples=[["demo.jpeg", "Describe this image in detail."]],
                inputs=[vl_image_input, vl_text_input],
                outputs=vl_output_highlight,
                fn=generate_and_visualize_vl_probabilities,
                cache_examples=False
            )
            vl_submit_btn.click(
                fn=generate_and_visualize_vl_probabilities, inputs=[vl_image_input, vl_text_input],
                outputs=vl_output_highlight, api_name="visualize_vl_generation"
            )

if __name__ == "__main__":
    if not os.path.exists("demo.jpeg"):
        try:
            from PIL import Image, ImageDraw, ImageFont
            img = Image.new('RGB', (400, 200), color = (73, 109, 137))
            d = ImageDraw.Draw(img)
            try:
                font = ImageFont.truetype("arial.ttf", 20)
            except IOError:
                font = ImageFont.load_default()
            d.text((10,10), "This is a demo image for Gradio.", font=font, fill=(255,255,0))
            img.save("demo.jpeg")
            print("Created a dummy 'demo.jpeg' for the example.")
        except Exception as e:
            print(f"Could not create a dummy image: {e}")

    demo.queue().launch(share=True)