Nealeon commited on
Commit
79cf446
·
0 Parent(s):

chore: rebase commits

Browse files
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Chat with DeepSeek VL 7B
3
+ emoji: 🐬
4
+ colorFrom: blue
5
+ colorTo: red
6
+ sdk: gradio
7
+ sdk_version: 4.21.0
8
+ app_file: app.py
9
+ pinned: false
10
+ license: mit
11
+ ---
12
+
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,511 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023-2024 DeepSeek.
2
+ #
3
+ # Permission is hereby granted, free of charge, to any person obtaining a copy of
4
+ # this software and associated documentation files (the "Software"), to deal in
5
+ # the Software without restriction, including without limitation the rights to
6
+ # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
7
+ # the Software, and to permit persons to whom the Software is furnished to do so,
8
+ # subject to the following conditions:
9
+ #
10
+ # The above copyright notice and this permission notice shall be included in all
11
+ # copies or substantial portions of the Software.
12
+ #
13
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
15
+ # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
16
+ # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
17
+ # IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
18
+ # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
19
+
20
+ # -*- coding:utf-8 -*-
21
+
22
+ import base64
23
+ from io import BytesIO
24
+
25
+ import gradio as gr
26
+ import torch
27
+ from app_modules.gradio_utils import (
28
+ cancel_outputing,
29
+ delete_last_conversation,
30
+ reset_state,
31
+ reset_textbox,
32
+ transfer_input,
33
+ wrap_gen_fn,
34
+ )
35
+ from app_modules.overwrites import reload_javascript
36
+ from app_modules.presets import CONCURRENT_COUNT, description, description_top, title
37
+ from app_modules.utils import configure_logger, is_variable_assigned, strip_stop_words
38
+
39
+ from inference import (
40
+ convert_conversation_to_prompts,
41
+ deepseek_generate,
42
+ load_model,
43
+ )
44
+ from app_modules.conversation import SeparatorStyle
45
+
46
+
47
+ def load_models():
48
+ models = {
49
+ "DeepSeek-VL 7B": "deepseek-ai/deepseek-vl-7b-chat",
50
+ }
51
+
52
+ for model_name in models:
53
+ models[model_name] = load_model(models[model_name])
54
+
55
+ return models
56
+
57
+
58
+ logger = configure_logger()
59
+ models = load_models()
60
+ MODELS = sorted(list(models.keys()))
61
+
62
+
63
+ def generate_prompt_with_history(
64
+ text, image, history, vl_chat_processor, tokenizer, max_length=2048
65
+ ):
66
+ """
67
+ Generate a prompt with history for the deepseek application.
68
+
69
+ Args:
70
+ text (str): The text prompt.
71
+ image (str): The image prompt.
72
+ history (list): List of previous conversation messages.
73
+ tokenizer: The tokenizer used for encoding the prompt.
74
+ max_length (int): The maximum length of the prompt.
75
+
76
+ Returns:
77
+ tuple: A tuple containing the generated prompt, image list, conversation, and conversation copy. If the prompt could not be generated within the max_length limit, returns None.
78
+ """
79
+
80
+ sft_format = "deepseek"
81
+ user_role_ind = 0
82
+ bot_role_ind = 1
83
+
84
+ # Initialize conversation
85
+ conversation = vl_chat_processor.new_chat_template()
86
+
87
+ if history:
88
+ conversation.messages = history
89
+
90
+ if image is not None:
91
+ if "<image_placeholder>" not in text:
92
+ text = (
93
+ "<image_placeholder>" + "\n" + text
94
+ ) # append the <image_placeholder> in a new line after the text prompt
95
+ text = (text, image)
96
+
97
+ conversation.append_message(conversation.roles[user_role_ind], text)
98
+ conversation.append_message(conversation.roles[bot_role_ind], "")
99
+
100
+ # Create a copy of the conversation to avoid history truncation in the UI
101
+ conversation_copy = conversation.copy()
102
+ logger.info("=" * 80)
103
+ logger.info(get_prompt(conversation))
104
+
105
+ rounds = len(conversation.messages) // 2
106
+
107
+ for _ in range(rounds):
108
+ current_prompt = get_prompt(conversation)
109
+ current_prompt = (
110
+ current_prompt.replace("</s>", "")
111
+ if sft_format == "deepseek"
112
+ else current_prompt
113
+ )
114
+
115
+ if torch.tensor(tokenizer.encode(current_prompt)).size(-1) <= max_length:
116
+ return conversation_copy
117
+
118
+ if len(conversation.messages) % 2 != 0:
119
+ gr.Error("The messages between user and assistant are not paired.")
120
+ return
121
+
122
+ try:
123
+ for _ in range(2): # pop out two messages in a row
124
+ conversation.messages.pop(0)
125
+ except IndexError:
126
+ gr.Error("Input text processing failed, unable to respond in this round.")
127
+ return None
128
+
129
+ gr.Error("Prompt could not be generated within max_length limit.")
130
+ return None
131
+
132
+
133
+ def to_gradio_chatbot(conv):
134
+ """Convert the conversation to gradio chatbot format."""
135
+ ret = []
136
+ for i, (role, msg) in enumerate(conv.messages[conv.offset :]):
137
+ if i % 2 == 0:
138
+ if type(msg) is tuple:
139
+ msg, image = msg
140
+ if isinstance(image, str):
141
+ with open(image, "rb") as f:
142
+ data = f.read()
143
+ img_b64_str = base64.b64encode(data).decode()
144
+ image_str = f'<video src="data:video/mp4;base64,{img_b64_str}" controls width="426" height="240"></video>'
145
+ msg = msg.replace("\n".join(["<image_placeholder>"] * 4), image_str)
146
+ else:
147
+ max_hw, min_hw = max(image.size), min(image.size)
148
+ aspect_ratio = max_hw / min_hw
149
+ max_len, min_len = 800, 400
150
+ shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
151
+ longest_edge = int(shortest_edge * aspect_ratio)
152
+ W, H = image.size
153
+ if H > W:
154
+ H, W = longest_edge, shortest_edge
155
+ else:
156
+ H, W = shortest_edge, longest_edge
157
+ image = image.resize((W, H))
158
+ buffered = BytesIO()
159
+ image.save(buffered, format="JPEG")
160
+ img_b64_str = base64.b64encode(buffered.getvalue()).decode()
161
+ img_str = f'<img src="data:image/png;base64,{img_b64_str}" alt="user upload image" />'
162
+ msg = msg.replace("<image_placeholder>", img_str)
163
+ ret.append([msg, None])
164
+ else:
165
+ ret[-1][-1] = msg
166
+ return ret
167
+
168
+
169
+ def to_gradio_history(conv):
170
+ """Convert the conversation to gradio history state."""
171
+ return conv.messages[conv.offset :]
172
+
173
+
174
+ def get_prompt(conv) -> str:
175
+ """Get the prompt for generation."""
176
+ system_prompt = conv.system_template.format(system_message=conv.system_message)
177
+ if conv.sep_style == SeparatorStyle.DeepSeek:
178
+ seps = [conv.sep, conv.sep2]
179
+ if system_prompt == "" or system_prompt is None:
180
+ ret = ""
181
+ else:
182
+ ret = system_prompt + seps[0]
183
+ for i, (role, message) in enumerate(conv.messages):
184
+ if message:
185
+ if type(message) is tuple: # multimodal message
186
+ message, _ = message
187
+ ret += role + ": " + message + seps[i % 2]
188
+ else:
189
+ ret += role + ":"
190
+ return ret
191
+ else:
192
+ return conv.get_prompt
193
+
194
+
195
+ @wrap_gen_fn
196
+ def predict(
197
+ text,
198
+ image,
199
+ chatbot,
200
+ history,
201
+ top_p,
202
+ temperature,
203
+ repetition_penalty,
204
+ max_length_tokens,
205
+ max_context_length_tokens,
206
+ model_select_dropdown,
207
+ ):
208
+ """
209
+ Function to predict the response based on the user's input and selected model.
210
+
211
+ Parameters:
212
+ user_text (str): The input text from the user.
213
+ user_image (str): The input image from the user.
214
+ chatbot (str): The chatbot's name.
215
+ history (str): The history of the chat.
216
+ top_p (float): The top-p parameter for the model.
217
+ temperature (float): The temperature parameter for the model.
218
+ max_length_tokens (int): The maximum length of tokens for the model.
219
+ max_context_length_tokens (int): The maximum length of context tokens for the model.
220
+ model_select_dropdown (str): The selected model from the dropdown.
221
+
222
+ Returns:
223
+ generator: A generator that yields the chatbot outputs, history, and status.
224
+ """
225
+ print("running the prediction function")
226
+ try:
227
+ tokenizer, vl_gpt, vl_chat_processor = models[model_select_dropdown]
228
+
229
+ if text == "":
230
+ yield chatbot, history, "Empty context."
231
+ return
232
+ except KeyError:
233
+ yield [[text, "No Model Found"]], [], "No Model Found"
234
+ return
235
+
236
+ conversation = generate_prompt_with_history(
237
+ text,
238
+ image,
239
+ history,
240
+ vl_chat_processor,
241
+ tokenizer,
242
+ max_length=max_context_length_tokens,
243
+ )
244
+ prompts = convert_conversation_to_prompts(conversation)
245
+
246
+ stop_words = conversation.stop_str
247
+ gradio_chatbot_output = to_gradio_chatbot(conversation)
248
+
249
+ full_response = ""
250
+ with torch.no_grad():
251
+ for x in deepseek_generate(
252
+ prompts=prompts,
253
+ vl_gpt=vl_gpt,
254
+ vl_chat_processor=vl_chat_processor,
255
+ tokenizer=tokenizer,
256
+ stop_words=stop_words,
257
+ max_length=max_length_tokens,
258
+ temperature=temperature,
259
+ repetition_penalty=repetition_penalty,
260
+ top_p=top_p,
261
+ ):
262
+ full_response += x
263
+ response = strip_stop_words(full_response, stop_words)
264
+ conversation.update_last_message(response)
265
+ gradio_chatbot_output[-1][1] = response
266
+ yield gradio_chatbot_output, to_gradio_history(
267
+ conversation
268
+ ), "Generating..."
269
+
270
+ print("flushed result to gradio")
271
+ torch.cuda.empty_cache()
272
+
273
+ if is_variable_assigned("x"):
274
+ print(f"{model_select_dropdown}:\n{text}\n{'-' * 80}\n{x}\n{'=' * 80}")
275
+ print(
276
+ f"temperature: {temperature}, top_p: {top_p}, repetition_penalty: {repetition_penalty}, max_length_tokens: {max_length_tokens}"
277
+ )
278
+
279
+ yield gradio_chatbot_output, to_gradio_history(conversation), "Generate: Success"
280
+
281
+
282
+ def retry(
283
+ text,
284
+ image,
285
+ chatbot,
286
+ history,
287
+ top_p,
288
+ temperature,
289
+ repetition_penalty,
290
+ max_length_tokens,
291
+ max_context_length_tokens,
292
+ model_select_dropdown,
293
+ ):
294
+ if len(history) == 0:
295
+ yield (chatbot, history, "Empty context")
296
+ return
297
+
298
+ chatbot.pop()
299
+ history.pop()
300
+ text = history.pop()[-1]
301
+ if type(text) is tuple:
302
+ text, image = text
303
+
304
+ yield from predict(
305
+ text,
306
+ image,
307
+ chatbot,
308
+ history,
309
+ top_p,
310
+ temperature,
311
+ repetition_penalty,
312
+ max_length_tokens,
313
+ max_context_length_tokens,
314
+ model_select_dropdown,
315
+ )
316
+
317
+
318
+ def build_demo(MODELS):
319
+ with open("assets/custom.css", "r", encoding="utf-8") as f:
320
+ customCSS = f.read()
321
+
322
+ with gr.Blocks(theme=gr.themes.Soft(spacing_size="md")) as demo:
323
+ history = gr.State([])
324
+ input_text = gr.State()
325
+ input_image = gr.State()
326
+
327
+ with gr.Row():
328
+ gr.HTML(title)
329
+ status_display = gr.Markdown("Success", elem_id="status_display")
330
+ gr.Markdown(description_top)
331
+
332
+ with gr.Row(equal_height=True):
333
+ with gr.Column(scale=4):
334
+ with gr.Row():
335
+ chatbot = gr.Chatbot(
336
+ elem_id="deepseek_chatbot",
337
+ show_share_button=True,
338
+ likeable=True,
339
+ bubble_full_width=False,
340
+ height=600,
341
+ )
342
+ with gr.Row():
343
+ with gr.Column(scale=4):
344
+ text_box = gr.Textbox(
345
+ show_label=False, placeholder="Enter text", container=False
346
+ )
347
+ with gr.Column(
348
+ min_width=70,
349
+ ):
350
+ submitBtn = gr.Button("Send")
351
+ with gr.Column(
352
+ min_width=70,
353
+ ):
354
+ cancelBtn = gr.Button("Stop")
355
+ with gr.Row():
356
+ emptyBtn = gr.Button(
357
+ "🧹 New Conversation",
358
+ )
359
+ retryBtn = gr.Button("🔄 Regenerate")
360
+ delLastBtn = gr.Button("🗑️ Remove Last Turn")
361
+
362
+ with gr.Column():
363
+ image_box = gr.Image(type="pil")
364
+
365
+ with gr.Tab(label="Parameter Setting") as parameter_row:
366
+ top_p = gr.Slider(
367
+ minimum=-0,
368
+ maximum=1.0,
369
+ value=0.95,
370
+ step=0.05,
371
+ interactive=True,
372
+ label="Top-p",
373
+ )
374
+ temperature = gr.Slider(
375
+ minimum=0,
376
+ maximum=1.0,
377
+ value=0.1,
378
+ step=0.1,
379
+ interactive=True,
380
+ label="Temperature",
381
+ )
382
+ repetition_penalty = gr.Slider(
383
+ minimum=0.0,
384
+ maximum=2.0,
385
+ value=1.1,
386
+ step=0.1,
387
+ interactive=True,
388
+ label="Repetition penalty",
389
+ )
390
+ max_length_tokens = gr.Slider(
391
+ minimum=0,
392
+ maximum=4096,
393
+ value=2048,
394
+ step=8,
395
+ interactive=True,
396
+ label="Max Generation Tokens",
397
+ )
398
+ max_context_length_tokens = gr.Slider(
399
+ minimum=0,
400
+ maximum=4096,
401
+ value=4096,
402
+ step=128,
403
+ interactive=True,
404
+ label="Max History Tokens",
405
+ )
406
+ model_select_dropdown = gr.Dropdown(
407
+ label="Select Models",
408
+ choices=MODELS,
409
+ multiselect=False,
410
+ value=MODELS[0],
411
+ interactive=True,
412
+ )
413
+
414
+ examples_list = [
415
+ [
416
+ "examples/rap.jpeg",
417
+ "Can you write me a master rap song that rhymes very well based on this image?",
418
+ ],
419
+ [
420
+ "examples/app.png",
421
+ "What is this app about?",
422
+ ],
423
+ [
424
+ "examples/pipeline.png",
425
+ "Help me write a python code based on the image.",
426
+ ],
427
+ [
428
+ "examples/chart.png",
429
+ "Could you help me to re-draw this picture with python codes?",
430
+ ],
431
+ [
432
+ "examples/mirror.png",
433
+ "How many people are there in the image. Why?",
434
+ ],
435
+ [
436
+ "examples/puzzle.png",
437
+ "Can this 2 pieces combine together?",
438
+ ],
439
+ ]
440
+ gr.Examples(examples=examples_list, inputs=[image_box, text_box])
441
+ gr.Markdown(description)
442
+
443
+ input_widgets = [
444
+ input_text,
445
+ input_image,
446
+ chatbot,
447
+ history,
448
+ top_p,
449
+ temperature,
450
+ repetition_penalty,
451
+ max_length_tokens,
452
+ max_context_length_tokens,
453
+ model_select_dropdown,
454
+ ]
455
+ output_widgets = [chatbot, history, status_display]
456
+
457
+ transfer_input_args = dict(
458
+ fn=transfer_input,
459
+ inputs=[text_box, image_box],
460
+ outputs=[input_text, input_image, text_box, image_box, submitBtn],
461
+ show_progress=True,
462
+ )
463
+
464
+ predict_args = dict(
465
+ fn=predict,
466
+ inputs=input_widgets,
467
+ outputs=output_widgets,
468
+ show_progress=True,
469
+ )
470
+
471
+ retry_args = dict(
472
+ fn=retry,
473
+ inputs=input_widgets,
474
+ outputs=output_widgets,
475
+ show_progress=True,
476
+ )
477
+
478
+ reset_args = dict(
479
+ fn=reset_textbox, inputs=[], outputs=[text_box, status_display]
480
+ )
481
+
482
+ predict_events = [
483
+ text_box.submit(**transfer_input_args).then(**predict_args),
484
+ submitBtn.click(**transfer_input_args).then(**predict_args),
485
+ ]
486
+
487
+ emptyBtn.click(reset_state, outputs=output_widgets, show_progress=True)
488
+ emptyBtn.click(**reset_args)
489
+ retryBtn.click(**retry_args)
490
+
491
+ delLastBtn.click(
492
+ delete_last_conversation,
493
+ [chatbot, history],
494
+ output_widgets,
495
+ show_progress=True,
496
+ )
497
+
498
+ cancelBtn.click(cancel_outputing, [], [status_display], cancels=predict_events)
499
+
500
+ return demo
501
+
502
+
503
+ if __name__ == "__main__":
504
+ demo = build_demo(MODELS)
505
+ demo.title = "DeepSeek-VL Chatbot"
506
+
507
+ reload_javascript()
508
+ demo.queue(max_size=20).launch(
509
+ share=False,
510
+ favicon_path="assets/favicon.ico",
511
+ )
app_modules/conversation.py ADDED
@@ -0,0 +1,348 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023-2024 DeepSeek.
2
+ #
3
+ # Permission is hereby granted, free of charge, to any person obtaining a copy of
4
+ # this software and associated documentation files (the "Software"), to deal in
5
+ # the Software without restriction, including without limitation the rights to
6
+ # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
7
+ # the Software, and to permit persons to whom the Software is furnished to do so,
8
+ # subject to the following conditions:
9
+ #
10
+ # The above copyright notice and this permission notice shall be included in all
11
+ # copies or substantial portions of the Software.
12
+ #
13
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
15
+ # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
16
+ # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
17
+ # IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
18
+ # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
19
+
20
+ """
21
+ From https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
22
+ """
23
+
24
+ import dataclasses
25
+ from enum import IntEnum, auto
26
+ from typing import Dict, List
27
+
28
+
29
+ class SeparatorStyle(IntEnum):
30
+ """Separator styles."""
31
+
32
+ ADD_COLON_SINGLE = auto()
33
+ ADD_COLON_TWO = auto()
34
+ ADD_COLON_SPACE_SINGLE = auto()
35
+ NO_COLON_SINGLE = auto()
36
+ NO_COLON_TWO = auto()
37
+ ADD_NEW_LINE_SINGLE = auto()
38
+ LLAMA2 = auto()
39
+ CHATGLM = auto()
40
+ CHATML = auto()
41
+ CHATINTERN = auto()
42
+ DOLLY = auto()
43
+ RWKV = auto()
44
+ PHOENIX = auto()
45
+ ROBIN = auto()
46
+ DeepSeek = auto()
47
+ PLAIN = auto()
48
+ ALIGNMENT = auto()
49
+
50
+
51
+ @dataclasses.dataclass
52
+ class Conversation:
53
+ """A class that manages prompt templates and keeps all conversation history."""
54
+
55
+ # The name of this template
56
+ name: str
57
+ # The template of the system prompt
58
+ system_template: str = "{system_message}"
59
+ # The system message
60
+ system_message: str = ""
61
+ # The names of two roles
62
+ roles: List[str] = (("USER", "ASSISTANT"),)
63
+ # All messages. Each item is (role, message).
64
+ messages: List[List[str]] = ()
65
+ # The number of few shot examples
66
+ offset: int = 0
67
+ # The separator style and configurations
68
+ sep_style: SeparatorStyle = SeparatorStyle.ADD_COLON_SINGLE
69
+ sep: str = "\n"
70
+ sep2: str = None
71
+ # Stop criteria (the default one is EOS token)
72
+ stop_str: str = None
73
+ # Stops generation if meeting any token in this list
74
+ stop_token_ids: List[int] = None
75
+
76
+ def get_prompt(self) -> str:
77
+ """Get the prompt for generation."""
78
+ system_prompt = self.system_template.format(system_message=self.system_message)
79
+
80
+ if self.sep_style == SeparatorStyle.DeepSeek:
81
+ seps = [self.sep, self.sep2]
82
+ if system_prompt == "" or system_prompt is None:
83
+ ret = ""
84
+ else:
85
+ ret = system_prompt + seps[0]
86
+ for i, (role, message) in enumerate(self.messages):
87
+ if message:
88
+ ret += role + ": " + message + seps[i % 2]
89
+ else:
90
+ ret += role + ":"
91
+ return ret
92
+ elif self.sep_style == SeparatorStyle.LLAMA2:
93
+ seps = [self.sep, self.sep2]
94
+ if self.system_message:
95
+ ret = system_prompt
96
+ else:
97
+ ret = "[INST] "
98
+ for i, (role, message) in enumerate(self.messages):
99
+ tag = self.roles[i % 2]
100
+ if message:
101
+ if type(message) is tuple: # multimodal message
102
+ message, _ = message
103
+ if i == 0:
104
+ ret += message + " "
105
+ else:
106
+ ret += tag + " " + message + seps[i % 2]
107
+ else:
108
+ ret += tag
109
+ return ret
110
+ elif self.sep_style == SeparatorStyle.PLAIN:
111
+ seps = [self.sep, self.sep2]
112
+ ret = ""
113
+ for i, (role, message) in enumerate(self.messages):
114
+ if message:
115
+ if type(message) is tuple:
116
+ message, _, _ = message
117
+ if i % 2 == 0:
118
+ ret += message + seps[i % 2]
119
+ else:
120
+ ret += message + seps[i % 2]
121
+ else:
122
+ ret += ""
123
+ return ret
124
+ elif self.sep_style == SeparatorStyle.ALIGNMENT:
125
+ seps = [self.sep, self.sep2]
126
+ ret = ""
127
+ for i, (role, message) in enumerate(self.messages):
128
+ if message:
129
+ if type(message) is tuple:
130
+ message, _, _ = message
131
+ if i % 2 == 0:
132
+ ret += "<image>\n" + seps[i % 2]
133
+ else:
134
+ ret += message + seps[i % 2]
135
+ else:
136
+ ret += ""
137
+ return ret
138
+ else:
139
+ raise ValueError(f"Invalid style: {self.sep_style}")
140
+
141
+ def get_prompt_for_current_round(self, content=None):
142
+ """Get current round formatted question prompt during sft training"""
143
+ if self.sep_style == SeparatorStyle.PLAIN:
144
+ formatted_question = "<image>\n"
145
+ elif self.sep_style == SeparatorStyle.DeepSeek:
146
+ formatted_question = (
147
+ f"{self.roles[0]}: " + content.strip() + self.sep + f"{self.roles[1]}:"
148
+ )
149
+ else:
150
+ raise ValueError(f"Unsupported sep_style: {self.sep_style}")
151
+ return formatted_question
152
+
153
+ def set_system_message(self, system_message: str):
154
+ """Set the system message."""
155
+ self.system_message = system_message
156
+
157
+ def append_message(self, role: str, message: str):
158
+ """Append a new message."""
159
+ self.messages.append([role, message])
160
+
161
+ def reset_message(self):
162
+ """Reset a new message."""
163
+ self.messages = []
164
+
165
+ def update_last_message(self, message: str):
166
+ """Update the last output.
167
+
168
+ The last message is typically set to be None when constructing the prompt,
169
+ so we need to update it in-place after getting the response from a model.
170
+ """
171
+ self.messages[-1][1] = message
172
+
173
+ def to_gradio_chatbot(self):
174
+ """Convert the conversation to gradio chatbot format."""
175
+ ret = []
176
+ for i, (role, msg) in enumerate(self.messages[self.offset :]):
177
+ if i % 2 == 0:
178
+ ret.append([msg, None])
179
+ else:
180
+ ret[-1][-1] = msg
181
+ return ret
182
+
183
+ def to_openai_api_messages(self):
184
+ """Convert the conversation to OpenAI chat completion format."""
185
+ system_prompt = self.system_template.format(system_message=self.system_message)
186
+ ret = [{"role": "system", "content": system_prompt}]
187
+
188
+ for i, (_, msg) in enumerate(self.messages[self.offset :]):
189
+ if i % 2 == 0:
190
+ ret.append({"role": "user", "content": msg})
191
+ else:
192
+ if msg is not None:
193
+ ret.append({"role": "assistant", "content": msg})
194
+ return ret
195
+
196
+ def copy(self):
197
+ return Conversation(
198
+ name=self.name,
199
+ system_template=self.system_template,
200
+ system_message=self.system_message,
201
+ roles=self.roles,
202
+ messages=[[x, y] for x, y in self.messages],
203
+ offset=self.offset,
204
+ sep_style=self.sep_style,
205
+ sep=self.sep,
206
+ sep2=self.sep2,
207
+ stop_str=self.stop_str,
208
+ stop_token_ids=self.stop_token_ids,
209
+ )
210
+
211
+ def dict(self):
212
+ return {
213
+ "template_name": self.name,
214
+ "system_message": self.system_message,
215
+ "roles": self.roles,
216
+ "messages": self.messages,
217
+ "offset": self.offset,
218
+ }
219
+
220
+
221
+ # A global registry for all conversation templates
222
+ conv_templates: Dict[str, Conversation] = {}
223
+
224
+
225
+ def register_conv_template(template: Conversation, override: bool = False):
226
+ """Register a new conversation template."""
227
+ if not override:
228
+ assert (
229
+ template.name not in conv_templates
230
+ ), f"{template.name} has been registered."
231
+
232
+ conv_templates[template.name] = template
233
+
234
+
235
+ def get_conv_template(name: str) -> Conversation:
236
+ """Get a conversation template."""
237
+ return conv_templates[name].copy()
238
+
239
+
240
+ # llava_llama2 template
241
+ register_conv_template(
242
+ Conversation(
243
+ name="llava_llama2",
244
+ system_message="You are a helpful language and vision assistant. "
245
+ "You are able to understand the visual content that the user provides, "
246
+ "and assist the user with a variety of tasks using natural language.",
247
+ system_template="[INST] <<SYS>>\n{system_message}\n<</SYS>>\n\n",
248
+ roles=("[INST]", "[/INST]"),
249
+ messages=(),
250
+ offset=0,
251
+ sep_style=SeparatorStyle.LLAMA2,
252
+ sep=" ",
253
+ sep2=" </s><s>",
254
+ stop_token_ids=[2],
255
+ )
256
+ )
257
+
258
+ # llama2 template
259
+ # reference: https://github.com/facebookresearch/llama/blob/cfc3fc8c1968d390eb830e65c63865e980873a06/llama/generation.py#L212
260
+ register_conv_template(
261
+ Conversation(
262
+ name="llama-2",
263
+ system_template="[INST] <<SYS>>\n{system_message}\n<</SYS>>\n\n",
264
+ roles=("[INST]", "[/INST]"),
265
+ messages=(),
266
+ offset=0,
267
+ sep_style=SeparatorStyle.LLAMA2,
268
+ sep=" ",
269
+ sep2=" </s><s>",
270
+ stop_token_ids=[2],
271
+ )
272
+ )
273
+
274
+
275
+ # deepseek template
276
+ register_conv_template(
277
+ Conversation(
278
+ name="deepseek",
279
+ system_template="{system_message}",
280
+ # system_message="You are a helpful assistant. Please answer truthfully and write out your "
281
+ # "thinking step by step to be sure you get the right answer.",
282
+ system_message="",
283
+ roles=("User", "Assistant"),
284
+ messages=(),
285
+ offset=0,
286
+ sep_style=SeparatorStyle.DeepSeek,
287
+ sep="\n\n",
288
+ sep2="<|end▁of▁sentence|>",
289
+ stop_token_ids=[100001],
290
+ stop_str=["User:", "<|end▁of▁sentence|>"],
291
+ )
292
+ )
293
+
294
+ register_conv_template(
295
+ Conversation(
296
+ name="plain",
297
+ system_template="",
298
+ system_message="",
299
+ roles=("", ""),
300
+ messages=(),
301
+ offset=0,
302
+ sep_style=SeparatorStyle.PLAIN,
303
+ sep="",
304
+ sep2="",
305
+ stop_token_ids=[2],
306
+ stop_str=["</s>"],
307
+ )
308
+ )
309
+
310
+
311
+ register_conv_template(
312
+ Conversation(
313
+ name="alignment",
314
+ system_template="",
315
+ system_message="",
316
+ roles=("", ""),
317
+ messages=(),
318
+ offset=0,
319
+ sep_style=SeparatorStyle.ALIGNMENT,
320
+ sep="",
321
+ sep2="",
322
+ stop_token_ids=[2],
323
+ stop_str=["</s>"],
324
+ )
325
+ )
326
+
327
+
328
+ if __name__ == "__main__":
329
+ # print("Llama-2 template:")
330
+ # conv = get_conv_template("llama-2")
331
+ # conv.set_system_message("You are a helpful, respectful and honest assistant.")
332
+ # conv.append_message(conv.roles[0], "Hello!")
333
+ # conv.append_message(conv.roles[1], "Hi!")
334
+ # conv.append_message(conv.roles[0], "How are you?")
335
+ # conv.append_message(conv.roles[1], None)
336
+ # print(conv.get_prompt())
337
+
338
+ # print("\n")
339
+
340
+ print("deepseek template:")
341
+ conv = get_conv_template("deepseek")
342
+ conv.append_message(conv.roles[0], "Hello!")
343
+ conv.append_message(conv.roles[1], "Hi! This is Tony.")
344
+ conv.append_message(conv.roles[0], "Who are you?")
345
+ conv.append_message(conv.roles[1], "I am a helpful assistant.")
346
+ conv.append_message(conv.roles[0], "How are you?")
347
+ conv.append_message(conv.roles[1], None)
348
+ print(conv.get_prompt())
app_modules/gradio_utils.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023-2024 DeepSeek.
2
+ #
3
+ # Permission is hereby granted, free of charge, to any person obtaining a copy of
4
+ # this software and associated documentation files (the "Software"), to deal in
5
+ # the Software without restriction, including without limitation the rights to
6
+ # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
7
+ # the Software, and to permit persons to whom the Software is furnished to do so,
8
+ # subject to the following conditions:
9
+ #
10
+ # The above copyright notice and this permission notice shall be included in all
11
+ # copies or substantial portions of the Software.
12
+ #
13
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
15
+ # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
16
+ # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
17
+ # IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
18
+ # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
19
+
20
+ from functools import wraps
21
+
22
+ import gradio as gr
23
+
24
+
25
+ def wrap_gen_fn(gen_fn):
26
+ @wraps(gen_fn)
27
+ def wrapped_gen_fn(prompt, *args, **kwargs):
28
+ try:
29
+ yield from gen_fn(prompt, *args, **kwargs)
30
+ except gr.Error as g_err:
31
+ raise g_err
32
+ except Exception as e:
33
+ raise gr.Error(f"Failed to generate text: {e}") from e
34
+
35
+ return wrapped_gen_fn
36
+
37
+
38
+ def delete_last_conversation(chatbot, history):
39
+ if len(history) % 2 != 0:
40
+ gr.Error("history length is not even")
41
+ return (
42
+ chatbot,
43
+ history,
44
+ "Delete Done",
45
+ )
46
+
47
+ if len(chatbot) > 0:
48
+ chatbot.pop()
49
+
50
+ if len(history) > 0 and len(history) % 2 == 0:
51
+ history.pop()
52
+ history.pop()
53
+
54
+ return (
55
+ chatbot,
56
+ history,
57
+ "Delete Done",
58
+ )
59
+
60
+
61
+ def reset_state():
62
+ return [], [], None, "Reset Done"
63
+
64
+
65
+ def reset_textbox():
66
+ return gr.update(value=""), ""
67
+
68
+
69
+ def cancel_outputing():
70
+ return "Stop Done"
71
+
72
+
73
+ def transfer_input(input_text, input_image):
74
+ print("transferring input text and input image")
75
+ return (
76
+ input_text,
77
+ input_image,
78
+ gr.update(value=""),
79
+ gr.update(value=None),
80
+ gr.Button(visible=True),
81
+ )
82
+
83
+
84
+ class State:
85
+ interrupted = False
86
+
87
+ def interrupt(self):
88
+ self.interrupted = True
89
+
90
+ def recover(self):
91
+ self.interrupted = False
92
+
93
+
94
+ shared_state = State()
app_modules/overwrites.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023-2024 DeepSeek.
2
+ #
3
+ # Permission is hereby granted, free of charge, to any person obtaining a copy of
4
+ # this software and associated documentation files (the "Software"), to deal in
5
+ # the Software without restriction, including without limitation the rights to
6
+ # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
7
+ # the Software, and to permit persons to whom the Software is furnished to do so,
8
+ # subject to the following conditions:
9
+ #
10
+ # The above copyright notice and this permission notice shall be included in all
11
+ # copies or substantial portions of the Software.
12
+ #
13
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
15
+ # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
16
+ # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
17
+ # IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
18
+ # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
19
+
20
+ from __future__ import annotations
21
+
22
+ import logging
23
+ from typing import List, Tuple
24
+
25
+ from app_modules.presets import gr
26
+ from app_modules.utils import convert_asis, convert_mdtext, detect_converted_mark
27
+
28
+
29
+ def compact_text_chunks(self, prompt, text_chunks: List[str]) -> List[str]:
30
+ logging.debug("Compacting text chunks...🚀🚀🚀")
31
+ combined_str = [c.strip() for c in text_chunks if c.strip()]
32
+ combined_str = [f"[{index+1}] {c}" for index, c in enumerate(combined_str)]
33
+ combined_str = "\n\n".join(combined_str)
34
+ # resplit based on self.max_chunk_overlap
35
+ text_splitter = self.get_text_splitter_given_prompt(prompt, 1, padding=1)
36
+ return text_splitter.split_text(combined_str)
37
+
38
+
39
+ def postprocess(
40
+ self, y: List[Tuple[str | None, str | None]]
41
+ ) -> List[Tuple[str | None, str | None]]:
42
+ """
43
+ Parameters:
44
+ y: List of tuples representing the message and response pairs. Each message and response should be a string, which may be in Markdown format.
45
+ Returns:
46
+ List of tuples representing the message and response. Each message and response will be a string of HTML.
47
+ """
48
+ if y is None or y == []:
49
+ return []
50
+ temp = []
51
+ for x in y:
52
+ user, bot = x
53
+ if not detect_converted_mark(user):
54
+ user = convert_asis(user)
55
+ if not detect_converted_mark(bot):
56
+ bot = convert_mdtext(bot)
57
+ temp.append((user, bot))
58
+ return temp
59
+
60
+
61
+ with open("assets/custom.js", "r", encoding="utf-8") as f, open(
62
+ "assets/Kelpy-Codos.js", "r", encoding="utf-8"
63
+ ) as f2:
64
+ customJS = f.read()
65
+ kelpyCodos = f2.read()
66
+
67
+
68
+ def reload_javascript():
69
+ print("Reloading javascript...")
70
+ js = f"<script>{customJS}</script><script>{kelpyCodos}</script>"
71
+
72
+ def template_response(*args, **kwargs):
73
+ res = GradioTemplateResponseOriginal(*args, **kwargs)
74
+ res.body = res.body.replace(b"</html>", f"{js}</html>".encode("utf8"))
75
+ res.init_headers()
76
+ return res
77
+
78
+ gr.routes.templates.TemplateResponse = template_response
79
+
80
+
81
+ GradioTemplateResponseOriginal = gr.routes.templates.TemplateResponse
app_modules/presets.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023-2024 DeepSeek.
2
+ #
3
+ # Permission is hereby granted, free of charge, to any person obtaining a copy of
4
+ # this software and associated documentation files (the "Software"), to deal in
5
+ # the Software without restriction, including without limitation the rights to
6
+ # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
7
+ # the Software, and to permit persons to whom the Software is furnished to do so,
8
+ # subject to the following conditions:
9
+ #
10
+ # The above copyright notice and this permission notice shall be included in all
11
+ # copies or substantial portions of the Software.
12
+ #
13
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
15
+ # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
16
+ # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
17
+ # IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
18
+ # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
19
+
20
+ # -*- coding:utf-8 -*-
21
+ import gradio as gr
22
+
23
+ title = """<h1 align="left" style="min-width:200px; margin-top:0;">Chat with DeepSeek-VL </h1>"""
24
+ description_top = """"""
25
+ description = """"""
26
+ CONCURRENT_COUNT = 10
27
+
28
+
29
+ ALREADY_CONVERTED_MARK = "<!-- ALREADY CONVERTED BY PARSER. -->"
30
+
31
+ small_and_beautiful_theme = gr.themes.Soft(
32
+ primary_hue=gr.themes.Color(
33
+ c50="#EBFAF2",
34
+ c100="#CFF3E1",
35
+ c200="#A8EAC8",
36
+ c300="#77DEA9",
37
+ c400="#3FD086",
38
+ c500="#02C160",
39
+ c600="#06AE56",
40
+ c700="#05974E",
41
+ c800="#057F45",
42
+ c900="#04673D",
43
+ c950="#2E5541",
44
+ name="small_and_beautiful",
45
+ ),
46
+ secondary_hue=gr.themes.Color(
47
+ c50="#576b95",
48
+ c100="#576b95",
49
+ c200="#576b95",
50
+ c300="#576b95",
51
+ c400="#576b95",
52
+ c500="#576b95",
53
+ c600="#576b95",
54
+ c700="#576b95",
55
+ c800="#576b95",
56
+ c900="#576b95",
57
+ c950="#576b95",
58
+ ),
59
+ neutral_hue=gr.themes.Color(
60
+ name="gray",
61
+ c50="#f6f7f8",
62
+ # c100="#f3f4f6",
63
+ c100="#F2F2F2",
64
+ c200="#e5e7eb",
65
+ c300="#d1d5db",
66
+ c400="#B2B2B2",
67
+ c500="#808080",
68
+ c600="#636363",
69
+ c700="#515151",
70
+ c800="#393939",
71
+ # c900="#272727",
72
+ c900="#2B2B2B",
73
+ c950="#171717",
74
+ ),
75
+ radius_size=gr.themes.sizes.radius_sm,
76
+ ).set(
77
+ # button_primary_background_fill="*primary_500",
78
+ button_primary_background_fill_dark="*primary_600",
79
+ # button_primary_background_fill_hover="*primary_400",
80
+ # button_primary_border_color="*primary_500",
81
+ button_primary_border_color_dark="*primary_600",
82
+ button_primary_text_color="white",
83
+ button_primary_text_color_dark="white",
84
+ button_secondary_background_fill="*neutral_100",
85
+ button_secondary_background_fill_hover="*neutral_50",
86
+ button_secondary_background_fill_dark="*neutral_900",
87
+ button_secondary_text_color="*neutral_800",
88
+ button_secondary_text_color_dark="white",
89
+ # background_fill_primary="#F7F7F7",
90
+ # background_fill_primary_dark="#1F1F1F",
91
+ # block_title_text_color="*primary_500",
92
+ block_title_background_fill_dark="*primary_900",
93
+ block_label_background_fill_dark="*primary_900",
94
+ input_background_fill="#F6F6F6",
95
+ # chatbot_code_background_color_dark="*neutral_950",
96
+ )
app_modules/utils.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023-2024 DeepSeek.
2
+ #
3
+ # Permission is hereby granted, free of charge, to any person obtaining a copy of
4
+ # this software and associated documentation files (the "Software"), to deal in
5
+ # the Software without restriction, including without limitation the rights to
6
+ # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
7
+ # the Software, and to permit persons to whom the Software is furnished to do so,
8
+ # subject to the following conditions:
9
+ #
10
+ # The above copyright notice and this permission notice shall be included in all
11
+ # copies or substantial portions of the Software.
12
+ #
13
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
15
+ # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
16
+ # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
17
+ # IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
18
+ # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
19
+
20
+ # -*- coding:utf-8 -*-
21
+ from __future__ import annotations
22
+
23
+ import html
24
+ import logging
25
+ import os
26
+ import re
27
+ import time
28
+
29
+ import mdtex2html
30
+ from app_modules.presets import ALREADY_CONVERTED_MARK
31
+ from markdown import markdown
32
+ from pygments import highlight
33
+ from pygments.formatters import HtmlFormatter
34
+ from pygments.lexers import ClassNotFound, get_lexer_by_name, guess_lexer
35
+
36
+ logger = logging.getLogger("gradio_logger")
37
+
38
+
39
+ def configure_logger():
40
+ logger = logging.getLogger("gradio_logger")
41
+ logger.setLevel(logging.DEBUG)
42
+
43
+ timestr = time.strftime("%Y%m%d-%H%M%S")
44
+ os.makedirs("logs", exist_ok=True)
45
+ file_handler = logging.FileHandler(
46
+ f"logs/{timestr}_gradio_log.log"
47
+ )
48
+ console_handler = logging.StreamHandler()
49
+
50
+ formatter = logging.Formatter(
51
+ "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
52
+ )
53
+ console_handler.setFormatter(formatter)
54
+ file_handler.setFormatter(formatter)
55
+
56
+ console_handler.setLevel(logging.INFO)
57
+ file_handler.setLevel(logging.INFO)
58
+
59
+ logger.addHandler(console_handler)
60
+ logger.addHandler(file_handler)
61
+
62
+ return logger
63
+
64
+
65
+ def strip_stop_words(x, stop_words):
66
+ for w in stop_words:
67
+ if w in x:
68
+ return x[: x.index(w)].strip()
69
+ return x.strip()
70
+
71
+
72
+ def format_output(history, text, x):
73
+ updated_history = history + [[text, x]]
74
+ a = [[y[0], convert_to_markdown(y[1])] for y in updated_history]
75
+ return a, updated_history
76
+
77
+
78
+ def markdown_to_html_with_syntax_highlight(md_str): # deprecated
79
+ def replacer(match):
80
+ lang = match.group(1) or "text"
81
+ code = match.group(2)
82
+
83
+ try:
84
+ lexer = get_lexer_by_name(lang, stripall=True)
85
+ except ValueError:
86
+ lexer = get_lexer_by_name("text", stripall=True)
87
+
88
+ formatter = HtmlFormatter()
89
+ highlighted_code = highlight(code, lexer, formatter)
90
+
91
+ return f'<pre><code class="{lang}">{highlighted_code}</code></pre>'
92
+
93
+ code_block_pattern = r"```(\w+)?\n([\s\S]+?)\n```"
94
+ md_str = re.sub(code_block_pattern, replacer, md_str, flags=re.MULTILINE)
95
+
96
+ html_str = markdown(md_str)
97
+ return html_str
98
+
99
+
100
+ def normalize_markdown(md_text: str) -> str: # deprecated
101
+ lines = md_text.split("\n")
102
+ normalized_lines = []
103
+ inside_list = False
104
+
105
+ for i, line in enumerate(lines):
106
+ if re.match(r"^(\d+\.|-|\*|\+)\s", line.strip()):
107
+ if not inside_list and i > 0 and lines[i - 1].strip() != "":
108
+ normalized_lines.append("")
109
+ inside_list = True
110
+ normalized_lines.append(line)
111
+ elif inside_list and line.strip() == "":
112
+ if i < len(lines) - 1 and not re.match(
113
+ r"^(\d+\.|-|\*|\+)\s", lines[i + 1].strip()
114
+ ):
115
+ normalized_lines.append(line)
116
+ continue
117
+ else:
118
+ inside_list = False
119
+ normalized_lines.append(line)
120
+
121
+ return "\n".join(normalized_lines)
122
+
123
+
124
+ def convert_mdtext(md_text):
125
+ code_block_pattern = re.compile(r"```(.*?)(?:```|$)", re.DOTALL)
126
+ inline_code_pattern = re.compile(r"`(.*?)`", re.DOTALL)
127
+ code_blocks = code_block_pattern.findall(md_text)
128
+ non_code_parts = code_block_pattern.split(md_text)[::2]
129
+
130
+ result = []
131
+ for non_code, code in zip(non_code_parts, code_blocks + [""]):
132
+ if non_code.strip():
133
+ non_code = normalize_markdown(non_code)
134
+ if inline_code_pattern.search(non_code):
135
+ result.append(markdown(non_code, extensions=["tables"]))
136
+ else:
137
+ result.append(mdtex2html.convert(non_code, extensions=["tables"]))
138
+ if code.strip():
139
+ code = f"\n```{code}\n\n```"
140
+ code = markdown_to_html_with_syntax_highlight(code)
141
+ result.append(code)
142
+ result = "".join(result)
143
+ result += ALREADY_CONVERTED_MARK
144
+ return result
145
+
146
+
147
+ def convert_asis(userinput):
148
+ return f'<p style="white-space:pre-wrap;">{html.escape(userinput)}</p>{ALREADY_CONVERTED_MARK}'
149
+
150
+
151
+ def is_stop_word_or_prefix(s: str, stop_words: list) -> bool:
152
+ return any(s.endswith(stop_word) for stop_word in stop_words)
153
+
154
+
155
+ def detect_converted_mark(userinput):
156
+ return bool(userinput.endswith(ALREADY_CONVERTED_MARK))
157
+
158
+
159
+ def detect_language(code):
160
+ first_line = "" if code.startswith("\n") else code.strip().split("\n", 1)[0]
161
+ language = first_line.lower() if first_line else ""
162
+ code_without_language = code[len(first_line) :].lstrip() if first_line else code
163
+ return language, code_without_language
164
+
165
+
166
+ def convert_to_markdown(text):
167
+ text = text.replace("$", "&#36;")
168
+ text = text.replace("\r\n", "\n")
169
+
170
+ def replace_leading_tabs_and_spaces(line):
171
+ new_line = []
172
+
173
+ for char in line:
174
+ if char == "\t":
175
+ new_line.append("&#9;")
176
+ elif char == " ":
177
+ new_line.append("&nbsp;")
178
+ else:
179
+ break
180
+ return "".join(new_line) + line[len(new_line) :]
181
+
182
+ markdown_text = ""
183
+ lines = text.split("\n")
184
+ in_code_block = False
185
+
186
+ for line in lines:
187
+ if in_code_block is False and line.startswith("```"):
188
+ in_code_block = True
189
+ markdown_text += f"{line}\n"
190
+ elif in_code_block is True and line.startswith("```"):
191
+ in_code_block = False
192
+ markdown_text += f"{line}\n"
193
+ elif in_code_block:
194
+ markdown_text += f"{line}\n"
195
+ else:
196
+ line = replace_leading_tabs_and_spaces(line)
197
+ line = re.sub(r"^(#)", r"\\\1", line)
198
+ markdown_text += f"{line} \n"
199
+
200
+ return markdown_text
201
+
202
+
203
+ def add_language_tag(text):
204
+ def detect_language(code_block):
205
+ try:
206
+ lexer = guess_lexer(code_block)
207
+ return lexer.name.lower()
208
+ except ClassNotFound:
209
+ return ""
210
+
211
+ code_block_pattern = re.compile(r"(```)(\w*\n[^`]+```)", re.MULTILINE)
212
+
213
+ def replacement(match):
214
+ code_block = match.group(2)
215
+ if match.group(2).startswith("\n"):
216
+ language = detect_language(code_block)
217
+ return (
218
+ f"```{language}{code_block}```" if language else f"```\n{code_block}```"
219
+ )
220
+ else:
221
+ return match.group(1) + code_block + "```"
222
+
223
+ text2 = code_block_pattern.sub(replacement, text)
224
+ return text2
225
+
226
+
227
+ def is_variable_assigned(var_name: str) -> bool:
228
+ return var_name in locals()
assets/Kelpy-Codos.js ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /**
2
+ * Copyright (c) 2023-2024 DeepSeek.
3
+ *
4
+ * Permission is hereby granted, free of charge, to any person obtaining a copy of
5
+ * this software and associated documentation files (the "Software"), to deal in
6
+ * the Software without restriction, including without limitation the rights to
7
+ * use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
8
+ * the Software, and to permit persons to whom the Software is furnished to do so,
9
+ * subject to the following conditions:
10
+ *
11
+ * The above copyright notice and this permission notice shall be included in all
12
+ * copies or substantial portions of the Software.
13
+ *
14
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
15
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
16
+ * FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
17
+ * COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
18
+ * IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
19
+ * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
20
+ */
21
+
22
+ // ==UserScript==
23
+ // @name Kelpy Codos
24
+ // @namespace https://github.com/Keldos-Li/Kelpy-Codos
25
+ // @version 1.0.5
26
+ // @author Keldos; https://keldos.me/
27
+ // @description Add copy button to PRE tags before CODE tag, for Chuanhu ChatGPT especially.
28
+ // Based on Chuanhu ChatGPT version: ac04408 (2023-3-22)
29
+ // @license GPL-3.0
30
+ // @grant none
31
+ // ==/UserScript==
32
+
33
+ (function () {
34
+ "use strict";
35
+
36
+ function addCopyButton(pre) {
37
+ var code = pre.querySelector("code");
38
+ if (!code) {
39
+ return; // 如果没有找到 <code> 元素,则不添加按钮
40
+ }
41
+ var firstChild = code.firstChild;
42
+ if (!firstChild) {
43
+ return; // 如果 <code> 元素没有子节点,则不添加按钮
44
+ }
45
+ var button = document.createElement("button");
46
+ button.textContent = "\uD83D\uDCCE"; // 使用 📎 符号作为“复制”按钮的文本
47
+ button.style.position = "relative";
48
+ button.style.float = "right";
49
+ button.style.fontSize = "1em"; // 可选:调整按钮大小
50
+ button.style.background = "none"; // 可选:去掉背景颜色
51
+ button.style.border = "none"; // 可选:去掉边框
52
+ button.style.cursor = "pointer"; // 可选:显示指针样式
53
+ button.addEventListener("click", function () {
54
+ var range = document.createRange();
55
+ range.selectNodeContents(code);
56
+ range.setStartBefore(firstChild); // 将范围设置为第一个子节点之前
57
+ var selection = window.getSelection();
58
+ selection.removeAllRanges();
59
+ selection.addRange(range);
60
+
61
+ try {
62
+ var success = document.execCommand("copy");
63
+ if (success) {
64
+ button.textContent = "\u2714";
65
+ setTimeout(function () {
66
+ button.textContent = "\uD83D\uDCCE"; // 恢复按钮为“复制”
67
+ }, 2000);
68
+ } else {
69
+ button.textContent = "\u2716";
70
+ }
71
+ } catch (e) {
72
+ console.error(e);
73
+ button.textContent = "\u2716";
74
+ }
75
+
76
+ selection.removeAllRanges();
77
+ });
78
+ code.insertBefore(button, firstChild); // 将按钮插入到第一个子元素之前
79
+ }
80
+
81
+ function handleNewElements(mutationsList, observer) {
82
+ for (var mutation of mutationsList) {
83
+ if (mutation.type === "childList") {
84
+ for (var node of mutation.addedNodes) {
85
+ if (node.nodeName === "PRE") {
86
+ addCopyButton(node);
87
+ }
88
+ }
89
+ }
90
+ }
91
+ }
92
+
93
+ var observer = new MutationObserver(handleNewElements);
94
+ observer.observe(document.documentElement, {
95
+ childList: true,
96
+ subtree: true,
97
+ });
98
+
99
+ document.querySelectorAll("pre").forEach(addCopyButton);
100
+ })();
assets/avatar.png ADDED
assets/custom.css ADDED
@@ -0,0 +1,355 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /**
2
+ * Copyright (c) 2023-2024 DeepSeek.
3
+ *
4
+ * Permission is hereby granted, free of charge, to any person obtaining a copy of
5
+ * this software and associated documentation files (the "Software"), to deal in
6
+ * the Software without restriction, including without limitation the rights to
7
+ * use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
8
+ * the Software, and to permit persons to whom the Software is furnished to do so,
9
+ * subject to the following conditions:
10
+ *
11
+ * The above copyright notice and this permission notice shall be included in all
12
+ * copies or substantial portions of the Software.
13
+ *
14
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
15
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
16
+ * FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
17
+ * COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
18
+ * IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
19
+ * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
20
+ */
21
+
22
+ :root {
23
+ --chatbot-color-light: #f3f3f3;
24
+ --chatbot-color-dark: #121111;
25
+ }
26
+
27
+ /* status_display */
28
+ #status_display {
29
+ display: flex;
30
+ min-height: 2.5em;
31
+ align-items: flex-end;
32
+ justify-content: flex-end;
33
+ }
34
+ #status_display p {
35
+ font-size: 0.85em;
36
+ font-family: monospace;
37
+ color: var(--body-text-color-subdued);
38
+ }
39
+
40
+ /* usage_display */
41
+ #usage_display {
42
+ height: 1em;
43
+ }
44
+ #usage_display p {
45
+ padding: 0 1em;
46
+ font-size: 0.85em;
47
+ font-family: monospace;
48
+ color: var(--body-text-color-subdued);
49
+ }
50
+ /* list */
51
+ ol:not(.options),
52
+ ul:not(.options) {
53
+ padding-inline-start: 2em !important;
54
+ }
55
+
56
+ /* Thank @Keldos-Li for fixing it */
57
+ /* Light mode (default) */
58
+ #deepseek_chatbot {
59
+ background-color: var(--chatbot-color-light) !important;
60
+ color: #000000 !important;
61
+ }
62
+ [data-testid="bot"] {
63
+ background-color: #ffffff !important;
64
+ }
65
+ [data-testid="user"] {
66
+ background-color: #95ec69 !important;
67
+ }
68
+
69
+ /* Dark mode */
70
+ .dark #deepseek_chatbot {
71
+ background-color: var(--chatbot-color-dark) !important;
72
+ color: #ffffff !important;
73
+ }
74
+ .dark [data-testid="bot"] {
75
+ background-color: #2c2c2c !important;
76
+ }
77
+ .dark [data-testid="user"] {
78
+ background-color: #26b561 !important;
79
+ }
80
+
81
+ #deepseek_chatbot {
82
+ height: 100%;
83
+ min-height: 800px;
84
+ flex-grow: 1;
85
+ overflow: auto;
86
+ }
87
+
88
+ [class*="message"] {
89
+ border-radius: var(--radius-xl) !important;
90
+ border: none;
91
+ padding: var(--spacing-xl) !important;
92
+ font-size: var(--text-md) !important;
93
+ line-height: var(--line-md) !important;
94
+ min-height: calc(var(--text-md) * var(--line-md) + 2 * var(--spacing-xl));
95
+ min-width: calc(var(--text-md) * var(--line-md) + 2 * var(--spacing-xl));
96
+ }
97
+ [data-testid="bot"] {
98
+ max-width: 85%;
99
+ border-bottom-left-radius: 0 !important;
100
+ }
101
+ [data-testid="user"] {
102
+ max-width: 85%;
103
+ width: auto !important;
104
+ border-bottom-right-radius: 0 !important;
105
+ }
106
+ /* Table */
107
+ table {
108
+ margin: 1em 0;
109
+ border-collapse: collapse;
110
+ empty-cells: show;
111
+ }
112
+ td,
113
+ th {
114
+ border: 1.2px solid var(--border-color-primary) !important;
115
+ padding: 0.2em;
116
+ }
117
+ thead {
118
+ background-color: rgba(175, 184, 193, 0.2);
119
+ }
120
+ thead th {
121
+ padding: 0.5em 0.2em;
122
+ }
123
+ /* Inline code */
124
+ #deepseek_chatbot code {
125
+ display: inline;
126
+ white-space: break-spaces;
127
+ border-radius: 6px;
128
+ margin: 0 2px 0 2px;
129
+ padding: 0.2em 0.4em 0.1em 0.4em;
130
+ background-color: rgba(175, 184, 193, 0.2);
131
+ }
132
+ /* Code block */
133
+ #deepseek_chatbot pre code {
134
+ display: block;
135
+ overflow: auto;
136
+ white-space: pre;
137
+ background-color: #1c1d1e !important;
138
+ border-radius: 10px;
139
+ padding: 1.4em 1.2em 0em 1.4em;
140
+ margin: 1.2em 2em 1.2em 0.5em;
141
+ color: #fdf8f8;
142
+ box-shadow: 6px 6px 16px hsla(0, 0%, 0%, 0.2);
143
+ }
144
+ /* Hightlight */
145
+ #deepseek_chatbot .highlight {
146
+ background-color: transparent;
147
+ }
148
+ #deepseek_chatbot .highlight .hll {
149
+ background-color: #49483e;
150
+ }
151
+ #deepseek_chatbot .highlight .c {
152
+ color: #75715e;
153
+ } /* Comment */
154
+ #deepseek_chatbot .highlight .err {
155
+ color: #960050;
156
+ background-color: #1e0010;
157
+ } /* Error */
158
+ #deepseek_chatbot .highlight .k {
159
+ color: #66d9ef;
160
+ } /* Keyword */
161
+ #deepseek_chatbot .highlight .l {
162
+ color: #ae81ff;
163
+ } /* Literal */
164
+ #deepseek_chatbot .highlight .n {
165
+ color: #f8f8f2;
166
+ } /* Name */
167
+ #deepseek_chatbot .highlight .o {
168
+ color: #f92672;
169
+ } /* Operator */
170
+ #deepseek_chatbot .highlight .p {
171
+ color: #f8f8f2;
172
+ } /* Punctuation */
173
+ #deepseek_chatbot .highlight .ch {
174
+ color: #75715e;
175
+ } /* Comment.Hashbang */
176
+ #deepseek_chatbot .highlight .cm {
177
+ color: #75715e;
178
+ } /* Comment.Multiline */
179
+ #deepseek_chatbot .highlight .cp {
180
+ color: #75715e;
181
+ } /* Comment.Preproc */
182
+ #deepseek_chatbot .highlight .cpf {
183
+ color: #75715e;
184
+ } /* Comment.PreprocFile */
185
+ #deepseek_chatbot .highlight .c1 {
186
+ color: #75715e;
187
+ } /* Comment.Single */
188
+ #deepseek_chatbot .highlight .cs {
189
+ color: #75715e;
190
+ } /* Comment.Special */
191
+ #deepseek_chatbot .highlight .gd {
192
+ color: #f92672;
193
+ } /* Generic.Deleted */
194
+ #deepseek_chatbot .highlight .ge {
195
+ font-style: italic;
196
+ } /* Generic.Emph */
197
+ #deepseek_chatbot .highlight .gi {
198
+ color: #a6e22e;
199
+ } /* Generic.Inserted */
200
+ #deepseek_chatbot .highlight .gs {
201
+ font-weight: bold;
202
+ } /* Generic.Strong */
203
+ #deepseek_chatbot .highlight .gu {
204
+ color: #75715e;
205
+ } /* Generic.Subheading */
206
+ #deepseek_chatbot .highlight .kc {
207
+ color: #66d9ef;
208
+ } /* Keyword.Constant */
209
+ #deepseek_chatbot .highlight .kd {
210
+ color: #66d9ef;
211
+ } /* Keyword.Declaration */
212
+ #deepseek_chatbot .highlight .kn {
213
+ color: #f92672;
214
+ } /* Keyword.Namespace */
215
+ #deepseek_chatbot .highlight .kp {
216
+ color: #66d9ef;
217
+ } /* Keyword.Pseudo */
218
+ #deepseek_chatbot .highlight .kr {
219
+ color: #66d9ef;
220
+ } /* Keyword.Reserved */
221
+ #deepseek_chatbot .highlight .kt {
222
+ color: #66d9ef;
223
+ } /* Keyword.Type */
224
+ #deepseek_chatbot .highlight .ld {
225
+ color: #e6db74;
226
+ } /* Literal.Date */
227
+ #deepseek_chatbot .highlight .m {
228
+ color: #ae81ff;
229
+ } /* Literal.Number */
230
+ #deepseek_chatbot .highlight .s {
231
+ color: #e6db74;
232
+ } /* Literal.String */
233
+ #deepseek_chatbot .highlight .na {
234
+ color: #a6e22e;
235
+ } /* Name.Attribute */
236
+ #deepseek_chatbot .highlight .nb {
237
+ color: #f8f8f2;
238
+ } /* Name.Builtin */
239
+ #deepseek_chatbot .highlight .nc {
240
+ color: #a6e22e;
241
+ } /* Name.Class */
242
+ #deepseek_chatbot .highlight .no {
243
+ color: #66d9ef;
244
+ } /* Name.Constant */
245
+ #deepseek_chatbot .highlight .nd {
246
+ color: #a6e22e;
247
+ } /* Name.Decorator */
248
+ #deepseek_chatbot .highlight .ni {
249
+ color: #f8f8f2;
250
+ } /* Name.Entity */
251
+ #deepseek_chatbot .highlight .ne {
252
+ color: #a6e22e;
253
+ } /* Name.Exception */
254
+ #deepseek_chatbot .highlight .nf {
255
+ color: #a6e22e;
256
+ } /* Name.Function */
257
+ #deepseek_chatbot .highlight .nl {
258
+ color: #f8f8f2;
259
+ } /* Name.Label */
260
+ #deepseek_chatbot .highlight .nn {
261
+ color: #f8f8f2;
262
+ } /* Name.Namespace */
263
+ #deepseek_chatbot .highlight .nx {
264
+ color: #a6e22e;
265
+ } /* Name.Other */
266
+ #deepseek_chatbot .highlight .py {
267
+ color: #f8f8f2;
268
+ } /* Name.Property */
269
+ #deepseek_chatbot .highlight .nt {
270
+ color: #f92672;
271
+ } /* Name.Tag */
272
+ #deepseek_chatbot .highlight .nv {
273
+ color: #f8f8f2;
274
+ } /* Name.Variable */
275
+ #deepseek_chatbot .highlight .ow {
276
+ color: #f92672;
277
+ } /* Operator.Word */
278
+ #deepseek_chatbot .highlight .w {
279
+ color: #f8f8f2;
280
+ } /* Text.Whitespace */
281
+ #deepseek_chatbot .highlight .mb {
282
+ color: #ae81ff;
283
+ } /* Literal.Number.Bin */
284
+ #deepseek_chatbot .highlight .mf {
285
+ color: #ae81ff;
286
+ } /* Literal.Number.Float */
287
+ #deepseek_chatbot .highlight .mh {
288
+ color: #ae81ff;
289
+ } /* Literal.Number.Hex */
290
+ #deepseek_chatbot .highlight .mi {
291
+ color: #ae81ff;
292
+ } /* Literal.Number.Integer */
293
+ #deepseek_chatbot .highlight .mo {
294
+ color: #ae81ff;
295
+ } /* Literal.Number.Oct */
296
+ #deepseek_chatbot .highlight .sa {
297
+ color: #e6db74;
298
+ } /* Literal.String.Affix */
299
+ #deepseek_chatbot .highlight .sb {
300
+ color: #e6db74;
301
+ } /* Literal.String.Backtick */
302
+ #deepseek_chatbot .highlight .sc {
303
+ color: #e6db74;
304
+ } /* Literal.String.Char */
305
+ #deepseek_chatbot .highlight .dl {
306
+ color: #e6db74;
307
+ } /* Literal.String.Delimiter */
308
+ #deepseek_chatbot .highlight .sd {
309
+ color: #e6db74;
310
+ } /* Literal.String.Doc */
311
+ #deepseek_chatbot .highlight .s2 {
312
+ color: #e6db74;
313
+ } /* Literal.String.Double */
314
+ #deepseek_chatbot .highlight .se {
315
+ color: #ae81ff;
316
+ } /* Literal.String.Escape */
317
+ #deepseek_chatbot .highlight .sh {
318
+ color: #e6db74;
319
+ } /* Literal.String.Heredoc */
320
+ #deepseek_chatbot .highlight .si {
321
+ color: #e6db74;
322
+ } /* Literal.String.Interpol */
323
+ #deepseek_chatbot .highlight .sx {
324
+ color: #e6db74;
325
+ } /* Literal.String.Other */
326
+ #deepseek_chatbot .highlight .sr {
327
+ color: #e6db74;
328
+ } /* Literal.String.Regex */
329
+ #deepseek_chatbot .highlight .s1 {
330
+ color: #e6db74;
331
+ } /* Literal.String.Single */
332
+ #deepseek_chatbot .highlight .ss {
333
+ color: #e6db74;
334
+ } /* Literal.String.Symbol */
335
+ #deepseek_chatbot .highlight .bp {
336
+ color: #f8f8f2;
337
+ } /* Name.Builtin.Pseudo */
338
+ #deepseek_chatbot .highlight .fm {
339
+ color: #a6e22e;
340
+ } /* Name.Function.Magic */
341
+ #deepseek_chatbot .highlight .vc {
342
+ color: #f8f8f2;
343
+ } /* Name.Variable.Class */
344
+ #deepseek_chatbot .highlight .vg {
345
+ color: #f8f8f2;
346
+ } /* Name.Variable.Global */
347
+ #deepseek_chatbot .highlight .vi {
348
+ color: #f8f8f2;
349
+ } /* Name.Variable.Instance */
350
+ #deepseek_chatbot .highlight .vm {
351
+ color: #f8f8f2;
352
+ } /* Name.Variable.Magic */
353
+ #deepseek_chatbot .highlight .il {
354
+ color: #ae81ff;
355
+ } /* Literal.Number.Integer.Long */
assets/custom.js ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /**
2
+ * Copyright (c) 2023-2024 DeepSeek.
3
+ *
4
+ * Permission is hereby granted, free of charge, to any person obtaining a copy of
5
+ * this software and associated documentation files (the "Software"), to deal in
6
+ * the Software without restriction, including without limitation the rights to
7
+ * use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
8
+ * the Software, and to permit persons to whom the Software is furnished to do so,
9
+ * subject to the following conditions:
10
+ *
11
+ * The above copyright notice and this permission notice shall be included in all
12
+ * copies or substantial portions of the Software.
13
+ *
14
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
15
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
16
+ * FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
17
+ * COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
18
+ * IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
19
+ * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
20
+ */
21
+
22
+ // custom javascript here
assets/favicon.ico ADDED
deepseek_vl/__init__.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023-2024 DeepSeek.
2
+ #
3
+ # Permission is hereby granted, free of charge, to any person obtaining a copy of
4
+ # this software and associated documentation files (the "Software"), to deal in
5
+ # the Software without restriction, including without limitation the rights to
6
+ # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
7
+ # the Software, and to permit persons to whom the Software is furnished to do so,
8
+ # subject to the following conditions:
9
+ #
10
+ # The above copyright notice and this permission notice shall be included in all
11
+ # copies or substantial portions of the Software.
12
+ #
13
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
15
+ # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
16
+ # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
17
+ # IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
18
+ # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
19
+
20
+
21
+ # check if python version is above 3.10
22
+ import sys
23
+
24
+ if sys.version_info >= (3, 10):
25
+ print("Python version is above 3.10, patching the collections module.")
26
+ # Monkey patch collections
27
+ import collections
28
+ import collections.abc
29
+
30
+ for type_name in collections.abc.__all__:
31
+ setattr(collections, type_name, getattr(collections.abc, type_name))
deepseek_vl/models/__init__.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023-2024 DeepSeek.
2
+ #
3
+ # Permission is hereby granted, free of charge, to any person obtaining a copy of
4
+ # this software and associated documentation files (the "Software"), to deal in
5
+ # the Software without restriction, including without limitation the rights to
6
+ # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
7
+ # the Software, and to permit persons to whom the Software is furnished to do so,
8
+ # subject to the following conditions:
9
+ #
10
+ # The above copyright notice and this permission notice shall be included in all
11
+ # copies or substantial portions of the Software.
12
+ #
13
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
15
+ # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
16
+ # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
17
+ # IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
18
+ # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
19
+
20
+ from .image_processing_vlm import VLMImageProcessor
21
+ from .modeling_vlm import MultiModalityCausalLM
22
+ from .processing_vlm import VLChatProcessor
23
+
24
+ __all__ = [
25
+ "VLMImageProcessor",
26
+ "VLChatProcessor",
27
+ "MultiModalityCausalLM",
28
+ ]
deepseek_vl/models/clip_encoder.py ADDED
@@ -0,0 +1,242 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023-2024 DeepSeek.
2
+ #
3
+ # Permission is hereby granted, free of charge, to any person obtaining a copy of
4
+ # this software and associated documentation files (the "Software"), to deal in
5
+ # the Software without restriction, including without limitation the rights to
6
+ # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
7
+ # the Software, and to permit persons to whom the Software is furnished to do so,
8
+ # subject to the following conditions:
9
+ #
10
+ # The above copyright notice and this permission notice shall be included in all
11
+ # copies or substantial portions of the Software.
12
+ #
13
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
15
+ # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
16
+ # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
17
+ # IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
18
+ # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
19
+
20
+ from typing import Dict, List, Literal, Optional, Tuple, Union
21
+
22
+ import torch
23
+ import torch.nn as nn
24
+ import torchvision.transforms
25
+ from einops import rearrange
26
+
27
+ from deepseek_vl.models.sam import create_sam_vit
28
+ from deepseek_vl.models.siglip_vit import create_siglip_vit
29
+
30
+
31
+ class CLIPVisionTower(nn.Module):
32
+ def __init__(
33
+ self,
34
+ model_name: str = "siglip_large_patch16_384",
35
+ image_size: Union[Tuple[int, int], int] = 336,
36
+ select_feature: str = "patch",
37
+ select_layer: int = -2,
38
+ select_layers: list = None,
39
+ ckpt_path: str = "",
40
+ pixel_mean: Optional[List[float]] = None,
41
+ pixel_std: Optional[List[float]] = None,
42
+ **kwargs,
43
+ ):
44
+ super().__init__()
45
+
46
+ self.model_name = model_name
47
+ self.select_feature = select_feature
48
+ self.select_layer = select_layer
49
+ self.select_layers = select_layers
50
+
51
+ vision_tower_params = {
52
+ "model_name": model_name,
53
+ "image_size": image_size,
54
+ "ckpt_path": ckpt_path,
55
+ "select_layer": select_layer,
56
+ }
57
+ vision_tower_params.update(kwargs)
58
+ self.vision_tower, self.forward_kwargs = self.build_vision_tower(
59
+ vision_tower_params
60
+ )
61
+
62
+ if pixel_mean is not None and pixel_std is not None:
63
+ image_norm = torchvision.transforms.Normalize(
64
+ mean=pixel_mean, std=pixel_std
65
+ )
66
+ else:
67
+ image_norm = None
68
+
69
+ self.image_norm = image_norm
70
+
71
+ def build_vision_tower(self, vision_tower_params):
72
+ if self.model_name.startswith("siglip"):
73
+ self.select_feature = "same"
74
+ vision_tower = create_siglip_vit(**vision_tower_params)
75
+ forward_kwargs = dict()
76
+
77
+ elif self.model_name.startswith("sam"):
78
+ vision_tower = create_sam_vit(**vision_tower_params)
79
+ forward_kwargs = dict()
80
+
81
+ else: # huggingface
82
+ from transformers import CLIPVisionModel
83
+
84
+ vision_tower = CLIPVisionModel.from_pretrained(**vision_tower_params)
85
+ forward_kwargs = dict(output_hidden_states=True)
86
+
87
+ return vision_tower, forward_kwargs
88
+
89
+ def feature_select(self, image_forward_outs):
90
+ if isinstance(image_forward_outs, torch.Tensor):
91
+ # the output has been the self.select_layer"s features
92
+ image_features = image_forward_outs
93
+ else:
94
+ image_features = image_forward_outs.hidden_states[self.select_layer]
95
+
96
+ if self.select_feature == "patch":
97
+ # if the output has cls_token
98
+ image_features = image_features[:, 1:]
99
+ elif self.select_feature == "cls_patch":
100
+ image_features = image_features
101
+ elif self.select_feature == "same":
102
+ image_features = image_features
103
+
104
+ else:
105
+ raise ValueError(f"Unexpected select feature: {self.select_feature}")
106
+ return image_features
107
+
108
+ def forward(self, images):
109
+ """
110
+
111
+ Args:
112
+ images (torch.Tensor): [b, 3, H, W]
113
+
114
+ Returns:
115
+ image_features (torch.Tensor): [b, n_patch, d]
116
+ """
117
+
118
+ if self.image_norm is not None:
119
+ images = self.image_norm(images)
120
+
121
+ image_forward_outs = self.vision_tower(images, **self.forward_kwargs)
122
+ image_features = self.feature_select(image_forward_outs)
123
+ return image_features
124
+
125
+
126
+ class HybridVisionTower(nn.Module):
127
+ def __init__(
128
+ self,
129
+ high_res_cfg: Dict,
130
+ low_res_cfg: Dict,
131
+ freeze_high: bool = False,
132
+ freeze_low: bool = False,
133
+ concat_type: Literal["feature", "sequence", "add", "tuple"] = "tuple",
134
+ **ignore_kwargs,
135
+ ):
136
+ super().__init__()
137
+
138
+ self.vision_tower_high = CLIPVisionTower(**high_res_cfg)
139
+ self.vision_tower_low = CLIPVisionTower(**low_res_cfg)
140
+ self.low_res_size = low_res_cfg["image_size"]
141
+ self.concat_type = concat_type
142
+
143
+ self.high_layer_norm = nn.LayerNorm(high_res_cfg.get("output_dim", 1024))
144
+ self.low_layer_norm = nn.LayerNorm(low_res_cfg.get("output_dim", 1024))
145
+
146
+ if freeze_high:
147
+ for p_name, p in self.vision_tower_high.named_parameters():
148
+ p.requires_grad = False
149
+ self.vision_tower_high = self.vision_tower_high.eval()
150
+ else:
151
+ # train donwsamples and neck
152
+ for p_name, p in self.vision_tower_high.named_parameters():
153
+ if "downsamples" in p_name or "neck" in p_name:
154
+ p.requires_grad = True
155
+ else:
156
+ p.requires_grad = False
157
+
158
+ if freeze_low:
159
+ for p in self.vision_tower_low.parameters():
160
+ p.requires_grad = False
161
+ self.vision_tower_low = self.vision_tower_low.eval()
162
+
163
+ self.resize = torchvision.transforms.Resize(self.low_res_size, antialias=True)
164
+
165
+ def forward(self, images: torch.Tensor):
166
+ """
167
+
168
+ Args:
169
+ images (torch.Tensor): [bs, 3, H, W]
170
+
171
+ Returns:
172
+ res (torch.Tensor): [bs, t, c]
173
+ """
174
+
175
+ # [bs, c, h, w]
176
+ high_images = images
177
+
178
+ # [bs, c, h_low, w_low]
179
+ low_images = self.resize(images)
180
+
181
+ # separately run two vision towers
182
+ # run high_res vision tower
183
+ high_res = self.vision_tower_high(high_images)
184
+ # [bs, c, h, w] -> [bs, h*w, c]
185
+ high_res = rearrange(high_res, "b c h w -> b (h w) c")
186
+ # run low_res vision tower
187
+ low_res = self.vision_tower_low(low_images)
188
+
189
+ if self.concat_type == "feature":
190
+ images_features = torch.cat([high_res, low_res], dim=-1)
191
+ elif self.concat_type == "sequence":
192
+ images_features = torch.cat([high_res, low_res], dim=1)
193
+ elif self.concat_type == "add":
194
+ images_features = high_res + low_res
195
+ elif self.concat_type == "tuple":
196
+ images_features = (high_res, low_res)
197
+
198
+ else:
199
+ raise ValueError(
200
+ "Currently only support `feature`, `sequence`, `add` and `tuple` concat type."
201
+ )
202
+
203
+ return images_features
204
+
205
+
206
+ if __name__ == "__main__":
207
+ image_size = 1024
208
+ x = torch.zeros(2, 3, image_size, image_size).bfloat16().cuda()
209
+
210
+ high_res_cfg = dict(
211
+ model_name="sam_b_downsample",
212
+ select_feature="same",
213
+ image_size=image_size,
214
+ pixel_mean=(0.48145466, 0.4578275, 0.40821073),
215
+ pixel_std=(0.26862954, 0.26130258, 0.27577711),
216
+ select_layer=-1,
217
+ ckpt_path="",
218
+ )
219
+
220
+ low_res_cfg = dict(
221
+ model_name="siglip_large_patch16_384",
222
+ select_feature="same",
223
+ image_size=384,
224
+ pixel_mean=(0.5, 0.5, 0.5),
225
+ pixel_std=(0.5, 0.5, 0.5),
226
+ select_layer=-1,
227
+ ckpt_path="",
228
+ )
229
+
230
+ net = (
231
+ HybridVisionTower(
232
+ high_res_cfg=high_res_cfg,
233
+ low_res_cfg=low_res_cfg,
234
+ freeze_high=True,
235
+ freeze_low=True,
236
+ concat_type="tuple",
237
+ )
238
+ .bfloat16()
239
+ .cuda()
240
+ )
241
+ high_x, low_x = net(x)
242
+ print(x.shape, high_x.shape, low_x.shape)
deepseek_vl/models/image_processing_vlm.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023-2024 DeepSeek.
2
+ #
3
+ # Permission is hereby granted, free of charge, to any person obtaining a copy of
4
+ # this software and associated documentation files (the "Software"), to deal in
5
+ # the Software without restriction, including without limitation the rights to
6
+ # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
7
+ # the Software, and to permit persons to whom the Software is furnished to do so,
8
+ # subject to the following conditions:
9
+ #
10
+ # The above copyright notice and this permission notice shall be included in all
11
+ # copies or substantial portions of the Software.
12
+ #
13
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
15
+ # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
16
+ # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
17
+ # IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
18
+ # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
19
+
20
+ from typing import List, Tuple, Union
21
+
22
+ import numpy as np
23
+ import torch
24
+ import torchvision
25
+ import torchvision.transforms.functional
26
+ from PIL import Image
27
+ from transformers import AutoImageProcessor, PretrainedConfig
28
+ from transformers.image_processing_utils import BaseImageProcessor, BatchFeature
29
+ from transformers.image_utils import to_numpy_array
30
+ from transformers.utils import logging
31
+
32
+ logger = logging.get_logger(__name__)
33
+
34
+ ImageType = Union[np.ndarray, torch.Tensor, Image.Image]
35
+ IMAGENET_MEAN = (0.48145466, 0.4578275, 0.40821073)
36
+ IMAGENET_STD = (0.26862954, 0.26130258, 0.27577711)
37
+ IMAGENET_INCEPTION_MEAN = (0.5, 0.5, 0.5)
38
+ IMAGENET_INCEPTION_STD = (0.5, 0.5, 0.5)
39
+
40
+
41
+ def expand2square(pil_img, background_color):
42
+ width, height = pil_img.size
43
+ if width == height:
44
+ return pil_img
45
+ elif width > height:
46
+ result = Image.new(pil_img.mode, (width, width), background_color)
47
+ result.paste(pil_img, (0, (width - height) // 2))
48
+ return result
49
+ else:
50
+ result = Image.new(pil_img.mode, (height, height), background_color)
51
+ result.paste(pil_img, ((height - width) // 2, 0))
52
+ return result
53
+
54
+
55
+ class VLMImageProcessorConfig(PretrainedConfig):
56
+ model_type = "deepseek_vlm"
57
+ image_size: int
58
+ min_size: int
59
+ image_mean: Union[Tuple[float, float, float], List[float]]
60
+ image_std: Union[Tuple[float, float, float], List[float]]
61
+ rescale_factor: float
62
+ do_normalize: bool
63
+
64
+ def __init__(
65
+ self,
66
+ image_size: int,
67
+ min_size: int = 14,
68
+ image_mean: Union[Tuple[float, float, float], List[float]] = (
69
+ 0.48145466,
70
+ 0.4578275,
71
+ 0.40821073,
72
+ ),
73
+ image_std: Union[Tuple[float, float, float], List[float]] = (
74
+ 0.26862954,
75
+ 0.26130258,
76
+ 0.27577711,
77
+ ),
78
+ rescale_factor: float = 1.0 / 255.0,
79
+ do_normalize: bool = True,
80
+ **kwargs,
81
+ ):
82
+ self.image_size = image_size
83
+ self.min_size = min_size
84
+ self.image_mean = image_mean
85
+ self.image_std = image_std
86
+ self.rescale_factor = rescale_factor
87
+ self.do_normalize = do_normalize
88
+
89
+ super().__init__(**kwargs)
90
+
91
+
92
+ class VLMImageProcessor(BaseImageProcessor):
93
+ model_input_names = ["pixel_values"]
94
+
95
+ def __init__(
96
+ self,
97
+ image_size: int,
98
+ min_size: int = 14,
99
+ image_mean: Union[Tuple[float, float, float], List[float]] = (
100
+ 0.48145466,
101
+ 0.4578275,
102
+ 0.40821073,
103
+ ),
104
+ image_std: Union[Tuple[float, float, float], List[float]] = (
105
+ 0.26862954,
106
+ 0.26130258,
107
+ 0.27577711,
108
+ ),
109
+ rescale_factor: float = 1.0 / 255.0,
110
+ do_normalize: bool = True,
111
+ **kwargs,
112
+ ):
113
+ super().__init__(**kwargs)
114
+
115
+ self.image_size = image_size
116
+ self.rescale_factor = rescale_factor
117
+ self.image_mean = image_mean
118
+ self.image_std = image_std
119
+ self.min_size = min_size
120
+ self.do_normalize = do_normalize
121
+
122
+ if image_mean is None:
123
+ self.background_color = (127, 127, 127)
124
+ else:
125
+ self.background_color = tuple([int(x * 255) for x in image_mean])
126
+
127
+ def resize(self, pil_img: Image) -> np.ndarray:
128
+ """
129
+
130
+ Args:
131
+ pil_img (PIL.Image): [H, W, 3] in PIL.Image in RGB
132
+
133
+ Returns:
134
+ x (np.ndarray): [3, self.image_size, self.image_size]
135
+ """
136
+
137
+ width, height = pil_img.size
138
+ max_size = max(width, height)
139
+
140
+ size = [
141
+ max(int(height / max_size * self.image_size), self.min_size),
142
+ max(int(width / max_size * self.image_size), self.min_size),
143
+ ]
144
+
145
+ if width <= 0 or height <= 0 or size[0] <= 0 or size[1] <= 0:
146
+ print(f"orig size = {pil_img.size}, new size = {size}")
147
+ raise ValueError("Invalid size!")
148
+
149
+ pil_img = torchvision.transforms.functional.resize(
150
+ pil_img,
151
+ size,
152
+ interpolation=torchvision.transforms.functional.InterpolationMode.BICUBIC,
153
+ antialias=True,
154
+ )
155
+
156
+ pil_img = expand2square(pil_img, self.background_color)
157
+ x = to_numpy_array(pil_img)
158
+
159
+ # [H, W, 3] -> [3, H, W]
160
+ x = np.transpose(x, (2, 0, 1))
161
+
162
+ return x
163
+
164
+ def preprocess(self, images, return_tensors: str = "pt", **kwargs) -> BatchFeature:
165
+ # resize and pad to [self.image_size, self.image_size]
166
+ # then convert from [H, W, 3] to [3, H, W]
167
+ images: List[np.ndarray] = [self.resize(image) for image in images]
168
+
169
+ # resacle from [0, 255] -> [0, 1]
170
+ images = [
171
+ self.rescale(
172
+ image=image,
173
+ scale=self.rescale_factor,
174
+ input_data_format="channels_first",
175
+ )
176
+ for image in images
177
+ ]
178
+
179
+ # normalize
180
+ if self.do_normalize:
181
+ images = [
182
+ self.normalize(
183
+ image=image,
184
+ mean=self.image_mean,
185
+ std=self.image_std,
186
+ input_data_format="channels_first",
187
+ )
188
+ for image in images
189
+ ]
190
+
191
+ data = {"pixel_values": images}
192
+ return BatchFeature(data=data, tensor_type=return_tensors)
193
+
194
+ @property
195
+ def default_shape(self):
196
+ return [3, self.image_size, self.image_size]
197
+
198
+
199
+ AutoImageProcessor.register(VLMImageProcessorConfig, VLMImageProcessor)
200
+
201
+
202
+ if __name__ == "__main__":
203
+ image_processor = VLMImageProcessor(
204
+ image_size=1024,
205
+ image_mean=IMAGENET_INCEPTION_MEAN,
206
+ image_std=IMAGENET_INCEPTION_STD,
207
+ do_normalize=True,
208
+ )
deepseek_vl/models/modeling_vlm.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023-2024 DeepSeek.
2
+ #
3
+ # Permission is hereby granted, free of charge, to any person obtaining a copy of
4
+ # this software and associated documentation files (the "Software"), to deal in
5
+ # the Software without restriction, including without limitation the rights to
6
+ # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
7
+ # the Software, and to permit persons to whom the Software is furnished to do so,
8
+ # subject to the following conditions:
9
+ #
10
+ # The above copyright notice and this permission notice shall be included in all
11
+ # copies or substantial portions of the Software.
12
+ #
13
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
15
+ # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
16
+ # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
17
+ # IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
18
+ # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
19
+
20
+ import torch
21
+ from attrdict import AttrDict
22
+ from einops import rearrange
23
+ from transformers import (
24
+ AutoConfig,
25
+ AutoModelForCausalLM,
26
+ LlamaConfig,
27
+ LlamaForCausalLM,
28
+ PreTrainedModel,
29
+ )
30
+ from transformers.configuration_utils import PretrainedConfig
31
+
32
+ from deepseek_vl.models.clip_encoder import CLIPVisionTower, HybridVisionTower
33
+ from deepseek_vl.models.projector import MlpProjector
34
+
35
+
36
+ def model_name_to_cls(cls_name):
37
+ if "MlpProjector" in cls_name:
38
+ cls = MlpProjector
39
+
40
+ elif "CLIPVisionTower" in cls_name:
41
+ cls = CLIPVisionTower
42
+
43
+ elif "HybridVisionTower" in cls_name:
44
+ cls = HybridVisionTower
45
+
46
+ else:
47
+ raise ValueError(f"class_name {cls_name} is invalid.")
48
+
49
+ return cls
50
+
51
+
52
+ class VisionConfig(PretrainedConfig):
53
+ model_type = "vision"
54
+ cls: str = ""
55
+ params: AttrDict = {}
56
+
57
+ def __init__(self, **kwargs):
58
+ super().__init__(**kwargs)
59
+
60
+ self.cls = kwargs.get("cls", "")
61
+ if not isinstance(self.cls, str):
62
+ self.cls = self.cls.__name__
63
+
64
+ self.params = AttrDict(kwargs.get("params", {}))
65
+
66
+
67
+ class AlignerConfig(PretrainedConfig):
68
+ model_type = "aligner"
69
+ cls: str = ""
70
+ params: AttrDict = {}
71
+
72
+ def __init__(self, **kwargs):
73
+ super().__init__(**kwargs)
74
+
75
+ self.cls = kwargs.get("cls", "")
76
+ if not isinstance(self.cls, str):
77
+ self.cls = self.cls.__name__
78
+
79
+ self.params = AttrDict(kwargs.get("params", {}))
80
+
81
+
82
+ class MultiModalityConfig(PretrainedConfig):
83
+ model_type = "multi_modality"
84
+ vision_config: VisionConfig
85
+ aligner_config: AlignerConfig
86
+ language_config: LlamaConfig
87
+
88
+ def __init__(self, **kwargs):
89
+ super().__init__(**kwargs)
90
+ vision_config = kwargs.get("vision_config", {})
91
+ self.vision_config = VisionConfig(**vision_config)
92
+
93
+ aligner_config = kwargs.get("aligner_config", {})
94
+ self.aligner_config = AlignerConfig(**aligner_config)
95
+
96
+ language_config = kwargs.get("language_config", {})
97
+ if isinstance(language_config, LlamaConfig):
98
+ self.language_config = language_config
99
+ else:
100
+ self.language_config = LlamaConfig(**language_config)
101
+
102
+
103
+ class MultiModalityPreTrainedModel(PreTrainedModel):
104
+ config_class = MultiModalityConfig
105
+ base_model_prefix = "multi_modality"
106
+ _no_split_modules = []
107
+ _skip_keys_device_placement = "past_key_values"
108
+
109
+
110
+ class MultiModalityCausalLM(MultiModalityPreTrainedModel):
111
+ def __init__(self, config: MultiModalityConfig):
112
+ super().__init__(config)
113
+
114
+ vision_config = config.vision_config
115
+ vision_cls = model_name_to_cls(vision_config.cls)
116
+ self.vision_model = vision_cls(**vision_config.params)
117
+
118
+ aligner_config = config.aligner_config
119
+ aligner_cls = model_name_to_cls(aligner_config.cls)
120
+ self.aligner = aligner_cls(aligner_config.params)
121
+
122
+ language_config = config.language_config
123
+ self.language_model = LlamaForCausalLM(language_config)
124
+
125
+ def prepare_inputs_embeds(
126
+ self,
127
+ input_ids: torch.LongTensor,
128
+ pixel_values: torch.FloatTensor,
129
+ images_seq_mask: torch.LongTensor,
130
+ images_emb_mask: torch.LongTensor,
131
+ **kwargs,
132
+ ):
133
+ """
134
+
135
+ Args:
136
+ input_ids (torch.LongTensor): [b, T]
137
+ pixel_values (torch.FloatTensor): [b, n_images, 3, h, w]
138
+ images_seq_mask (torch.BoolTensor): [b, T]
139
+ images_emb_mask (torch.BoolTensor): [b, n_images, n_image_tokens]
140
+
141
+ assert torch.sum(images_seq_mask) == torch.sum(images_emb_mask)
142
+
143
+ Returns:
144
+ input_embeds (torch.Tensor): [b, T, D]
145
+ """
146
+
147
+ bs, n = pixel_values.shape[0:2]
148
+ images = rearrange(pixel_values, "b n c h w -> (b n) c h w")
149
+ # [b x n, T2, D]
150
+ images_embeds = self.aligner(self.vision_model(images))
151
+
152
+ # [b x n, T2, D] -> [b, n x T2, D]
153
+ images_embeds = rearrange(images_embeds, "(b n) t d -> b (n t) d", b=bs, n=n)
154
+ # [b, n, T2] -> [b, n x T2]
155
+ images_emb_mask = rearrange(images_emb_mask, "b n t -> b (n t)")
156
+
157
+ # [b, T, D]
158
+ input_ids[input_ids < 0] = 0 # ignore the image embeddings
159
+ inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
160
+
161
+ # replace with the image embeddings
162
+ inputs_embeds[images_seq_mask] = images_embeds[images_emb_mask]
163
+
164
+ return inputs_embeds
165
+
166
+
167
+ AutoConfig.register("vision", VisionConfig)
168
+ AutoConfig.register("aligner", AlignerConfig)
169
+ AutoConfig.register("multi_modality", MultiModalityConfig)
170
+ AutoModelForCausalLM.register(MultiModalityConfig, MultiModalityCausalLM)
deepseek_vl/models/processing_vlm.py ADDED
@@ -0,0 +1,390 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023-2024 DeepSeek.
2
+ #
3
+ # Permission is hereby granted, free of charge, to any person obtaining a copy of
4
+ # this software and associated documentation files (the "Software"), to deal in
5
+ # the Software without restriction, including without limitation the rights to
6
+ # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
7
+ # the Software, and to permit persons to whom the Software is furnished to do so,
8
+ # subject to the following conditions:
9
+ #
10
+ # The above copyright notice and this permission notice shall be included in all
11
+ # copies or substantial portions of the Software.
12
+ #
13
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
15
+ # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
16
+ # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
17
+ # IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
18
+ # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
19
+
20
+ from dataclasses import dataclass
21
+ from typing import Dict, List
22
+
23
+ import torch
24
+ from PIL.Image import Image
25
+ from transformers import LlamaTokenizerFast
26
+ from transformers.processing_utils import ProcessorMixin
27
+
28
+ from deepseek_vl.models.image_processing_vlm import VLMImageProcessor
29
+ from deepseek_vl.utils.conversation import get_conv_template
30
+
31
+
32
+ class DictOutput(object):
33
+ def keys(self):
34
+ return self.__dict__.keys()
35
+
36
+ def __getitem__(self, item):
37
+ return self.__dict__[item]
38
+
39
+ def __setitem__(self, key, value):
40
+ self.__dict__[key] = value
41
+
42
+
43
+ @dataclass
44
+ class VLChatProcessorOutput(DictOutput):
45
+ sft_format: str
46
+ input_ids: torch.Tensor
47
+ pixel_values: torch.Tensor
48
+ num_image_tokens: torch.IntTensor
49
+
50
+ def __len__(self):
51
+ return len(self.input_ids)
52
+
53
+
54
+ @dataclass
55
+ class BatchedVLChatProcessorOutput(DictOutput):
56
+ sft_format: List[str]
57
+ input_ids: torch.Tensor
58
+ pixel_values: torch.Tensor
59
+ attention_mask: torch.Tensor
60
+ images_seq_mask: torch.BoolTensor
61
+ images_emb_mask: torch.BoolTensor
62
+
63
+ def to(self, device, dtype=torch.bfloat16):
64
+ self.input_ids = self.input_ids.to(device)
65
+ self.attention_mask = self.attention_mask.to(device)
66
+ self.images_seq_mask = self.images_seq_mask.to(device)
67
+ self.images_emb_mask = self.images_emb_mask.to(device)
68
+ self.pixel_values = self.pixel_values.to(device=device, dtype=dtype)
69
+ return self
70
+
71
+
72
+ class VLChatProcessor(ProcessorMixin):
73
+ image_processor_class = "AutoImageProcessor"
74
+ tokenizer_class = ("LlamaTokenizer", "LlamaTokenizerFast")
75
+
76
+ attributes = ["image_processor", "tokenizer"]
77
+
78
+ system_prompt = (
79
+ "You are a helpful language and vision assistant. "
80
+ "You are able to understand the visual content that the user provides, "
81
+ "and assist the user with a variety of tasks using natural language."
82
+ )
83
+
84
+ def __init__(
85
+ self,
86
+ image_processor: VLMImageProcessor,
87
+ tokenizer: LlamaTokenizerFast,
88
+ image_tag: str = "<image_placeholder>",
89
+ num_image_tokens: int = 576,
90
+ add_special_token: bool = False,
91
+ sft_format: str = "deepseek",
92
+ mask_prompt: bool = True,
93
+ ignore_id: int = -100,
94
+ **kwargs,
95
+ ):
96
+ self.image_processor = image_processor
97
+ self.tokenizer = tokenizer
98
+
99
+ image_id = self.tokenizer.vocab.get(image_tag)
100
+ if image_id is None:
101
+ special_tokens = [image_tag]
102
+ special_tokens_dict = {"additional_special_tokens": special_tokens}
103
+ self.tokenizer.add_special_tokens(special_tokens_dict)
104
+ print(f"Add image tag = {image_tag} to the tokenizer")
105
+
106
+ self.image_tag = image_tag
107
+ self.num_image_tokens = num_image_tokens
108
+ self.add_special_token = add_special_token
109
+ self.sft_format = sft_format
110
+ self.mask_prompt = mask_prompt
111
+ self.ignore_id = ignore_id
112
+
113
+ super().__init__(
114
+ image_processor,
115
+ tokenizer,
116
+ image_tag,
117
+ num_image_tokens,
118
+ add_special_token,
119
+ sft_format,
120
+ mask_prompt,
121
+ ignore_id,
122
+ **kwargs,
123
+ )
124
+
125
+ def new_chat_template(self):
126
+ conv = get_conv_template(self.sft_format)
127
+ conv.set_system_message(self.system_prompt)
128
+ return conv
129
+
130
+ def apply_sft_template_for_multi_turn_prompts(
131
+ self,
132
+ conversations: List[Dict[str, str]],
133
+ sft_format: str = "deepseek",
134
+ system_prompt: str = "",
135
+ ):
136
+ """
137
+ Applies the SFT template to conversation.
138
+
139
+ An example of conversation:
140
+ conversation = [
141
+ {
142
+ "role": "User",
143
+ "content": "<image_placeholder> is Figure 1.\n<image_placeholder> is Figure 2.\nWhich image is brighter?",
144
+ "images": [
145
+ "./multi-images/attribute_comparison_1.png",
146
+ "./multi-images/attribute_comparison_2.png"
147
+ ]
148
+ },
149
+ {
150
+ "role": "Assistant",
151
+ "content": ""
152
+ }
153
+ ]
154
+
155
+ Args:
156
+ conversations (List[Dict]): A conversation with a List of Dict[str, str] text.
157
+ sft_format (str, optional): The format of the SFT template to use. Defaults to "deepseek".
158
+ system_prompt (str, optional): The system prompt to use in the SFT template. Defaults to "".
159
+
160
+ Returns:
161
+ sft_prompt (str): The formatted text.
162
+ """
163
+
164
+ conv = get_conv_template(sft_format)
165
+ conv.set_system_message(system_prompt)
166
+ for message in conversations:
167
+ conv.append_message(message["role"], message["content"].strip())
168
+ sft_prompt = conv.get_prompt().strip()
169
+
170
+ return sft_prompt
171
+
172
+ @property
173
+ def image_token(self):
174
+ return self.image_tag
175
+
176
+ @property
177
+ def image_id(self):
178
+ image_id = self.tokenizer.vocab.get(self.image_tag)
179
+ return image_id
180
+
181
+ @property
182
+ def pad_id(self):
183
+ pad_id = self.tokenizer.pad_token_id
184
+ if pad_id is None:
185
+ pad_id = self.tokenizer.eos_token_id
186
+
187
+ return pad_id
188
+
189
+ def add_image_token(
190
+ self,
191
+ image_indices: List[int],
192
+ input_ids: torch.LongTensor,
193
+ ):
194
+ """
195
+
196
+ Args:
197
+ image_indices (List[int]): [index_0, index_1, ..., index_j]
198
+ input_ids (torch.LongTensor): [N]
199
+
200
+ Returns:
201
+ input_ids (torch.LongTensor): [N + image tokens]
202
+ num_image_tokens (torch.IntTensor): [n_images]
203
+ """
204
+
205
+ input_slices = []
206
+
207
+ start = 0
208
+ for index in image_indices:
209
+ if self.add_special_token:
210
+ end = index + 1
211
+ else:
212
+ end = index
213
+
214
+ # original text tokens
215
+ input_slices.append(input_ids[start:end])
216
+
217
+ # add image tokens, and set the mask as False
218
+ input_slices.append(
219
+ self.image_id * torch.ones((self.num_image_tokens,), dtype=torch.long)
220
+ )
221
+ start = index + 1
222
+
223
+ # the left part
224
+ input_slices.append(input_ids[start:])
225
+
226
+ # concat all slices
227
+ input_ids = torch.cat(input_slices, dim=0)
228
+ num_image_tokens = torch.IntTensor([self.num_image_tokens] * len(image_indices))
229
+
230
+ return input_ids, num_image_tokens
231
+
232
+ def process_one(
233
+ self,
234
+ prompt: str = None,
235
+ conversations: List[Dict[str, str]] = None,
236
+ images: List[Image] = None,
237
+ **kwargs,
238
+ ):
239
+ """
240
+
241
+ Args:
242
+ prompt (str): the formatted prompt;
243
+ conversations (List[Dict]): conversations with a list of messages;
244
+ images (List[ImageType]): the list of images;
245
+ **kwargs:
246
+
247
+ Returns:
248
+ outputs (BaseProcessorOutput): the output of the processor,
249
+ - input_ids (torch.LongTensor): [N + image tokens]
250
+ - target_ids (torch.LongTensor): [N + image tokens]
251
+ - images (torch.FloatTensor): [n_images, 3, H, W]
252
+ - image_id (int): the id of the image token
253
+ - num_image_tokens (List[int]): the number of image tokens
254
+ """
255
+
256
+ assert (
257
+ prompt is None or conversations is None
258
+ ), "prompt and conversations cannot be used at the same time."
259
+
260
+ if prompt is None:
261
+ # apply sft format
262
+ sft_format = self.apply_sft_template_for_multi_turn_prompts(
263
+ conversations=conversations,
264
+ sft_format=self.sft_format,
265
+ system_prompt=self.system_prompt,
266
+ )
267
+ else:
268
+ sft_format = prompt
269
+
270
+ # tokenize
271
+ input_ids = self.tokenizer.encode(sft_format)
272
+ input_ids = torch.LongTensor(input_ids)
273
+
274
+ # add image tokens to the input_ids
275
+ image_token_mask: torch.BoolTensor = input_ids == self.image_id
276
+ image_indices = image_token_mask.nonzero()
277
+ input_ids, num_image_tokens = self.add_image_token(
278
+ image_indices=image_indices,
279
+ input_ids=input_ids,
280
+ )
281
+
282
+ # load images
283
+ images_outputs = self.image_processor(images, return_tensors="pt")
284
+
285
+ prepare = VLChatProcessorOutput(
286
+ sft_format=sft_format,
287
+ input_ids=input_ids,
288
+ pixel_values=images_outputs.pixel_values,
289
+ num_image_tokens=num_image_tokens,
290
+ )
291
+
292
+ return prepare
293
+
294
+ def __call__(
295
+ self,
296
+ *,
297
+ prompt: str = None,
298
+ conversations: List[Dict[str, str]] = None,
299
+ images: List[Image] = None,
300
+ force_batchify: bool = True,
301
+ **kwargs,
302
+ ):
303
+ """
304
+
305
+ Args:
306
+ prompt (str): the formatted prompt;
307
+ conversations (List[Dict]): conversations with a list of messages;
308
+ images (List[ImageType]): the list of images;
309
+ force_batchify (bool): force batchify the inputs;
310
+ **kwargs:
311
+
312
+ Returns:
313
+ outputs (BaseProcessorOutput): the output of the processor,
314
+ - input_ids (torch.LongTensor): [N + image tokens]
315
+ - images (torch.FloatTensor): [n_images, 3, H, W]
316
+ - image_id (int): the id of the image token
317
+ - num_image_tokens (List[int]): the number of image tokens
318
+ """
319
+
320
+ prepare = self.process_one(
321
+ prompt=prompt, conversations=conversations, images=images
322
+ )
323
+
324
+ if force_batchify:
325
+ prepare = self.batchify([prepare])
326
+
327
+ return prepare
328
+
329
+ def batchify(
330
+ self, prepare_list: List[VLChatProcessorOutput]
331
+ ) -> BatchedVLChatProcessorOutput:
332
+ """
333
+ Preprocesses the inputs for multimodal inference.
334
+
335
+ Args:
336
+ prepare_list (List[VLChatProcessorOutput]): A list of VLChatProcessorOutput.
337
+
338
+ Returns:
339
+ BatchedVLChatProcessorOutput: A dictionary of the inputs to use for multimodal inference.
340
+ """
341
+
342
+ batch_size = len(prepare_list)
343
+ sft_format = []
344
+ n_images = []
345
+ seq_lens = []
346
+ for prepare in prepare_list:
347
+ n_images.append(len(prepare.num_image_tokens))
348
+ seq_lens.append(len(prepare))
349
+
350
+ input_token_max_len = max(seq_lens)
351
+ max_n_images = max(1, max(n_images))
352
+
353
+ batched_input_ids = torch.full(
354
+ (batch_size, input_token_max_len), self.pad_id
355
+ ).long() # FIXME
356
+ batched_attention_mask = torch.zeros((batch_size, input_token_max_len)).long()
357
+ batched_pixel_values = torch.zeros(
358
+ (batch_size, max_n_images, *self.image_processor.default_shape)
359
+ ).float()
360
+ batched_images_seq_mask = torch.zeros((batch_size, input_token_max_len)).bool()
361
+ batched_images_emb_mask = torch.zeros(
362
+ (batch_size, max_n_images, self.num_image_tokens)
363
+ ).bool()
364
+
365
+ for i, prepare in enumerate(prepare_list):
366
+ input_ids = prepare.input_ids
367
+ seq_len = len(prepare)
368
+ n_image = len(prepare.num_image_tokens)
369
+ # left-padding
370
+ batched_attention_mask[i, -seq_len:] = 1
371
+ batched_input_ids[i, -seq_len:] = torch.LongTensor(input_ids)
372
+ batched_images_seq_mask[i, -seq_len:] = input_ids == self.image_id
373
+
374
+ if n_image > 0:
375
+ batched_pixel_values[i, :n_image] = prepare.pixel_values
376
+ for j, n_image_tokens in enumerate(prepare.num_image_tokens):
377
+ batched_images_emb_mask[i, j, :n_image_tokens] = True
378
+
379
+ sft_format.append(prepare.sft_format)
380
+
381
+ batched_prepares = BatchedVLChatProcessorOutput(
382
+ input_ids=batched_input_ids,
383
+ attention_mask=batched_attention_mask,
384
+ pixel_values=batched_pixel_values,
385
+ images_seq_mask=batched_images_seq_mask,
386
+ images_emb_mask=batched_images_emb_mask,
387
+ sft_format=sft_format,
388
+ )
389
+
390
+ return batched_prepares
deepseek_vl/models/projector.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023-2024 DeepSeek.
2
+ #
3
+ # Permission is hereby granted, free of charge, to any person obtaining a copy of
4
+ # this software and associated documentation files (the "Software"), to deal in
5
+ # the Software without restriction, including without limitation the rights to
6
+ # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
7
+ # the Software, and to permit persons to whom the Software is furnished to do so,
8
+ # subject to the following conditions:
9
+ #
10
+ # The above copyright notice and this permission notice shall be included in all
11
+ # copies or substantial portions of the Software.
12
+ #
13
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
15
+ # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
16
+ # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
17
+ # IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
18
+ # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
19
+
20
+ from typing import Tuple, Union
21
+
22
+ import torch
23
+ import torch.nn as nn
24
+ from attrdict import AttrDict
25
+
26
+
27
+ class MlpProjector(nn.Module):
28
+ def __init__(self, cfg):
29
+ super().__init__()
30
+
31
+ self.cfg = cfg
32
+
33
+ if cfg.projector_type == "identity":
34
+ modules = nn.Identity()
35
+
36
+ elif cfg.projector_type == "linear":
37
+ modules = nn.Linear(cfg.input_dim, cfg.n_embed)
38
+
39
+ elif cfg.projector_type == "mlp_gelu":
40
+ mlp_depth = cfg.get("depth", 1)
41
+ modules = [nn.Linear(cfg.input_dim, cfg.n_embed)]
42
+ for _ in range(1, mlp_depth):
43
+ modules.append(nn.GELU())
44
+ modules.append(nn.Linear(cfg.n_embed, cfg.n_embed))
45
+ modules = nn.Sequential(*modules)
46
+
47
+ elif cfg.projector_type == "low_high_hybrid_split_mlp_gelu":
48
+ mlp_depth = cfg.get("depth", 1)
49
+ self.high_up_proj = nn.Linear(cfg.input_dim, cfg.n_embed // 2)
50
+ self.low_up_proj = nn.Linear(cfg.input_dim, cfg.n_embed // 2)
51
+
52
+ modules = []
53
+ for _ in range(1, mlp_depth):
54
+ modules.append(nn.GELU())
55
+ modules.append(nn.Linear(cfg.n_embed, cfg.n_embed))
56
+ modules = nn.Sequential(*modules)
57
+
58
+ else:
59
+ raise ValueError(f"Unknown projector type: {cfg.projector_type}")
60
+
61
+ self.layers = modules
62
+
63
+ def forward(
64
+ self, x_or_tuple: Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor]
65
+ ):
66
+ """
67
+
68
+ Args:
69
+ x_or_tuple (Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor]: if it is a tuple of torch.Tensor,
70
+ then it comes from the hybrid vision encoder, and x = high_res_x, low_res_x);
71
+ otherwise it is the feature from the single vision encoder.
72
+
73
+ Returns:
74
+ x (torch.Tensor): [b, s, c]
75
+ """
76
+
77
+ if isinstance(x_or_tuple, tuple):
78
+ # self.cfg.projector_type == "low_high_hybrid_split_mlp_gelu":
79
+ high_x, low_x = x_or_tuple
80
+ high_x = self.high_up_proj(high_x)
81
+ low_x = self.low_up_proj(low_x)
82
+ x = torch.concat([high_x, low_x], dim=-1)
83
+ else:
84
+ x = x_or_tuple
85
+
86
+ return self.layers(x)
87
+
88
+
89
+ if __name__ == "__main__":
90
+ cfg = AttrDict(
91
+ input_dim=1024,
92
+ n_embed=2048,
93
+ depth=2,
94
+ projector_type="low_high_hybrid_split_mlp_gelu",
95
+ )
96
+ inputs = (torch.rand(4, 576, 1024), torch.rand(4, 576, 1024))
97
+
98
+ m = MlpProjector(cfg)
99
+ out = m(inputs)
100
+ print(out.shape)
deepseek_vl/models/sam.py ADDED
@@ -0,0 +1,593 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import copy
8
+ from dataclasses import dataclass
9
+ from functools import partial
10
+ from typing import List, Optional, Tuple, Type, Union
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.nn.functional as F
15
+
16
+
17
+ class MLPBlock(nn.Module):
18
+ def __init__(
19
+ self,
20
+ embedding_dim: int,
21
+ mlp_dim: int,
22
+ act: Type[nn.Module] = nn.GELU,
23
+ ) -> None:
24
+ super().__init__()
25
+ self.lin1 = nn.Linear(embedding_dim, mlp_dim)
26
+ self.lin2 = nn.Linear(mlp_dim, embedding_dim)
27
+ self.act = act()
28
+
29
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
30
+ return self.lin2(self.act(self.lin1(x)))
31
+
32
+
33
+ # From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa
34
+ # Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa
35
+ class LayerNorm2d(nn.Module):
36
+ def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
37
+ super().__init__()
38
+ self.weight = nn.Parameter(torch.ones(num_channels))
39
+ self.bias = nn.Parameter(torch.zeros(num_channels))
40
+ self.eps = eps
41
+
42
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
43
+ u = x.mean(1, keepdim=True)
44
+ s = (x - u).pow(2).mean(1, keepdim=True)
45
+ x = (x - u) / torch.sqrt(s + self.eps)
46
+ x = self.weight[:, None, None] * x + self.bias[:, None, None]
47
+ return x
48
+
49
+
50
+ # This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py # noqa
51
+ class ImageEncoderViT(nn.Module):
52
+ def __init__(
53
+ self,
54
+ img_size: int = 1024,
55
+ patch_size: int = 16,
56
+ in_chans: int = 3,
57
+ embed_dim: int = 768,
58
+ depth: int = 12,
59
+ num_heads: int = 12,
60
+ mlp_ratio: float = 4.0,
61
+ out_chans: int = 256,
62
+ qkv_bias: bool = True,
63
+ norm_layer: Type[nn.Module] = nn.LayerNorm,
64
+ act_layer: Type[nn.Module] = nn.GELU,
65
+ use_abs_pos: bool = True,
66
+ use_rel_pos: bool = False,
67
+ rel_pos_zero_init: bool = True,
68
+ window_size: int = 0,
69
+ global_attn_indexes: Tuple[int, ...] = (),
70
+ downsample_channels: Tuple[int, ...] = (512, 1024),
71
+ ) -> None:
72
+ """
73
+ Args:
74
+ img_size (int): Input image size.
75
+ patch_size (int): Patch size.
76
+ in_chans (int): Number of input image channels.
77
+ embed_dim (int): Patch embedding dimension.
78
+ depth (int): Depth of ViT.
79
+ num_heads (int): Number of attention heads in each ViT block.
80
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
81
+ qkv_bias (bool): If True, add a learnable bias to query, key, value.
82
+ norm_layer (nn.Module): Normalization layer.
83
+ act_layer (nn.Module): Activation layer.
84
+ use_abs_pos (bool): If True, use absolute positional embeddings.
85
+ use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
86
+ rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
87
+ window_size (int): Window size for window attention blocks.
88
+ global_attn_indexes (list): Indexes for blocks using global attention.
89
+ downsample_channels (list): Channels for downsampling layers.
90
+ """
91
+ super().__init__()
92
+ self.img_size = img_size
93
+
94
+ self.patch_embed = PatchEmbed(
95
+ kernel_size=(patch_size, patch_size),
96
+ stride=(patch_size, patch_size),
97
+ in_chans=in_chans,
98
+ embed_dim=embed_dim,
99
+ )
100
+
101
+ self.pos_embed: Optional[nn.Parameter] = None
102
+ if use_abs_pos:
103
+ # Initialize absolute positional embedding with pretrain image size.
104
+ self.pos_embed = nn.Parameter(
105
+ torch.zeros(
106
+ 1, img_size // patch_size, img_size // patch_size, embed_dim
107
+ )
108
+ )
109
+
110
+ self.blocks = nn.ModuleList()
111
+ for i in range(depth):
112
+ block = Block(
113
+ dim=embed_dim,
114
+ num_heads=num_heads,
115
+ mlp_ratio=mlp_ratio,
116
+ qkv_bias=qkv_bias,
117
+ norm_layer=norm_layer,
118
+ act_layer=act_layer,
119
+ use_rel_pos=use_rel_pos,
120
+ rel_pos_zero_init=rel_pos_zero_init,
121
+ window_size=window_size if i not in global_attn_indexes else 0,
122
+ input_size=(img_size // patch_size, img_size // patch_size),
123
+ )
124
+ self.blocks.append(block)
125
+
126
+ self.neck = nn.Sequential(
127
+ nn.Conv2d(
128
+ embed_dim,
129
+ out_chans,
130
+ kernel_size=1,
131
+ bias=False,
132
+ ),
133
+ LayerNorm2d(out_chans),
134
+ nn.Conv2d(
135
+ out_chans,
136
+ out_chans,
137
+ kernel_size=3,
138
+ padding=1,
139
+ bias=False,
140
+ ),
141
+ LayerNorm2d(out_chans),
142
+ )
143
+
144
+ in_channels = out_chans
145
+ downsamples = []
146
+ for i in range(len(downsample_channels)):
147
+ out_channels = downsample_channels[i]
148
+ downsamples.append(
149
+ nn.Conv2d(
150
+ in_channels,
151
+ out_channels,
152
+ kernel_size=3,
153
+ stride=2,
154
+ padding=1,
155
+ bias=False,
156
+ )
157
+ )
158
+ in_channels = out_channels
159
+ self.downsamples = nn.Sequential(*downsamples)
160
+
161
+ self.sam_hd = True
162
+ if self.sam_hd:
163
+ self.hd_alpha_downsamples = nn.Parameter(torch.zeros(1))
164
+ # self.neck_hd = nn.Linear(embed_dim, embed_dim)
165
+ self.neck_hd = copy.deepcopy(self.neck)
166
+ # self.downsamples_hd = copy.deepcopy(self.downsamples)
167
+
168
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
169
+ x = self.patch_embed(x)
170
+ if self.pos_embed is not None:
171
+ x = x + self.pos_embed
172
+
173
+ global_features = []
174
+ for i, blk in enumerate(self.blocks):
175
+ x = blk(x)
176
+ if self.sam_hd and blk.window_size == 0:
177
+ global_features.append(x)
178
+
179
+ x = self.neck(x.permute(0, 3, 1, 2))
180
+ x_dtype = x.dtype
181
+ x = F.interpolate(
182
+ x.float(), size=(96, 96), mode="bilinear", align_corners=False
183
+ ).to(x_dtype)
184
+ x = self.downsamples(x)
185
+
186
+ if self.sam_hd:
187
+ first_global_feature = self.neck_hd(global_features[0].permute(0, 3, 1, 2))
188
+ x_dtype = first_global_feature.dtype
189
+ first_global_feature = F.interpolate(
190
+ first_global_feature.float(),
191
+ size=(96, 96),
192
+ mode="bilinear",
193
+ align_corners=False,
194
+ )
195
+ first_global_feature = self.downsamples(first_global_feature.to(x_dtype))
196
+ x = x + first_global_feature * self.hd_alpha_downsamples
197
+
198
+ return x
199
+
200
+
201
+ class Block(nn.Module):
202
+ """Transformer blocks with support of window attention and residual propagation blocks"""
203
+
204
+ def __init__(
205
+ self,
206
+ dim: int,
207
+ num_heads: int,
208
+ mlp_ratio: float = 4.0,
209
+ qkv_bias: bool = True,
210
+ norm_layer: Type[nn.Module] = nn.LayerNorm,
211
+ act_layer: Type[nn.Module] = nn.GELU,
212
+ use_rel_pos: bool = False,
213
+ rel_pos_zero_init: bool = True,
214
+ window_size: int = 0,
215
+ input_size: Optional[Tuple[int, int]] = None,
216
+ ) -> None:
217
+ """
218
+ Args:
219
+ dim (int): Number of input channels.
220
+ num_heads (int): Number of attention heads in each ViT block.
221
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
222
+ qkv_bias (bool): If True, add a learnable bias to query, key, value.
223
+ norm_layer (nn.Module): Normalization layer.
224
+ act_layer (nn.Module): Activation layer.
225
+ use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
226
+ rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
227
+ window_size (int): Window size for window attention blocks. If it equals 0, then
228
+ use global attention.
229
+ input_size (tuple(int, int) or None): Input resolution for calculating the relative
230
+ positional parameter size.
231
+ """
232
+ super().__init__()
233
+ self.norm1 = norm_layer(dim)
234
+ self.attn = Attention(
235
+ dim,
236
+ num_heads=num_heads,
237
+ qkv_bias=qkv_bias,
238
+ use_rel_pos=use_rel_pos,
239
+ rel_pos_zero_init=rel_pos_zero_init,
240
+ input_size=input_size if window_size == 0 else (window_size, window_size),
241
+ )
242
+
243
+ self.norm2 = norm_layer(dim)
244
+ self.mlp = MLPBlock(
245
+ embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer
246
+ )
247
+
248
+ self.window_size = window_size
249
+
250
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
251
+ shortcut = x
252
+ x = self.norm1(x)
253
+ # Window partition
254
+ if self.window_size > 0:
255
+ H, W = x.shape[1], x.shape[2]
256
+ x, pad_hw = window_partition(x, self.window_size)
257
+
258
+ x = self.attn(x)
259
+ # Reverse window partition
260
+ if self.window_size > 0:
261
+ x = window_unpartition(x, self.window_size, pad_hw, (H, W))
262
+
263
+ x = shortcut + x
264
+ x = x + self.mlp(self.norm2(x))
265
+
266
+ return x
267
+
268
+
269
+ class Attention(nn.Module):
270
+ """Multi-head Attention block with relative position embeddings."""
271
+
272
+ def __init__(
273
+ self,
274
+ dim: int,
275
+ num_heads: int = 8,
276
+ qkv_bias: bool = True,
277
+ use_rel_pos: bool = False,
278
+ rel_pos_zero_init: bool = True,
279
+ input_size: Optional[Tuple[int, int]] = None,
280
+ ) -> None:
281
+ """
282
+ Args:
283
+ dim (int): Number of input channels.
284
+ num_heads (int): Number of attention heads.
285
+ qkv_bias (bool): If True, add a learnable bias to query, key, value.
286
+ rel_pos (bool): If True, add relative positional embeddings to the attention map.
287
+ rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
288
+ input_size (tuple(int, int) or None): Input resolution for calculating the relative
289
+ positional parameter size.
290
+ """
291
+ super().__init__()
292
+ self.num_heads = num_heads
293
+ head_dim = dim // num_heads
294
+ self.scale = head_dim**-0.5
295
+
296
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
297
+ self.proj = nn.Linear(dim, dim)
298
+
299
+ self.use_rel_pos = use_rel_pos
300
+ if self.use_rel_pos:
301
+ assert (
302
+ input_size is not None
303
+ ), "Input size must be provided if using relative positional encoding."
304
+ # initialize relative positional embeddings
305
+ self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim))
306
+ self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))
307
+
308
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
309
+ B, H, W, _ = x.shape
310
+ # qkv with shape (3, B, nHead, H * W, C)
311
+ qkv = (
312
+ self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
313
+ )
314
+ # q, k, v with shape (B * nHead, H * W, C)
315
+ q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0)
316
+
317
+ def do_attention(q, k, v):
318
+ attn = (q * self.scale) @ k.transpose(-2, -1)
319
+ if self.use_rel_pos:
320
+ attn = add_decomposed_rel_pos(
321
+ attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W)
322
+ )
323
+
324
+ attn = attn.softmax(dim=-1)
325
+ x = (
326
+ (attn @ v)
327
+ .view(B, self.num_heads, H, W, -1)
328
+ .permute(0, 2, 3, 1, 4)
329
+ .reshape(B, H, W, -1)
330
+ )
331
+
332
+ return x
333
+
334
+ # from haiscale.utils import on_demand_checkpoint
335
+ # x = on_demand_checkpoint(do_attention, q, k, v)
336
+ x = do_attention(q, k, v)
337
+ x = self.proj(x)
338
+
339
+ return x
340
+
341
+
342
+ def window_partition(
343
+ x: torch.Tensor, window_size: int
344
+ ) -> Tuple[torch.Tensor, Tuple[int, int]]:
345
+ """
346
+ Partition into non-overlapping windows with padding if needed.
347
+ Args:
348
+ x (tensor): input tokens with [B, H, W, C].
349
+ window_size (int): window size.
350
+
351
+ Returns:
352
+ windows: windows after partition with [B * num_windows, window_size, window_size, C].
353
+ (Hp, Wp): padded height and width before partition
354
+ """
355
+ B, H, W, C = x.shape
356
+
357
+ pad_h = (window_size - H % window_size) % window_size
358
+ pad_w = (window_size - W % window_size) % window_size
359
+ if pad_h > 0 or pad_w > 0:
360
+ x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
361
+ Hp, Wp = H + pad_h, W + pad_w
362
+
363
+ x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
364
+ windows = (
365
+ x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
366
+ )
367
+ return windows, (Hp, Wp)
368
+
369
+
370
+ def window_unpartition(
371
+ windows: torch.Tensor,
372
+ window_size: int,
373
+ pad_hw: Tuple[int, int],
374
+ hw: Tuple[int, int],
375
+ ) -> torch.Tensor:
376
+ """
377
+ Window unpartition into original sequences and removing padding.
378
+ Args:
379
+ windows (tensor): input tokens with [B * num_windows, window_size, window_size, C].
380
+ window_size (int): window size.
381
+ pad_hw (Tuple): padded height and width (Hp, Wp).
382
+ hw (Tuple): original height and width (H, W) before padding.
383
+
384
+ Returns:
385
+ x: unpartitioned sequences with [B, H, W, C].
386
+ """
387
+ Hp, Wp = pad_hw
388
+ H, W = hw
389
+ B = windows.shape[0] // (Hp * Wp // window_size // window_size)
390
+ x = windows.view(
391
+ B, Hp // window_size, Wp // window_size, window_size, window_size, -1
392
+ )
393
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
394
+
395
+ if Hp > H or Wp > W:
396
+ x = x[:, :H, :W, :].contiguous()
397
+ return x
398
+
399
+
400
+ def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor:
401
+ """
402
+ Get relative positional embeddings according to the relative positions of
403
+ query and key sizes.
404
+ Args:
405
+ q_size (int): size of query q.
406
+ k_size (int): size of key k.
407
+ rel_pos (Tensor): relative position embeddings (L, C).
408
+
409
+ Returns:
410
+ Extracted positional embeddings according to relative positions.
411
+ """
412
+ max_rel_dist = int(2 * max(q_size, k_size) - 1)
413
+ # Interpolate rel pos if needed.
414
+ if rel_pos.shape[0] != max_rel_dist:
415
+ # Interpolate rel pos.
416
+ rel_pos_resized = F.interpolate(
417
+ rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
418
+ size=max_rel_dist,
419
+ mode="linear",
420
+ )
421
+ rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
422
+ else:
423
+ rel_pos_resized = rel_pos
424
+
425
+ # Scale the coords with short length if shapes for q and k are different.
426
+ q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
427
+ k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
428
+ relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
429
+
430
+ return rel_pos_resized[relative_coords.long()]
431
+
432
+
433
+ def add_decomposed_rel_pos(
434
+ attn: torch.Tensor,
435
+ q: torch.Tensor,
436
+ rel_pos_h: torch.Tensor,
437
+ rel_pos_w: torch.Tensor,
438
+ q_size: Tuple[int, int],
439
+ k_size: Tuple[int, int],
440
+ ) -> torch.Tensor:
441
+ """
442
+ Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
443
+ https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950
444
+ Args:
445
+ attn (Tensor): attention map.
446
+ q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C).
447
+ rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis.
448
+ rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis.
449
+ q_size (Tuple): spatial sequence size of query q with (q_h, q_w).
450
+ k_size (Tuple): spatial sequence size of key k with (k_h, k_w).
451
+
452
+ Returns:
453
+ attn (Tensor): attention map with added relative positional embeddings.
454
+ """
455
+ q_h, q_w = q_size
456
+ k_h, k_w = k_size
457
+ Rh = get_rel_pos(q_h, k_h, rel_pos_h)
458
+ Rw = get_rel_pos(q_w, k_w, rel_pos_w)
459
+
460
+ B, _, dim = q.shape
461
+ r_q = q.reshape(B, q_h, q_w, dim)
462
+ rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)
463
+ rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)
464
+
465
+ attn = (
466
+ attn.view(B, q_h, q_w, k_h, k_w)
467
+ + rel_h[:, :, :, :, None]
468
+ + rel_w[:, :, :, None, :]
469
+ ).view(B, q_h * q_w, k_h * k_w)
470
+
471
+ return attn
472
+
473
+
474
+ class PatchEmbed(nn.Module):
475
+ """
476
+ Image to Patch Embedding.
477
+ """
478
+
479
+ def __init__(
480
+ self,
481
+ kernel_size: Tuple[int, int] = (16, 16),
482
+ stride: Tuple[int, int] = (16, 16),
483
+ padding: Tuple[int, int] = (0, 0),
484
+ in_chans: int = 3,
485
+ embed_dim: int = 768,
486
+ ) -> None:
487
+ """
488
+ Args:
489
+ kernel_size (Tuple): kernel size of the projection layer.
490
+ stride (Tuple): stride of the projection layer.
491
+ padding (Tuple): padding size of the projection layer.
492
+ in_chans (int): Number of input image channels.
493
+ embed_dim (int): Patch embedding dimension.
494
+ """
495
+ super().__init__()
496
+
497
+ self.proj = nn.Conv2d(
498
+ in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding
499
+ )
500
+
501
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
502
+ x = self.proj(x)
503
+ # B C H W -> B H W C
504
+ x = x.permute(0, 2, 3, 1)
505
+ return x
506
+
507
+
508
+ @dataclass
509
+ class SAMViTCfg:
510
+ image_size: Union[Tuple[int, int], int] = 1024
511
+ width: int = 1024
512
+ layers: int = 23
513
+ heads: int = 16
514
+ patch_size: int = 16
515
+ window_size: int = 14
516
+ prompt_embed_dim: int = 256
517
+ global_attn_indexes: Union[List[int], Tuple[int]] = (5, 11, 17, 23)
518
+ downsample_channels: Union[List[int], Tuple[int]] = (512, 1024)
519
+
520
+
521
+ SAM_MODEL_CONFIG = {
522
+ "sam_vit_b": {
523
+ "width": 768,
524
+ "layers": 12,
525
+ "heads": 12,
526
+ "global_attn_indexes": [2, 5, 8, 11],
527
+ "downsample_channels": (),
528
+ },
529
+ "sam_b_downsample": {
530
+ "width": 768,
531
+ "layers": 12,
532
+ "heads": 12,
533
+ "global_attn_indexes": [2, 5, 8, 11],
534
+ "downsample_channels": (512, 1024),
535
+ },
536
+ "sam_vit_l": {
537
+ "width": 1024,
538
+ "layers": 24,
539
+ "heads": 16,
540
+ "global_attn_indexes": [5, 11, 17, 23],
541
+ "downsample_channels": (),
542
+ },
543
+ "sam_vit_h": {
544
+ "width": 1280,
545
+ "layers": 32,
546
+ "heads": 16,
547
+ "global_attn_indexes": [7, 15, 23, 31],
548
+ "downsample_channels": (),
549
+ },
550
+ }
551
+
552
+
553
+ def create_sam_vit(
554
+ model_name: str = "sam_b_downsample",
555
+ image_size: int = 1024,
556
+ ckpt_path: str = "",
557
+ **kwargs,
558
+ ):
559
+ assert (
560
+ model_name in SAM_MODEL_CONFIG.keys()
561
+ ), f"model name: {model_name} should be in {SAM_MODEL_CONFIG.keys()}"
562
+
563
+ sam_cfg = SAMViTCfg(**SAM_MODEL_CONFIG[model_name])
564
+ image_encoder = ImageEncoderViT(
565
+ depth=sam_cfg.layers,
566
+ embed_dim=sam_cfg.width,
567
+ img_size=image_size,
568
+ mlp_ratio=4,
569
+ norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
570
+ num_heads=sam_cfg.heads,
571
+ patch_size=sam_cfg.patch_size,
572
+ qkv_bias=True,
573
+ use_rel_pos=True,
574
+ global_attn_indexes=sam_cfg.global_attn_indexes,
575
+ window_size=14,
576
+ out_chans=sam_cfg.prompt_embed_dim,
577
+ downsample_channels=sam_cfg.downsample_channels,
578
+ )
579
+
580
+ if ckpt_path:
581
+ state_dict = torch.load(ckpt_path)
582
+ image_encoder.load_state_dict(state_dict, strict=False)
583
+ print(f"SAM-ViT restores from {ckpt_path}")
584
+
585
+ return image_encoder
586
+
587
+
588
+ if __name__ == "__main__":
589
+ x = torch.zeros(2, 3, 1024, 1024).bfloat16()
590
+ # x.permute(0, 3, 1, 2)
591
+ net = create_sam_vit().bfloat16()
592
+ out = net(x)
593
+ print(x.shape, out.shape)
deepseek_vl/models/siglip_vit.py ADDED
@@ -0,0 +1,681 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023-2024 DeepSeek.
2
+ #
3
+ # Permission is hereby granted, free of charge, to any person obtaining a copy of
4
+ # this software and associated documentation files (the "Software"), to deal in
5
+ # the Software without restriction, including without limitation the rights to
6
+ # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
7
+ # the Software, and to permit persons to whom the Software is furnished to do so,
8
+ # subject to the following conditions:
9
+ #
10
+ # The above copyright notice and this permission notice shall be included in all
11
+ # copies or substantial portions of the Software.
12
+ #
13
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
15
+ # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
16
+ # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
17
+ # IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
18
+ # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
19
+
20
+ # https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py
21
+ import math
22
+ import warnings
23
+ from dataclasses import dataclass
24
+ from functools import partial
25
+ from typing import (
26
+ Callable,
27
+ Dict,
28
+ Final,
29
+ List,
30
+ Literal,
31
+ Optional,
32
+ Sequence,
33
+ Set,
34
+ Tuple,
35
+ Type,
36
+ Union,
37
+ )
38
+
39
+ import torch
40
+ import torch.nn as nn
41
+ import torch.nn.functional as F
42
+ from timm.layers import (
43
+ AttentionPoolLatent,
44
+ DropPath,
45
+ LayerType,
46
+ Mlp,
47
+ PatchDropout,
48
+ PatchEmbed,
49
+ resample_abs_pos_embed,
50
+ )
51
+ from timm.models._manipulate import checkpoint_seq, named_apply
52
+
53
+
54
+ def _no_grad_trunc_normal_(tensor, mean, std, a, b):
55
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
56
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
57
+ def norm_cdf(x):
58
+ # Computes standard normal cumulative distribution function
59
+ return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
60
+
61
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
62
+ warnings.warn(
63
+ "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
64
+ "The distribution of values may be incorrect.",
65
+ stacklevel=2,
66
+ )
67
+
68
+ with torch.no_grad():
69
+ # Values are generated by using a truncated uniform distribution and
70
+ # then using the inverse CDF for the normal distribution.
71
+ # Get upper and lower cdf values
72
+ l = norm_cdf((a - mean) / std) # noqa: E741
73
+ u = norm_cdf((b - mean) / std)
74
+
75
+ # Uniformly fill tensor with values from [l, u], then translate to
76
+ # [2l-1, 2u-1].
77
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
78
+
79
+ # Use inverse cdf transform for normal distribution to get truncated
80
+ # standard normal
81
+ tensor.erfinv_()
82
+
83
+ # Transform to proper mean, std
84
+ tensor.mul_(std * math.sqrt(2.0))
85
+ tensor.add_(mean)
86
+
87
+ # Clamp to ensure it's in the proper range
88
+ tensor.clamp_(min=a, max=b)
89
+ return tensor
90
+
91
+
92
+ def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):
93
+ # type: (torch.Tensor, float, float, float, float) -> torch.Tensor
94
+ r"""The original timm.models.layers.weight_init.trunc_normal_ can not handle bfloat16 yet, here we first
95
+ convert the tensor to float32, apply the trunc_normal_() in float32, and then convert it back to its orignal dtype.
96
+ Fills the input Tensor with values drawn from a truncated normal distribution. The values are effectively drawn
97
+ from the normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
98
+ with values outside :math:`[a, b]` redrawn until they are within
99
+ the bounds. The method used for generating the random values works
100
+ best when :math:`a \leq \text{mean} \leq b`.
101
+ Args:
102
+ tensor: an n-dimensional `torch.Tensor`
103
+ mean: the mean of the normal distribution
104
+ std: the standard deviation of the normal distribution
105
+ a: the minimum cutoff value
106
+ b: the maximum cutoff value
107
+ Examples:
108
+ >>> w = torch.empty(3, 5)
109
+ >>> nn.init.trunc_normal_(w)
110
+ """
111
+
112
+ with torch.no_grad():
113
+ dtype = tensor.dtype
114
+ tensor_fp32 = tensor.float()
115
+ tensor_fp32 = _no_grad_trunc_normal_(tensor_fp32, mean, std, a, b)
116
+ tensor_dtype = tensor_fp32.to(dtype=dtype)
117
+ tensor.copy_(tensor_dtype)
118
+
119
+
120
+ def init_weights(self):
121
+ if self.pos_embed is not None:
122
+ trunc_normal_(self.pos_embed, std=self.pos_embed.shape[1] ** -0.5)
123
+ trunc_normal_(self.latent, std=self.latent_dim**-0.5)
124
+
125
+
126
+ def init_weights_vit_timm(module: nn.Module, name: str = "") -> None:
127
+ """ViT weight initialization, original timm impl (for reproducibility)"""
128
+ if isinstance(module, nn.Linear):
129
+ trunc_normal_(module.weight, std=0.02)
130
+ if module.bias is not None:
131
+ nn.init.zeros_(module.bias)
132
+ elif hasattr(module, "init_weights"):
133
+ module.init_weights()
134
+
135
+
136
+ class Attention(nn.Module):
137
+ fused_attn: Final[bool]
138
+
139
+ def __init__(
140
+ self,
141
+ dim: int,
142
+ num_heads: int = 8,
143
+ qkv_bias: bool = False,
144
+ qk_norm: bool = False,
145
+ attn_drop: float = 0.0,
146
+ proj_drop: float = 0.0,
147
+ norm_layer: nn.Module = nn.LayerNorm,
148
+ ) -> None:
149
+ super().__init__()
150
+ assert dim % num_heads == 0, "dim should be divisible by num_heads"
151
+ self.num_heads = num_heads
152
+ self.head_dim = dim // num_heads
153
+ self.scale = self.head_dim**-0.5
154
+ # self.fused_attn = use_fused_attn()
155
+ self.fused_attn = True
156
+
157
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
158
+ self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
159
+ self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
160
+ self.attn_drop = nn.Dropout(attn_drop)
161
+ self.proj = nn.Linear(dim, dim)
162
+ self.proj_drop = nn.Dropout(proj_drop) if proj_drop > 0.0 else nn.Identity()
163
+
164
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
165
+ B, N, C = x.shape
166
+ qkv = (
167
+ self.qkv(x)
168
+ .reshape(B, N, 3, self.num_heads, self.head_dim)
169
+ .permute(2, 0, 3, 1, 4)
170
+ )
171
+ q, k, v = qkv.unbind(0)
172
+ q, k = self.q_norm(q), self.k_norm(k)
173
+
174
+ if self.fused_attn:
175
+ x = F.scaled_dot_product_attention(
176
+ q,
177
+ k,
178
+ v,
179
+ dropout_p=self.attn_drop.p if self.training else 0.0,
180
+ )
181
+ else:
182
+ q = q * self.scale
183
+ attn = q @ k.transpose(-2, -1)
184
+ attn = attn.softmax(dim=-1)
185
+ attn = self.attn_drop(attn)
186
+ x = attn @ v
187
+
188
+ x = x.transpose(1, 2).reshape(B, N, C)
189
+ x = self.proj(x)
190
+ x = self.proj_drop(x)
191
+ return x
192
+
193
+
194
+ class LayerScale(nn.Module):
195
+ def __init__(
196
+ self,
197
+ dim: int,
198
+ init_values: float = 1e-5,
199
+ inplace: bool = False,
200
+ ) -> None:
201
+ super().__init__()
202
+ self.inplace = inplace
203
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
204
+
205
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
206
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
207
+
208
+
209
+ class Block(nn.Module):
210
+ def __init__(
211
+ self,
212
+ dim: int,
213
+ num_heads: int,
214
+ mlp_ratio: float = 4.0,
215
+ qkv_bias: bool = False,
216
+ qk_norm: bool = False,
217
+ proj_drop: float = 0.0,
218
+ attn_drop: float = 0.0,
219
+ init_values: Optional[float] = None,
220
+ drop_path: float = 0.0,
221
+ act_layer: nn.Module = nn.GELU,
222
+ norm_layer: nn.Module = nn.LayerNorm,
223
+ mlp_layer: nn.Module = Mlp,
224
+ ) -> None:
225
+ super().__init__()
226
+ self.norm1 = norm_layer(dim)
227
+ self.attn = Attention(
228
+ dim,
229
+ num_heads=num_heads,
230
+ qkv_bias=qkv_bias,
231
+ qk_norm=qk_norm,
232
+ attn_drop=attn_drop,
233
+ proj_drop=proj_drop,
234
+ norm_layer=norm_layer,
235
+ )
236
+ self.ls1 = (
237
+ LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
238
+ )
239
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
240
+
241
+ self.norm2 = norm_layer(dim)
242
+ self.mlp = mlp_layer(
243
+ in_features=dim,
244
+ hidden_features=int(dim * mlp_ratio),
245
+ act_layer=act_layer,
246
+ drop=proj_drop,
247
+ )
248
+ self.ls2 = (
249
+ LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
250
+ )
251
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
252
+
253
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
254
+ x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x))))
255
+ x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
256
+ return x
257
+
258
+
259
+ class VisionTransformer(nn.Module):
260
+ """Vision Transformer
261
+
262
+ A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale`
263
+ - https://arxiv.org/abs/2010.11929
264
+ """
265
+
266
+ dynamic_img_size: Final[bool]
267
+
268
+ def __init__(
269
+ self,
270
+ img_size: Union[int, Tuple[int, int]] = 224,
271
+ patch_size: Union[int, Tuple[int, int]] = 16,
272
+ in_chans: int = 3,
273
+ num_classes: int = 1000,
274
+ global_pool: Literal["", "avg", "token", "map"] = "token",
275
+ embed_dim: int = 768,
276
+ depth: int = 12,
277
+ num_heads: int = 12,
278
+ mlp_ratio: float = 4.0,
279
+ qkv_bias: bool = True,
280
+ qk_norm: bool = False,
281
+ init_values: Optional[float] = None,
282
+ class_token: bool = True,
283
+ no_embed_class: bool = False,
284
+ reg_tokens: int = 0,
285
+ pre_norm: bool = False,
286
+ fc_norm: Optional[bool] = None,
287
+ dynamic_img_size: bool = False,
288
+ dynamic_img_pad: bool = False,
289
+ drop_rate: float = 0.0,
290
+ pos_drop_rate: float = 0.0,
291
+ patch_drop_rate: float = 0.0,
292
+ proj_drop_rate: float = 0.0,
293
+ attn_drop_rate: float = 0.0,
294
+ drop_path_rate: float = 0.0,
295
+ weight_init: Literal["skip", "jax", "jax_nlhb", "moco", ""] = "",
296
+ embed_layer: Callable = PatchEmbed,
297
+ norm_layer: Optional[LayerType] = None,
298
+ act_layer: Optional[LayerType] = None,
299
+ block_fn: Type[nn.Module] = Block,
300
+ mlp_layer: Type[nn.Module] = Mlp,
301
+ ignore_head: bool = False,
302
+ ) -> None:
303
+ """
304
+ Args:
305
+ img_size: Input image size.
306
+ patch_size: Patch size.
307
+ in_chans: Number of image input channels.
308
+ num_classes: Mumber of classes for classification head.
309
+ global_pool: Type of global pooling for final sequence (default: 'token').
310
+ embed_dim: Transformer embedding dimension.
311
+ depth: Depth of transformer.
312
+ num_heads: Number of attention heads.
313
+ mlp_ratio: Ratio of mlp hidden dim to embedding dim.
314
+ qkv_bias: Enable bias for qkv projections if True.
315
+ init_values: Layer-scale init values (layer-scale enabled if not None).
316
+ class_token: Use class token.
317
+ no_embed_class: Don't include position embeddings for class (or reg) tokens.
318
+ reg_tokens: Number of register tokens.
319
+ fc_norm: Pre head norm after pool (instead of before), if None, enabled when global_pool == 'avg'.
320
+ drop_rate: Head dropout rate.
321
+ pos_drop_rate: Position embedding dropout rate.
322
+ attn_drop_rate: Attention dropout rate.
323
+ drop_path_rate: Stochastic depth rate.
324
+ weight_init: Weight initialization scheme.
325
+ embed_layer: Patch embedding layer.
326
+ norm_layer: Normalization layer.
327
+ act_layer: MLP activation layer.
328
+ block_fn: Transformer block layer.
329
+ """
330
+ super().__init__()
331
+ assert global_pool in ("", "avg", "token", "map")
332
+ assert class_token or global_pool != "token"
333
+ use_fc_norm = global_pool == "avg" if fc_norm is None else fc_norm
334
+ # norm_layer = get_norm_layer(norm_layer) or partial(nn.LayerNorm, eps=1e-6)
335
+ # act_layer = get_act_layer(act_layer) or nn.GELU
336
+ norm_layer = partial(nn.LayerNorm, eps=1e-6)
337
+ act_layer = nn.GELU
338
+
339
+ self.num_classes = num_classes
340
+ self.global_pool = global_pool
341
+ self.num_features = self.embed_dim = (
342
+ embed_dim # num_features for consistency with other models
343
+ )
344
+ self.num_prefix_tokens = 1 if class_token else 0
345
+ self.num_prefix_tokens += reg_tokens
346
+ self.num_reg_tokens = reg_tokens
347
+ self.has_class_token = class_token
348
+ self.no_embed_class = (
349
+ no_embed_class # don't embed prefix positions (includes reg)
350
+ )
351
+ self.dynamic_img_size = dynamic_img_size
352
+ self.grad_checkpointing = False
353
+ self.ignore_head = ignore_head
354
+
355
+ embed_args = {}
356
+ if dynamic_img_size:
357
+ # flatten deferred until after pos embed
358
+ embed_args.update(dict(strict_img_size=False, output_fmt="NHWC"))
359
+ self.patch_embed = embed_layer(
360
+ img_size=img_size,
361
+ patch_size=patch_size,
362
+ in_chans=in_chans,
363
+ embed_dim=embed_dim,
364
+ bias=not pre_norm, # disable bias if pre-norm is used (e.g. CLIP)
365
+ dynamic_img_pad=dynamic_img_pad,
366
+ **embed_args,
367
+ )
368
+ num_patches = self.patch_embed.num_patches
369
+
370
+ self.cls_token = (
371
+ nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None
372
+ )
373
+ self.reg_token = (
374
+ nn.Parameter(torch.zeros(1, reg_tokens, embed_dim)) if reg_tokens else None
375
+ )
376
+ embed_len = (
377
+ num_patches if no_embed_class else num_patches + self.num_prefix_tokens
378
+ )
379
+ self.pos_embed = nn.Parameter(torch.randn(1, embed_len, embed_dim) * 0.02)
380
+ self.pos_drop = nn.Dropout(p=pos_drop_rate)
381
+ if patch_drop_rate > 0:
382
+ self.patch_drop = PatchDropout(
383
+ patch_drop_rate,
384
+ num_prefix_tokens=self.num_prefix_tokens,
385
+ )
386
+ else:
387
+ self.patch_drop = nn.Identity()
388
+ self.norm_pre = norm_layer(embed_dim) if pre_norm else nn.Identity()
389
+
390
+ dpr = [
391
+ x.item() for x in torch.linspace(0, drop_path_rate, depth)
392
+ ] # stochastic depth decay rule
393
+ self.blocks = nn.Sequential(
394
+ *[
395
+ block_fn(
396
+ dim=embed_dim,
397
+ num_heads=num_heads,
398
+ mlp_ratio=mlp_ratio,
399
+ qkv_bias=qkv_bias,
400
+ qk_norm=qk_norm,
401
+ init_values=init_values,
402
+ proj_drop=proj_drop_rate,
403
+ attn_drop=attn_drop_rate,
404
+ drop_path=dpr[i],
405
+ norm_layer=norm_layer,
406
+ act_layer=act_layer,
407
+ mlp_layer=mlp_layer,
408
+ )
409
+ for i in range(depth)
410
+ ]
411
+ )
412
+ self.norm = norm_layer(embed_dim) if not use_fc_norm else nn.Identity()
413
+
414
+ # Classifier Head
415
+ if global_pool == "map":
416
+ AttentionPoolLatent.init_weights = init_weights
417
+ self.attn_pool = AttentionPoolLatent(
418
+ self.embed_dim,
419
+ num_heads=num_heads,
420
+ mlp_ratio=mlp_ratio,
421
+ norm_layer=norm_layer,
422
+ )
423
+ else:
424
+ self.attn_pool = None
425
+ self.fc_norm = norm_layer(embed_dim) if use_fc_norm else nn.Identity()
426
+ self.head_drop = nn.Dropout(drop_rate)
427
+ self.head = (
428
+ nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
429
+ )
430
+
431
+ if weight_init != "skip":
432
+ self.init_weights(weight_init)
433
+
434
+ def init_weights(self, mode: Literal["jax", "jax_nlhb", "moco", ""] = "") -> None:
435
+ assert mode in ("jax", "jax_nlhb", "moco", "")
436
+ # head_bias = -math.log(self.num_classes) if "nlhb" in mode else 0.0
437
+ trunc_normal_(self.pos_embed, std=0.02)
438
+ if self.cls_token is not None:
439
+ nn.init.normal_(self.cls_token, std=1e-6)
440
+ named_apply(init_weights_vit_timm, self)
441
+
442
+ @torch.jit.ignore
443
+ def no_weight_decay(self) -> Set:
444
+ return {"pos_embed", "cls_token", "dist_token"}
445
+
446
+ @torch.jit.ignore
447
+ def group_matcher(self, coarse: bool = False) -> Dict:
448
+ return dict(
449
+ stem=r"^cls_token|pos_embed|patch_embed", # stem and embed
450
+ blocks=[(r"^blocks\.(\d+)", None), (r"^norm", (99999,))],
451
+ )
452
+
453
+ @torch.jit.ignore
454
+ def set_grad_checkpointing(self, enable: bool = True) -> None:
455
+ self.grad_checkpointing = enable
456
+
457
+ @torch.jit.ignore
458
+ def get_classifier(self) -> nn.Module:
459
+ return self.head
460
+
461
+ def reset_classifier(self, num_classes: int, global_pool=None) -> None:
462
+ self.num_classes = num_classes
463
+ if global_pool is not None:
464
+ assert global_pool in ("", "avg", "token", "map")
465
+ if global_pool == "map" and self.attn_pool is None:
466
+ assert (
467
+ False
468
+ ), "Cannot currently add attention pooling in reset_classifier()."
469
+ elif global_pool != "map " and self.attn_pool is not None:
470
+ self.attn_pool = None # remove attention pooling
471
+ self.global_pool = global_pool
472
+ self.head = (
473
+ nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
474
+ )
475
+
476
+ def _pos_embed(self, x: torch.Tensor) -> torch.Tensor:
477
+ if self.dynamic_img_size:
478
+ B, H, W, C = x.shape
479
+ pos_embed = resample_abs_pos_embed(
480
+ self.pos_embed,
481
+ (H, W),
482
+ num_prefix_tokens=0 if self.no_embed_class else self.num_prefix_tokens,
483
+ )
484
+ x = x.view(B, -1, C)
485
+ else:
486
+ pos_embed = self.pos_embed
487
+
488
+ to_cat = []
489
+ if self.cls_token is not None:
490
+ to_cat.append(self.cls_token.expand(x.shape[0], -1, -1))
491
+ if self.reg_token is not None:
492
+ to_cat.append(self.reg_token.expand(x.shape[0], -1, -1))
493
+
494
+ if self.no_embed_class:
495
+ # deit-3, updated JAX (big vision)
496
+ # position embedding does not overlap with class token, add then concat
497
+ x = x + pos_embed
498
+ if to_cat:
499
+ x = torch.cat(to_cat + [x], dim=1)
500
+ else:
501
+ # original timm, JAX, and deit vit impl
502
+ # pos_embed has entry for class token, concat then add
503
+ if to_cat:
504
+ x = torch.cat(to_cat + [x], dim=1)
505
+ x = x + pos_embed
506
+
507
+ return self.pos_drop(x)
508
+
509
+ def _intermediate_layers(
510
+ self,
511
+ x: torch.Tensor,
512
+ n: Union[int, Sequence] = 1,
513
+ ) -> List[torch.Tensor]:
514
+ outputs, num_blocks = [], len(self.blocks)
515
+ take_indices = set(
516
+ range(num_blocks - n, num_blocks) if isinstance(n, int) else n
517
+ )
518
+
519
+ # forward pass
520
+ x = self.patch_embed(x)
521
+ x = self._pos_embed(x)
522
+ x = self.patch_drop(x)
523
+ x = self.norm_pre(x)
524
+ for i, blk in enumerate(self.blocks):
525
+ x = blk(x)
526
+ if i in take_indices:
527
+ outputs.append(x)
528
+
529
+ return outputs
530
+
531
+ def get_intermediate_layers(
532
+ self,
533
+ x: torch.Tensor,
534
+ n: Union[int, Sequence] = 1,
535
+ reshape: bool = False,
536
+ return_prefix_tokens: bool = False,
537
+ norm: bool = False,
538
+ ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
539
+ """Intermediate layer accessor (NOTE: This is a WIP experiment).
540
+ Inspired by DINO / DINOv2 interface
541
+ """
542
+ # take last n blocks if n is an int, if in is a sequence, select by matching indices
543
+ outputs = self._intermediate_layers(x, n)
544
+ if norm:
545
+ outputs = [self.norm(out) for out in outputs]
546
+ prefix_tokens = [out[:, 0 : self.num_prefix_tokens] for out in outputs]
547
+ outputs = [out[:, self.num_prefix_tokens :] for out in outputs]
548
+
549
+ if reshape:
550
+ grid_size = self.patch_embed.grid_size
551
+ outputs = [
552
+ out.reshape(x.shape[0], grid_size[0], grid_size[1], -1)
553
+ .permute(0, 3, 1, 2)
554
+ .contiguous()
555
+ for out in outputs
556
+ ]
557
+
558
+ if return_prefix_tokens:
559
+ return tuple(zip(outputs, prefix_tokens))
560
+ return tuple(outputs)
561
+
562
+ def forward_features(self, x: torch.Tensor) -> torch.Tensor:
563
+ x = self.patch_embed(x)
564
+ x = self._pos_embed(x)
565
+ x = self.patch_drop(x)
566
+ x = self.norm_pre(x)
567
+ if self.grad_checkpointing and not torch.jit.is_scripting():
568
+ x = checkpoint_seq(self.blocks, x)
569
+ else:
570
+ x = self.blocks(x)
571
+ x = self.norm(x)
572
+ return x
573
+
574
+ def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor:
575
+ if self.attn_pool is not None:
576
+ x = self.attn_pool(x)
577
+ elif self.global_pool == "avg":
578
+ x = x[:, self.num_prefix_tokens :].mean(dim=1)
579
+ elif self.global_pool:
580
+ x = x[:, 0] # class token
581
+ x = self.fc_norm(x)
582
+ x = self.head_drop(x)
583
+ return x if pre_logits else self.head(x)
584
+
585
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
586
+ x = self.forward_features(x)
587
+ if not self.ignore_head:
588
+ x = self.forward_head(x)
589
+ return x
590
+
591
+
592
+ @dataclass
593
+ class SigLIPVisionCfg:
594
+ width: int = 1152
595
+ layers: Union[Tuple[int, int, int, int], int] = 27
596
+ heads: int = 16
597
+ patch_size: int = 14
598
+ image_size: Union[Tuple[int, int], int] = 336
599
+ global_pool: str = "map"
600
+ mlp_ratio: float = 3.7362
601
+ class_token: bool = False
602
+ num_classes: int = 0
603
+ use_checkpoint: bool = False
604
+
605
+
606
+ SigLIP_MODEL_CONFIG = {
607
+ "siglip_so400m_patch14_384": {
608
+ "image_size": 336,
609
+ "patch_size": 14,
610
+ "width": 1152,
611
+ "layers": 27,
612
+ "heads": 16,
613
+ "mlp_ratio": 3.7362,
614
+ "global_pool": "map",
615
+ "use_checkpoint": False,
616
+ },
617
+ "siglip_so400m_patch14_224": {
618
+ "image_size": 224,
619
+ "patch_size": 14,
620
+ "width": 1152,
621
+ "layers": 27,
622
+ "heads": 16,
623
+ "mlp_ratio": 3.7362,
624
+ "global_pool": "map",
625
+ "use_checkpoint": False,
626
+ },
627
+ "siglip_large_patch16_384": {
628
+ "image_size": 384,
629
+ "patch_size": 16,
630
+ "width": 1024,
631
+ "layers": 24,
632
+ "heads": 16,
633
+ "mlp_ratio": 4,
634
+ "global_pool": "map",
635
+ "use_checkpoint": False,
636
+ },
637
+ }
638
+
639
+
640
+ def create_siglip_vit(
641
+ model_name: str = "siglip_so400m_patch14_384",
642
+ image_size: int = 384,
643
+ select_layer: int = -1,
644
+ ckpt_path: str = "",
645
+ **kwargs,
646
+ ):
647
+ assert (
648
+ model_name in SigLIP_MODEL_CONFIG.keys()
649
+ ), f"model name should be in {SigLIP_MODEL_CONFIG.keys()}"
650
+
651
+ vision_cfg = SigLIPVisionCfg(**SigLIP_MODEL_CONFIG[model_name])
652
+
653
+ if select_layer <= 0:
654
+ layers = min(vision_cfg.layers, vision_cfg.layers + select_layer + 1)
655
+ else:
656
+ layers = min(vision_cfg.layers, select_layer)
657
+
658
+ model = VisionTransformer(
659
+ img_size=image_size,
660
+ patch_size=vision_cfg.patch_size,
661
+ embed_dim=vision_cfg.width,
662
+ depth=layers,
663
+ num_heads=vision_cfg.heads,
664
+ mlp_ratio=vision_cfg.mlp_ratio,
665
+ class_token=vision_cfg.class_token,
666
+ global_pool=vision_cfg.global_pool,
667
+ ignore_head=kwargs.get("ignore_head", True),
668
+ weight_init=kwargs.get("weight_init", "skip"),
669
+ num_classes=0,
670
+ )
671
+
672
+ if ckpt_path:
673
+ state_dict = torch.load(ckpt_path, map_location="cpu")
674
+
675
+ incompatible_keys = model.load_state_dict(state_dict, strict=False)
676
+ print(
677
+ f"SigLIP-ViT restores from {ckpt_path},\n"
678
+ f"\tincompatible_keys:', {incompatible_keys}."
679
+ )
680
+
681
+ return model
deepseek_vl/utils/__init__.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023-2024 DeepSeek.
2
+ #
3
+ # Permission is hereby granted, free of charge, to any person obtaining a copy of
4
+ # this software and associated documentation files (the "Software"), to deal in
5
+ # the Software without restriction, including without limitation the rights to
6
+ # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
7
+ # the Software, and to permit persons to whom the Software is furnished to do so,
8
+ # subject to the following conditions:
9
+ #
10
+ # The above copyright notice and this permission notice shall be included in all
11
+ # copies or substantial portions of the Software.
12
+ #
13
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
15
+ # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
16
+ # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
17
+ # IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
18
+ # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
deepseek_vl/utils/conversation.py ADDED
@@ -0,0 +1,348 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023-2024 DeepSeek.
2
+ #
3
+ # Permission is hereby granted, free of charge, to any person obtaining a copy of
4
+ # this software and associated documentation files (the "Software"), to deal in
5
+ # the Software without restriction, including without limitation the rights to
6
+ # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
7
+ # the Software, and to permit persons to whom the Software is furnished to do so,
8
+ # subject to the following conditions:
9
+ #
10
+ # The above copyright notice and this permission notice shall be included in all
11
+ # copies or substantial portions of the Software.
12
+ #
13
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
15
+ # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
16
+ # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
17
+ # IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
18
+ # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
19
+
20
+ """
21
+ From https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
22
+ """
23
+
24
+ import dataclasses
25
+ from enum import IntEnum, auto
26
+ from typing import Dict, List
27
+
28
+
29
+ class SeparatorStyle(IntEnum):
30
+ """Separator styles."""
31
+
32
+ ADD_COLON_SINGLE = auto()
33
+ ADD_COLON_TWO = auto()
34
+ ADD_COLON_SPACE_SINGLE = auto()
35
+ NO_COLON_SINGLE = auto()
36
+ NO_COLON_TWO = auto()
37
+ ADD_NEW_LINE_SINGLE = auto()
38
+ LLAMA2 = auto()
39
+ CHATGLM = auto()
40
+ CHATML = auto()
41
+ CHATINTERN = auto()
42
+ DOLLY = auto()
43
+ RWKV = auto()
44
+ PHOENIX = auto()
45
+ ROBIN = auto()
46
+ DeepSeek = auto()
47
+ PLAIN = auto()
48
+ ALIGNMENT = auto()
49
+
50
+
51
+ @dataclasses.dataclass
52
+ class Conversation:
53
+ """A class that manages prompt templates and keeps all conversation history."""
54
+
55
+ # The name of this template
56
+ name: str
57
+ # The template of the system prompt
58
+ system_template: str = "{system_message}"
59
+ # The system message
60
+ system_message: str = ""
61
+ # The names of two roles
62
+ roles: List[str] = (("USER", "ASSISTANT"),)
63
+ # All messages. Each item is (role, message).
64
+ messages: List[List[str]] = ()
65
+ # The number of few shot examples
66
+ offset: int = 0
67
+ # The separator style and configurations
68
+ sep_style: SeparatorStyle = SeparatorStyle.ADD_COLON_SINGLE
69
+ sep: str = "\n"
70
+ sep2: str = None
71
+ # Stop criteria (the default one is EOS token)
72
+ stop_str: str = None
73
+ # Stops generation if meeting any token in this list
74
+ stop_token_ids: List[int] = None
75
+
76
+ def get_prompt(self) -> str:
77
+ """Get the prompt for generation."""
78
+ system_prompt = self.system_template.format(system_message=self.system_message)
79
+
80
+ if self.sep_style == SeparatorStyle.DeepSeek:
81
+ seps = [self.sep, self.sep2]
82
+ if system_prompt == "" or system_prompt is None:
83
+ ret = ""
84
+ else:
85
+ ret = system_prompt + seps[0]
86
+ for i, (role, message) in enumerate(self.messages):
87
+ if message:
88
+ ret += role + ": " + message + seps[i % 2]
89
+ else:
90
+ ret += role + ":"
91
+ return ret
92
+ elif self.sep_style == SeparatorStyle.LLAMA2:
93
+ seps = [self.sep, self.sep2]
94
+ if self.system_message:
95
+ ret = system_prompt
96
+ else:
97
+ ret = "[INST] "
98
+ for i, (role, message) in enumerate(self.messages):
99
+ tag = self.roles[i % 2]
100
+ if message:
101
+ if type(message) is tuple: # multimodal message
102
+ message, _ = message
103
+ if i == 0:
104
+ ret += message + " "
105
+ else:
106
+ ret += tag + " " + message + seps[i % 2]
107
+ else:
108
+ ret += tag
109
+ return ret
110
+ elif self.sep_style == SeparatorStyle.PLAIN:
111
+ seps = [self.sep, self.sep2]
112
+ ret = ""
113
+ for i, (role, message) in enumerate(self.messages):
114
+ if message:
115
+ if type(message) is tuple:
116
+ message, _, _ = message
117
+ if i % 2 == 0:
118
+ ret += message + seps[i % 2]
119
+ else:
120
+ ret += message + seps[i % 2]
121
+ else:
122
+ ret += ""
123
+ return ret
124
+ elif self.sep_style == SeparatorStyle.ALIGNMENT:
125
+ seps = [self.sep, self.sep2]
126
+ ret = ""
127
+ for i, (role, message) in enumerate(self.messages):
128
+ if message:
129
+ if type(message) is tuple:
130
+ message, _, _ = message
131
+ if i % 2 == 0:
132
+ ret += "<image>\n" + seps[i % 2]
133
+ else:
134
+ ret += message + seps[i % 2]
135
+ else:
136
+ ret += ""
137
+ return ret
138
+ else:
139
+ raise ValueError(f"Invalid style: {self.sep_style}")
140
+
141
+ def get_prompt_for_current_round(self, content=None):
142
+ """Get current round formatted question prompt during sft training"""
143
+ if self.sep_style == SeparatorStyle.PLAIN:
144
+ formatted_question = "<image>\n"
145
+ elif self.sep_style == SeparatorStyle.DeepSeek:
146
+ formatted_question = (
147
+ f"{self.roles[0]}: " + content.strip() + self.sep + f"{self.roles[1]}:"
148
+ )
149
+ else:
150
+ raise ValueError(f"Unsupported sep_style: {self.sep_style}")
151
+ return formatted_question
152
+
153
+ def set_system_message(self, system_message: str):
154
+ """Set the system message."""
155
+ self.system_message = system_message
156
+
157
+ def append_message(self, role: str, message: str):
158
+ """Append a new message."""
159
+ self.messages.append([role, message])
160
+
161
+ def reset_message(self):
162
+ """Reset a new message."""
163
+ self.messages = []
164
+
165
+ def update_last_message(self, message: str):
166
+ """Update the last output.
167
+
168
+ The last message is typically set to be None when constructing the prompt,
169
+ so we need to update it in-place after getting the response from a model.
170
+ """
171
+ self.messages[-1][1] = message
172
+
173
+ def to_gradio_chatbot(self):
174
+ """Convert the conversation to gradio chatbot format."""
175
+ ret = []
176
+ for i, (role, msg) in enumerate(self.messages[self.offset :]):
177
+ if i % 2 == 0:
178
+ ret.append([msg, None])
179
+ else:
180
+ ret[-1][-1] = msg
181
+ return ret
182
+
183
+ def to_openai_api_messages(self):
184
+ """Convert the conversation to OpenAI chat completion format."""
185
+ system_prompt = self.system_template.format(system_message=self.system_message)
186
+ ret = [{"role": "system", "content": system_prompt}]
187
+
188
+ for i, (_, msg) in enumerate(self.messages[self.offset :]):
189
+ if i % 2 == 0:
190
+ ret.append({"role": "user", "content": msg})
191
+ else:
192
+ if msg is not None:
193
+ ret.append({"role": "assistant", "content": msg})
194
+ return ret
195
+
196
+ def copy(self):
197
+ return Conversation(
198
+ name=self.name,
199
+ system_template=self.system_template,
200
+ system_message=self.system_message,
201
+ roles=self.roles,
202
+ messages=[[x, y] for x, y in self.messages],
203
+ offset=self.offset,
204
+ sep_style=self.sep_style,
205
+ sep=self.sep,
206
+ sep2=self.sep2,
207
+ stop_str=self.stop_str,
208
+ stop_token_ids=self.stop_token_ids,
209
+ )
210
+
211
+ def dict(self):
212
+ return {
213
+ "template_name": self.name,
214
+ "system_message": self.system_message,
215
+ "roles": self.roles,
216
+ "messages": self.messages,
217
+ "offset": self.offset,
218
+ }
219
+
220
+
221
+ # A global registry for all conversation templates
222
+ conv_templates: Dict[str, Conversation] = {}
223
+
224
+
225
+ def register_conv_template(template: Conversation, override: bool = False):
226
+ """Register a new conversation template."""
227
+ if not override:
228
+ assert (
229
+ template.name not in conv_templates
230
+ ), f"{template.name} has been registered."
231
+
232
+ conv_templates[template.name] = template
233
+
234
+
235
+ def get_conv_template(name: str) -> Conversation:
236
+ """Get a conversation template."""
237
+ return conv_templates[name].copy()
238
+
239
+
240
+ # llava_llama2 template
241
+ register_conv_template(
242
+ Conversation(
243
+ name="llava_llama2",
244
+ system_message="You are a helpful language and vision assistant. "
245
+ "You are able to understand the visual content that the user provides, "
246
+ "and assist the user with a variety of tasks using natural language.",
247
+ system_template="[INST] <<SYS>>\n{system_message}\n<</SYS>>\n\n",
248
+ roles=("[INST]", "[/INST]"),
249
+ messages=(),
250
+ offset=0,
251
+ sep_style=SeparatorStyle.LLAMA2,
252
+ sep=" ",
253
+ sep2=" </s><s>",
254
+ stop_token_ids=[2],
255
+ )
256
+ )
257
+
258
+ # llama2 template
259
+ # reference: https://github.com/facebookresearch/llama/blob/cfc3fc8c1968d390eb830e65c63865e980873a06/llama/generation.py#L212
260
+ register_conv_template(
261
+ Conversation(
262
+ name="llama-2",
263
+ system_template="[INST] <<SYS>>\n{system_message}\n<</SYS>>\n\n",
264
+ roles=("[INST]", "[/INST]"),
265
+ messages=(),
266
+ offset=0,
267
+ sep_style=SeparatorStyle.LLAMA2,
268
+ sep=" ",
269
+ sep2=" </s><s>",
270
+ stop_token_ids=[2],
271
+ )
272
+ )
273
+
274
+
275
+ # deepseek template
276
+ register_conv_template(
277
+ Conversation(
278
+ name="deepseek",
279
+ system_template="{system_message}",
280
+ # system_message="You are a helpful assistant. Please answer truthfully and write out your "
281
+ # "thinking step by step to be sure you get the right answer.",
282
+ system_message="",
283
+ roles=("User", "Assistant"),
284
+ messages=(),
285
+ offset=0,
286
+ sep_style=SeparatorStyle.DeepSeek,
287
+ sep="\n\n",
288
+ sep2="<|end▁of▁sentence|>",
289
+ stop_token_ids=[100001],
290
+ stop_str=["User:", "<|end▁of▁sentence|>"],
291
+ )
292
+ )
293
+
294
+ register_conv_template(
295
+ Conversation(
296
+ name="plain",
297
+ system_template="",
298
+ system_message="",
299
+ roles=("", ""),
300
+ messages=(),
301
+ offset=0,
302
+ sep_style=SeparatorStyle.PLAIN,
303
+ sep="",
304
+ sep2="",
305
+ stop_token_ids=[2],
306
+ stop_str=["</s>"],
307
+ )
308
+ )
309
+
310
+
311
+ register_conv_template(
312
+ Conversation(
313
+ name="alignment",
314
+ system_template="",
315
+ system_message="",
316
+ roles=("", ""),
317
+ messages=(),
318
+ offset=0,
319
+ sep_style=SeparatorStyle.ALIGNMENT,
320
+ sep="",
321
+ sep2="",
322
+ stop_token_ids=[2],
323
+ stop_str=["</s>"],
324
+ )
325
+ )
326
+
327
+
328
+ if __name__ == "__main__":
329
+ # print("Llama-2 template:")
330
+ # conv = get_conv_template("llama-2")
331
+ # conv.set_system_message("You are a helpful, respectful and honest assistant.")
332
+ # conv.append_message(conv.roles[0], "Hello!")
333
+ # conv.append_message(conv.roles[1], "Hi!")
334
+ # conv.append_message(conv.roles[0], "How are you?")
335
+ # conv.append_message(conv.roles[1], None)
336
+ # print(conv.get_prompt())
337
+
338
+ # print("\n")
339
+
340
+ print("deepseek template:")
341
+ conv = get_conv_template("deepseek")
342
+ conv.append_message(conv.roles[0], "Hello!")
343
+ conv.append_message(conv.roles[1], "Hi! This is Tony.")
344
+ conv.append_message(conv.roles[0], "Who are you?")
345
+ conv.append_message(conv.roles[1], "I am a helpful assistant.")
346
+ conv.append_message(conv.roles[0], "How are you?")
347
+ conv.append_message(conv.roles[1], None)
348
+ print(conv.get_prompt())
deepseek_vl/utils/io.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023-2024 DeepSeek.
2
+ #
3
+ # Permission is hereby granted, free of charge, to any person obtaining a copy of
4
+ # this software and associated documentation files (the "Software"), to deal in
5
+ # the Software without restriction, including without limitation the rights to
6
+ # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
7
+ # the Software, and to permit persons to whom the Software is furnished to do so,
8
+ # subject to the following conditions:
9
+ #
10
+ # The above copyright notice and this permission notice shall be included in all
11
+ # copies or substantial portions of the Software.
12
+ #
13
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
15
+ # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
16
+ # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
17
+ # IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
18
+ # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
19
+
20
+ import json
21
+ from typing import Dict, List
22
+
23
+ import PIL.Image
24
+ import torch
25
+ from transformers import AutoModelForCausalLM
26
+
27
+ from deepseek_vl.models import MultiModalityCausalLM, VLChatProcessor
28
+
29
+
30
+ def load_pretrained_model(model_path: str):
31
+ vl_chat_processor: VLChatProcessor = VLChatProcessor.from_pretrained(model_path)
32
+ tokenizer = vl_chat_processor.tokenizer
33
+
34
+ vl_gpt: MultiModalityCausalLM = AutoModelForCausalLM.from_pretrained(
35
+ model_path, trust_remote_code=True
36
+ )
37
+ vl_gpt = vl_gpt.to(torch.bfloat16).cuda().eval()
38
+
39
+ return tokenizer, vl_chat_processor, vl_gpt
40
+
41
+
42
+ def load_pil_images(conversations: List[Dict[str, str]]) -> List[PIL.Image.Image]:
43
+ """
44
+
45
+ Args:
46
+ conversations (List[Dict[str, str]]): the conversations with a list of messages. An example is :
47
+ [
48
+ {
49
+ "role": "User",
50
+ "content": "<image_placeholder>\nExtract all information from this image and convert them into markdown format.",
51
+ "images": ["./examples/table_datasets.png"]
52
+ },
53
+ {"role": "Assistant", "content": ""},
54
+ ]
55
+
56
+ Returns:
57
+ pil_images (List[PIL.Image.Image]): the list of PIL images.
58
+
59
+ """
60
+
61
+ pil_images = []
62
+
63
+ for message in conversations:
64
+ if "images" not in message:
65
+ continue
66
+
67
+ for image_path in message["images"]:
68
+ pil_img = PIL.Image.open(image_path)
69
+ pil_img = pil_img.convert("RGB")
70
+ pil_images.append(pil_img)
71
+
72
+ return pil_images
73
+
74
+
75
+ def load_json(filepath):
76
+ with open(filepath, "r") as f:
77
+ data = json.load(f)
78
+ return data
examples/app.png ADDED
examples/chart.png ADDED
examples/mirror.png ADDED
examples/pipeline.png ADDED
examples/puzzle.png ADDED
examples/rap.jpeg ADDED
inference.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023-2024 DeepSeek.
2
+ #
3
+ # Permission is hereby granted, free of charge, to any person obtaining a copy of
4
+ # this software and associated documentation files (the "Software"), to deal in
5
+ # the Software without restriction, including without limitation the rights to
6
+ # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
7
+ # the Software, and to permit persons to whom the Software is furnished to do so,
8
+ # subject to the following conditions:
9
+ #
10
+ # The above copyright notice and this permission notice shall be included in all
11
+ # copies or substantial portions of the Software.
12
+ #
13
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
15
+ # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
16
+ # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
17
+ # IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
18
+ # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
19
+
20
+ from threading import Thread
21
+ from typing import List
22
+
23
+ import torch
24
+ import transformers
25
+ from transformers import (
26
+ AutoModelForCausalLM,
27
+ StoppingCriteria,
28
+ StoppingCriteriaList,
29
+ TextIteratorStreamer,
30
+ )
31
+
32
+ from deepseek_vl.models import MultiModalityCausalLM, VLChatProcessor
33
+ from deepseek_vl.utils.conversation import Conversation
34
+
35
+
36
+ def load_model(model_path):
37
+ vl_chat_processor: VLChatProcessor = VLChatProcessor.from_pretrained(model_path)
38
+ tokenizer = vl_chat_processor.tokenizer
39
+ vl_gpt: MultiModalityCausalLM = AutoModelForCausalLM.from_pretrained(
40
+ model_path, trust_remote_code=True
41
+ )
42
+ vl_gpt = vl_gpt.to(torch.bfloat16).cuda().eval()
43
+ return tokenizer, vl_gpt, vl_chat_processor
44
+
45
+
46
+ def convert_conversation_to_prompts(conversation: Conversation):
47
+ prompts = []
48
+ messages = conversation.messages
49
+
50
+ for i in range(0, len(messages), 2):
51
+ prompt = {
52
+ "role": messages[i][0],
53
+ "content": (
54
+ messages[i][1][0]
55
+ if isinstance(messages[i][1], tuple)
56
+ else messages[i][1]
57
+ ),
58
+ "images": [messages[i][1][1]] if isinstance(messages[i][1], tuple) else [],
59
+ }
60
+ response = {"role": messages[i + 1][0], "content": messages[i + 1][1]}
61
+ prompts.extend([prompt, response])
62
+
63
+ return prompts
64
+
65
+
66
+ class StoppingCriteriaSub(StoppingCriteria):
67
+ def __init__(self, stops=[], encounters=1):
68
+ super().__init__()
69
+ self.stops = [stop.to("cuda") for stop in stops]
70
+
71
+ def __call__(
72
+ self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
73
+ ):
74
+ for stop in self.stops:
75
+ if input_ids.shape[-1] < len(stop):
76
+ continue
77
+ if torch.all((stop == input_ids[0][-len(stop) :])).item():
78
+ return True
79
+
80
+ return False
81
+
82
+
83
+ @torch.inference_mode()
84
+ def deepseek_generate(
85
+ prompts: list,
86
+ vl_gpt: torch.nn.Module,
87
+ vl_chat_processor,
88
+ tokenizer: transformers.PreTrainedTokenizer,
89
+ stop_words: list,
90
+ max_length: int = 256,
91
+ temperature: float = 1.0,
92
+ top_p: float = 1.0,
93
+ repetition_penalty=1.1,
94
+ ):
95
+ prompts = prompts
96
+ pil_images = list()
97
+ for message in prompts:
98
+ if "images" not in message:
99
+ continue
100
+ for pil_img in message["images"]:
101
+ pil_images.append(pil_img)
102
+
103
+ prepare_inputs = vl_chat_processor(
104
+ conversations=prompts, images=pil_images, force_batchify=True
105
+ ).to(vl_gpt.device)
106
+
107
+ return generate(
108
+ vl_gpt,
109
+ tokenizer,
110
+ prepare_inputs,
111
+ max_length,
112
+ temperature,
113
+ repetition_penalty,
114
+ top_p,
115
+ stop_words,
116
+ )
117
+
118
+
119
+ @torch.inference_mode()
120
+ def generate(
121
+ vl_gpt,
122
+ tokenizer,
123
+ prepare_inputs,
124
+ max_gen_len: int = 256,
125
+ temperature: float = 0,
126
+ repetition_penalty=1.1,
127
+ top_p: float = 0.95,
128
+ stop_words: List[str] = [],
129
+ ):
130
+ """Stream the text output from the multimodality model with prompt and image inputs."""
131
+ inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs)
132
+
133
+ streamer = TextIteratorStreamer(tokenizer)
134
+
135
+ stop_words_ids = [
136
+ torch.tensor(tokenizer.encode(stop_word)) for stop_word in stop_words
137
+ ]
138
+ stopping_criteria = StoppingCriteriaList(
139
+ [StoppingCriteriaSub(stops=stop_words_ids)]
140
+ )
141
+
142
+ generation_config = dict(
143
+ inputs_embeds=inputs_embeds,
144
+ attention_mask=prepare_inputs.attention_mask,
145
+ pad_token_id=tokenizer.eos_token_id,
146
+ bos_token_id=tokenizer.bos_token_id,
147
+ eos_token_id=tokenizer.eos_token_id,
148
+ max_new_tokens=max_gen_len,
149
+ do_sample=True,
150
+ use_cache=True,
151
+ streamer=streamer,
152
+ stopping_criteria=stopping_criteria,
153
+ )
154
+
155
+ if temperature > 0:
156
+ generation_config.update(
157
+ {
158
+ "do_sample": True,
159
+ "top_p": top_p,
160
+ "temperature": temperature,
161
+ "repetition_penalty": repetition_penalty,
162
+ }
163
+ )
164
+ else:
165
+ generation_config["do_sample"] = False
166
+
167
+ thread = Thread(target=vl_gpt.language_model.generate, kwargs=generation_config)
168
+ thread.start()
169
+
170
+ yield from streamer
requirements.txt ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch==2.0.1
2
+ transformers>=4.38.2
3
+ timm>=0.9.16
4
+ accelerate
5
+ sentencepiece
6
+ attrdict
7
+ einops
8
+
9
+ # for gradio demo
10
+ gradio==3.48.0
11
+ gradio-client==0.6.1
12
+ mdtex2html==1.3.0
13
+ pypinyin==0.50.0
14
+ tiktoken==0.5.2
15
+ tqdm==4.64.0
16
+ colorama==0.4.5
17
+ Pygments==2.12.0
18
+ markdown==3.4.1
19
+ SentencePiece==0.1.96