Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
79cf446
0
Parent(s):
chore: rebase commits
Browse files- .gitattributes +35 -0
- README.md +13 -0
- app.py +511 -0
- app_modules/conversation.py +348 -0
- app_modules/gradio_utils.py +94 -0
- app_modules/overwrites.py +81 -0
- app_modules/presets.py +96 -0
- app_modules/utils.py +228 -0
- assets/Kelpy-Codos.js +100 -0
- assets/avatar.png +0 -0
- assets/custom.css +355 -0
- assets/custom.js +22 -0
- assets/favicon.ico +0 -0
- deepseek_vl/__init__.py +31 -0
- deepseek_vl/models/__init__.py +28 -0
- deepseek_vl/models/clip_encoder.py +242 -0
- deepseek_vl/models/image_processing_vlm.py +208 -0
- deepseek_vl/models/modeling_vlm.py +170 -0
- deepseek_vl/models/processing_vlm.py +390 -0
- deepseek_vl/models/projector.py +100 -0
- deepseek_vl/models/sam.py +593 -0
- deepseek_vl/models/siglip_vit.py +681 -0
- deepseek_vl/utils/__init__.py +18 -0
- deepseek_vl/utils/conversation.py +348 -0
- deepseek_vl/utils/io.py +78 -0
- examples/app.png +0 -0
- examples/chart.png +0 -0
- examples/mirror.png +0 -0
- examples/pipeline.png +0 -0
- examples/puzzle.png +0 -0
- examples/rap.jpeg +0 -0
- inference.py +170 -0
- requirements.txt +19 -0
.gitattributes
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
README.md
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: Chat with DeepSeek VL 7B
|
3 |
+
emoji: 🐬
|
4 |
+
colorFrom: blue
|
5 |
+
colorTo: red
|
6 |
+
sdk: gradio
|
7 |
+
sdk_version: 4.21.0
|
8 |
+
app_file: app.py
|
9 |
+
pinned: false
|
10 |
+
license: mit
|
11 |
+
---
|
12 |
+
|
13 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
ADDED
@@ -0,0 +1,511 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023-2024 DeepSeek.
|
2 |
+
#
|
3 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy of
|
4 |
+
# this software and associated documentation files (the "Software"), to deal in
|
5 |
+
# the Software without restriction, including without limitation the rights to
|
6 |
+
# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
|
7 |
+
# the Software, and to permit persons to whom the Software is furnished to do so,
|
8 |
+
# subject to the following conditions:
|
9 |
+
#
|
10 |
+
# The above copyright notice and this permission notice shall be included in all
|
11 |
+
# copies or substantial portions of the Software.
|
12 |
+
#
|
13 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
14 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
|
15 |
+
# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
|
16 |
+
# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
|
17 |
+
# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
|
18 |
+
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
19 |
+
|
20 |
+
# -*- coding:utf-8 -*-
|
21 |
+
|
22 |
+
import base64
|
23 |
+
from io import BytesIO
|
24 |
+
|
25 |
+
import gradio as gr
|
26 |
+
import torch
|
27 |
+
from app_modules.gradio_utils import (
|
28 |
+
cancel_outputing,
|
29 |
+
delete_last_conversation,
|
30 |
+
reset_state,
|
31 |
+
reset_textbox,
|
32 |
+
transfer_input,
|
33 |
+
wrap_gen_fn,
|
34 |
+
)
|
35 |
+
from app_modules.overwrites import reload_javascript
|
36 |
+
from app_modules.presets import CONCURRENT_COUNT, description, description_top, title
|
37 |
+
from app_modules.utils import configure_logger, is_variable_assigned, strip_stop_words
|
38 |
+
|
39 |
+
from inference import (
|
40 |
+
convert_conversation_to_prompts,
|
41 |
+
deepseek_generate,
|
42 |
+
load_model,
|
43 |
+
)
|
44 |
+
from app_modules.conversation import SeparatorStyle
|
45 |
+
|
46 |
+
|
47 |
+
def load_models():
|
48 |
+
models = {
|
49 |
+
"DeepSeek-VL 7B": "deepseek-ai/deepseek-vl-7b-chat",
|
50 |
+
}
|
51 |
+
|
52 |
+
for model_name in models:
|
53 |
+
models[model_name] = load_model(models[model_name])
|
54 |
+
|
55 |
+
return models
|
56 |
+
|
57 |
+
|
58 |
+
logger = configure_logger()
|
59 |
+
models = load_models()
|
60 |
+
MODELS = sorted(list(models.keys()))
|
61 |
+
|
62 |
+
|
63 |
+
def generate_prompt_with_history(
|
64 |
+
text, image, history, vl_chat_processor, tokenizer, max_length=2048
|
65 |
+
):
|
66 |
+
"""
|
67 |
+
Generate a prompt with history for the deepseek application.
|
68 |
+
|
69 |
+
Args:
|
70 |
+
text (str): The text prompt.
|
71 |
+
image (str): The image prompt.
|
72 |
+
history (list): List of previous conversation messages.
|
73 |
+
tokenizer: The tokenizer used for encoding the prompt.
|
74 |
+
max_length (int): The maximum length of the prompt.
|
75 |
+
|
76 |
+
Returns:
|
77 |
+
tuple: A tuple containing the generated prompt, image list, conversation, and conversation copy. If the prompt could not be generated within the max_length limit, returns None.
|
78 |
+
"""
|
79 |
+
|
80 |
+
sft_format = "deepseek"
|
81 |
+
user_role_ind = 0
|
82 |
+
bot_role_ind = 1
|
83 |
+
|
84 |
+
# Initialize conversation
|
85 |
+
conversation = vl_chat_processor.new_chat_template()
|
86 |
+
|
87 |
+
if history:
|
88 |
+
conversation.messages = history
|
89 |
+
|
90 |
+
if image is not None:
|
91 |
+
if "<image_placeholder>" not in text:
|
92 |
+
text = (
|
93 |
+
"<image_placeholder>" + "\n" + text
|
94 |
+
) # append the <image_placeholder> in a new line after the text prompt
|
95 |
+
text = (text, image)
|
96 |
+
|
97 |
+
conversation.append_message(conversation.roles[user_role_ind], text)
|
98 |
+
conversation.append_message(conversation.roles[bot_role_ind], "")
|
99 |
+
|
100 |
+
# Create a copy of the conversation to avoid history truncation in the UI
|
101 |
+
conversation_copy = conversation.copy()
|
102 |
+
logger.info("=" * 80)
|
103 |
+
logger.info(get_prompt(conversation))
|
104 |
+
|
105 |
+
rounds = len(conversation.messages) // 2
|
106 |
+
|
107 |
+
for _ in range(rounds):
|
108 |
+
current_prompt = get_prompt(conversation)
|
109 |
+
current_prompt = (
|
110 |
+
current_prompt.replace("</s>", "")
|
111 |
+
if sft_format == "deepseek"
|
112 |
+
else current_prompt
|
113 |
+
)
|
114 |
+
|
115 |
+
if torch.tensor(tokenizer.encode(current_prompt)).size(-1) <= max_length:
|
116 |
+
return conversation_copy
|
117 |
+
|
118 |
+
if len(conversation.messages) % 2 != 0:
|
119 |
+
gr.Error("The messages between user and assistant are not paired.")
|
120 |
+
return
|
121 |
+
|
122 |
+
try:
|
123 |
+
for _ in range(2): # pop out two messages in a row
|
124 |
+
conversation.messages.pop(0)
|
125 |
+
except IndexError:
|
126 |
+
gr.Error("Input text processing failed, unable to respond in this round.")
|
127 |
+
return None
|
128 |
+
|
129 |
+
gr.Error("Prompt could not be generated within max_length limit.")
|
130 |
+
return None
|
131 |
+
|
132 |
+
|
133 |
+
def to_gradio_chatbot(conv):
|
134 |
+
"""Convert the conversation to gradio chatbot format."""
|
135 |
+
ret = []
|
136 |
+
for i, (role, msg) in enumerate(conv.messages[conv.offset :]):
|
137 |
+
if i % 2 == 0:
|
138 |
+
if type(msg) is tuple:
|
139 |
+
msg, image = msg
|
140 |
+
if isinstance(image, str):
|
141 |
+
with open(image, "rb") as f:
|
142 |
+
data = f.read()
|
143 |
+
img_b64_str = base64.b64encode(data).decode()
|
144 |
+
image_str = f'<video src="data:video/mp4;base64,{img_b64_str}" controls width="426" height="240"></video>'
|
145 |
+
msg = msg.replace("\n".join(["<image_placeholder>"] * 4), image_str)
|
146 |
+
else:
|
147 |
+
max_hw, min_hw = max(image.size), min(image.size)
|
148 |
+
aspect_ratio = max_hw / min_hw
|
149 |
+
max_len, min_len = 800, 400
|
150 |
+
shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
|
151 |
+
longest_edge = int(shortest_edge * aspect_ratio)
|
152 |
+
W, H = image.size
|
153 |
+
if H > W:
|
154 |
+
H, W = longest_edge, shortest_edge
|
155 |
+
else:
|
156 |
+
H, W = shortest_edge, longest_edge
|
157 |
+
image = image.resize((W, H))
|
158 |
+
buffered = BytesIO()
|
159 |
+
image.save(buffered, format="JPEG")
|
160 |
+
img_b64_str = base64.b64encode(buffered.getvalue()).decode()
|
161 |
+
img_str = f'<img src="data:image/png;base64,{img_b64_str}" alt="user upload image" />'
|
162 |
+
msg = msg.replace("<image_placeholder>", img_str)
|
163 |
+
ret.append([msg, None])
|
164 |
+
else:
|
165 |
+
ret[-1][-1] = msg
|
166 |
+
return ret
|
167 |
+
|
168 |
+
|
169 |
+
def to_gradio_history(conv):
|
170 |
+
"""Convert the conversation to gradio history state."""
|
171 |
+
return conv.messages[conv.offset :]
|
172 |
+
|
173 |
+
|
174 |
+
def get_prompt(conv) -> str:
|
175 |
+
"""Get the prompt for generation."""
|
176 |
+
system_prompt = conv.system_template.format(system_message=conv.system_message)
|
177 |
+
if conv.sep_style == SeparatorStyle.DeepSeek:
|
178 |
+
seps = [conv.sep, conv.sep2]
|
179 |
+
if system_prompt == "" or system_prompt is None:
|
180 |
+
ret = ""
|
181 |
+
else:
|
182 |
+
ret = system_prompt + seps[0]
|
183 |
+
for i, (role, message) in enumerate(conv.messages):
|
184 |
+
if message:
|
185 |
+
if type(message) is tuple: # multimodal message
|
186 |
+
message, _ = message
|
187 |
+
ret += role + ": " + message + seps[i % 2]
|
188 |
+
else:
|
189 |
+
ret += role + ":"
|
190 |
+
return ret
|
191 |
+
else:
|
192 |
+
return conv.get_prompt
|
193 |
+
|
194 |
+
|
195 |
+
@wrap_gen_fn
|
196 |
+
def predict(
|
197 |
+
text,
|
198 |
+
image,
|
199 |
+
chatbot,
|
200 |
+
history,
|
201 |
+
top_p,
|
202 |
+
temperature,
|
203 |
+
repetition_penalty,
|
204 |
+
max_length_tokens,
|
205 |
+
max_context_length_tokens,
|
206 |
+
model_select_dropdown,
|
207 |
+
):
|
208 |
+
"""
|
209 |
+
Function to predict the response based on the user's input and selected model.
|
210 |
+
|
211 |
+
Parameters:
|
212 |
+
user_text (str): The input text from the user.
|
213 |
+
user_image (str): The input image from the user.
|
214 |
+
chatbot (str): The chatbot's name.
|
215 |
+
history (str): The history of the chat.
|
216 |
+
top_p (float): The top-p parameter for the model.
|
217 |
+
temperature (float): The temperature parameter for the model.
|
218 |
+
max_length_tokens (int): The maximum length of tokens for the model.
|
219 |
+
max_context_length_tokens (int): The maximum length of context tokens for the model.
|
220 |
+
model_select_dropdown (str): The selected model from the dropdown.
|
221 |
+
|
222 |
+
Returns:
|
223 |
+
generator: A generator that yields the chatbot outputs, history, and status.
|
224 |
+
"""
|
225 |
+
print("running the prediction function")
|
226 |
+
try:
|
227 |
+
tokenizer, vl_gpt, vl_chat_processor = models[model_select_dropdown]
|
228 |
+
|
229 |
+
if text == "":
|
230 |
+
yield chatbot, history, "Empty context."
|
231 |
+
return
|
232 |
+
except KeyError:
|
233 |
+
yield [[text, "No Model Found"]], [], "No Model Found"
|
234 |
+
return
|
235 |
+
|
236 |
+
conversation = generate_prompt_with_history(
|
237 |
+
text,
|
238 |
+
image,
|
239 |
+
history,
|
240 |
+
vl_chat_processor,
|
241 |
+
tokenizer,
|
242 |
+
max_length=max_context_length_tokens,
|
243 |
+
)
|
244 |
+
prompts = convert_conversation_to_prompts(conversation)
|
245 |
+
|
246 |
+
stop_words = conversation.stop_str
|
247 |
+
gradio_chatbot_output = to_gradio_chatbot(conversation)
|
248 |
+
|
249 |
+
full_response = ""
|
250 |
+
with torch.no_grad():
|
251 |
+
for x in deepseek_generate(
|
252 |
+
prompts=prompts,
|
253 |
+
vl_gpt=vl_gpt,
|
254 |
+
vl_chat_processor=vl_chat_processor,
|
255 |
+
tokenizer=tokenizer,
|
256 |
+
stop_words=stop_words,
|
257 |
+
max_length=max_length_tokens,
|
258 |
+
temperature=temperature,
|
259 |
+
repetition_penalty=repetition_penalty,
|
260 |
+
top_p=top_p,
|
261 |
+
):
|
262 |
+
full_response += x
|
263 |
+
response = strip_stop_words(full_response, stop_words)
|
264 |
+
conversation.update_last_message(response)
|
265 |
+
gradio_chatbot_output[-1][1] = response
|
266 |
+
yield gradio_chatbot_output, to_gradio_history(
|
267 |
+
conversation
|
268 |
+
), "Generating..."
|
269 |
+
|
270 |
+
print("flushed result to gradio")
|
271 |
+
torch.cuda.empty_cache()
|
272 |
+
|
273 |
+
if is_variable_assigned("x"):
|
274 |
+
print(f"{model_select_dropdown}:\n{text}\n{'-' * 80}\n{x}\n{'=' * 80}")
|
275 |
+
print(
|
276 |
+
f"temperature: {temperature}, top_p: {top_p}, repetition_penalty: {repetition_penalty}, max_length_tokens: {max_length_tokens}"
|
277 |
+
)
|
278 |
+
|
279 |
+
yield gradio_chatbot_output, to_gradio_history(conversation), "Generate: Success"
|
280 |
+
|
281 |
+
|
282 |
+
def retry(
|
283 |
+
text,
|
284 |
+
image,
|
285 |
+
chatbot,
|
286 |
+
history,
|
287 |
+
top_p,
|
288 |
+
temperature,
|
289 |
+
repetition_penalty,
|
290 |
+
max_length_tokens,
|
291 |
+
max_context_length_tokens,
|
292 |
+
model_select_dropdown,
|
293 |
+
):
|
294 |
+
if len(history) == 0:
|
295 |
+
yield (chatbot, history, "Empty context")
|
296 |
+
return
|
297 |
+
|
298 |
+
chatbot.pop()
|
299 |
+
history.pop()
|
300 |
+
text = history.pop()[-1]
|
301 |
+
if type(text) is tuple:
|
302 |
+
text, image = text
|
303 |
+
|
304 |
+
yield from predict(
|
305 |
+
text,
|
306 |
+
image,
|
307 |
+
chatbot,
|
308 |
+
history,
|
309 |
+
top_p,
|
310 |
+
temperature,
|
311 |
+
repetition_penalty,
|
312 |
+
max_length_tokens,
|
313 |
+
max_context_length_tokens,
|
314 |
+
model_select_dropdown,
|
315 |
+
)
|
316 |
+
|
317 |
+
|
318 |
+
def build_demo(MODELS):
|
319 |
+
with open("assets/custom.css", "r", encoding="utf-8") as f:
|
320 |
+
customCSS = f.read()
|
321 |
+
|
322 |
+
with gr.Blocks(theme=gr.themes.Soft(spacing_size="md")) as demo:
|
323 |
+
history = gr.State([])
|
324 |
+
input_text = gr.State()
|
325 |
+
input_image = gr.State()
|
326 |
+
|
327 |
+
with gr.Row():
|
328 |
+
gr.HTML(title)
|
329 |
+
status_display = gr.Markdown("Success", elem_id="status_display")
|
330 |
+
gr.Markdown(description_top)
|
331 |
+
|
332 |
+
with gr.Row(equal_height=True):
|
333 |
+
with gr.Column(scale=4):
|
334 |
+
with gr.Row():
|
335 |
+
chatbot = gr.Chatbot(
|
336 |
+
elem_id="deepseek_chatbot",
|
337 |
+
show_share_button=True,
|
338 |
+
likeable=True,
|
339 |
+
bubble_full_width=False,
|
340 |
+
height=600,
|
341 |
+
)
|
342 |
+
with gr.Row():
|
343 |
+
with gr.Column(scale=4):
|
344 |
+
text_box = gr.Textbox(
|
345 |
+
show_label=False, placeholder="Enter text", container=False
|
346 |
+
)
|
347 |
+
with gr.Column(
|
348 |
+
min_width=70,
|
349 |
+
):
|
350 |
+
submitBtn = gr.Button("Send")
|
351 |
+
with gr.Column(
|
352 |
+
min_width=70,
|
353 |
+
):
|
354 |
+
cancelBtn = gr.Button("Stop")
|
355 |
+
with gr.Row():
|
356 |
+
emptyBtn = gr.Button(
|
357 |
+
"🧹 New Conversation",
|
358 |
+
)
|
359 |
+
retryBtn = gr.Button("🔄 Regenerate")
|
360 |
+
delLastBtn = gr.Button("🗑️ Remove Last Turn")
|
361 |
+
|
362 |
+
with gr.Column():
|
363 |
+
image_box = gr.Image(type="pil")
|
364 |
+
|
365 |
+
with gr.Tab(label="Parameter Setting") as parameter_row:
|
366 |
+
top_p = gr.Slider(
|
367 |
+
minimum=-0,
|
368 |
+
maximum=1.0,
|
369 |
+
value=0.95,
|
370 |
+
step=0.05,
|
371 |
+
interactive=True,
|
372 |
+
label="Top-p",
|
373 |
+
)
|
374 |
+
temperature = gr.Slider(
|
375 |
+
minimum=0,
|
376 |
+
maximum=1.0,
|
377 |
+
value=0.1,
|
378 |
+
step=0.1,
|
379 |
+
interactive=True,
|
380 |
+
label="Temperature",
|
381 |
+
)
|
382 |
+
repetition_penalty = gr.Slider(
|
383 |
+
minimum=0.0,
|
384 |
+
maximum=2.0,
|
385 |
+
value=1.1,
|
386 |
+
step=0.1,
|
387 |
+
interactive=True,
|
388 |
+
label="Repetition penalty",
|
389 |
+
)
|
390 |
+
max_length_tokens = gr.Slider(
|
391 |
+
minimum=0,
|
392 |
+
maximum=4096,
|
393 |
+
value=2048,
|
394 |
+
step=8,
|
395 |
+
interactive=True,
|
396 |
+
label="Max Generation Tokens",
|
397 |
+
)
|
398 |
+
max_context_length_tokens = gr.Slider(
|
399 |
+
minimum=0,
|
400 |
+
maximum=4096,
|
401 |
+
value=4096,
|
402 |
+
step=128,
|
403 |
+
interactive=True,
|
404 |
+
label="Max History Tokens",
|
405 |
+
)
|
406 |
+
model_select_dropdown = gr.Dropdown(
|
407 |
+
label="Select Models",
|
408 |
+
choices=MODELS,
|
409 |
+
multiselect=False,
|
410 |
+
value=MODELS[0],
|
411 |
+
interactive=True,
|
412 |
+
)
|
413 |
+
|
414 |
+
examples_list = [
|
415 |
+
[
|
416 |
+
"examples/rap.jpeg",
|
417 |
+
"Can you write me a master rap song that rhymes very well based on this image?",
|
418 |
+
],
|
419 |
+
[
|
420 |
+
"examples/app.png",
|
421 |
+
"What is this app about?",
|
422 |
+
],
|
423 |
+
[
|
424 |
+
"examples/pipeline.png",
|
425 |
+
"Help me write a python code based on the image.",
|
426 |
+
],
|
427 |
+
[
|
428 |
+
"examples/chart.png",
|
429 |
+
"Could you help me to re-draw this picture with python codes?",
|
430 |
+
],
|
431 |
+
[
|
432 |
+
"examples/mirror.png",
|
433 |
+
"How many people are there in the image. Why?",
|
434 |
+
],
|
435 |
+
[
|
436 |
+
"examples/puzzle.png",
|
437 |
+
"Can this 2 pieces combine together?",
|
438 |
+
],
|
439 |
+
]
|
440 |
+
gr.Examples(examples=examples_list, inputs=[image_box, text_box])
|
441 |
+
gr.Markdown(description)
|
442 |
+
|
443 |
+
input_widgets = [
|
444 |
+
input_text,
|
445 |
+
input_image,
|
446 |
+
chatbot,
|
447 |
+
history,
|
448 |
+
top_p,
|
449 |
+
temperature,
|
450 |
+
repetition_penalty,
|
451 |
+
max_length_tokens,
|
452 |
+
max_context_length_tokens,
|
453 |
+
model_select_dropdown,
|
454 |
+
]
|
455 |
+
output_widgets = [chatbot, history, status_display]
|
456 |
+
|
457 |
+
transfer_input_args = dict(
|
458 |
+
fn=transfer_input,
|
459 |
+
inputs=[text_box, image_box],
|
460 |
+
outputs=[input_text, input_image, text_box, image_box, submitBtn],
|
461 |
+
show_progress=True,
|
462 |
+
)
|
463 |
+
|
464 |
+
predict_args = dict(
|
465 |
+
fn=predict,
|
466 |
+
inputs=input_widgets,
|
467 |
+
outputs=output_widgets,
|
468 |
+
show_progress=True,
|
469 |
+
)
|
470 |
+
|
471 |
+
retry_args = dict(
|
472 |
+
fn=retry,
|
473 |
+
inputs=input_widgets,
|
474 |
+
outputs=output_widgets,
|
475 |
+
show_progress=True,
|
476 |
+
)
|
477 |
+
|
478 |
+
reset_args = dict(
|
479 |
+
fn=reset_textbox, inputs=[], outputs=[text_box, status_display]
|
480 |
+
)
|
481 |
+
|
482 |
+
predict_events = [
|
483 |
+
text_box.submit(**transfer_input_args).then(**predict_args),
|
484 |
+
submitBtn.click(**transfer_input_args).then(**predict_args),
|
485 |
+
]
|
486 |
+
|
487 |
+
emptyBtn.click(reset_state, outputs=output_widgets, show_progress=True)
|
488 |
+
emptyBtn.click(**reset_args)
|
489 |
+
retryBtn.click(**retry_args)
|
490 |
+
|
491 |
+
delLastBtn.click(
|
492 |
+
delete_last_conversation,
|
493 |
+
[chatbot, history],
|
494 |
+
output_widgets,
|
495 |
+
show_progress=True,
|
496 |
+
)
|
497 |
+
|
498 |
+
cancelBtn.click(cancel_outputing, [], [status_display], cancels=predict_events)
|
499 |
+
|
500 |
+
return demo
|
501 |
+
|
502 |
+
|
503 |
+
if __name__ == "__main__":
|
504 |
+
demo = build_demo(MODELS)
|
505 |
+
demo.title = "DeepSeek-VL Chatbot"
|
506 |
+
|
507 |
+
reload_javascript()
|
508 |
+
demo.queue(max_size=20).launch(
|
509 |
+
share=False,
|
510 |
+
favicon_path="assets/favicon.ico",
|
511 |
+
)
|
app_modules/conversation.py
ADDED
@@ -0,0 +1,348 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023-2024 DeepSeek.
|
2 |
+
#
|
3 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy of
|
4 |
+
# this software and associated documentation files (the "Software"), to deal in
|
5 |
+
# the Software without restriction, including without limitation the rights to
|
6 |
+
# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
|
7 |
+
# the Software, and to permit persons to whom the Software is furnished to do so,
|
8 |
+
# subject to the following conditions:
|
9 |
+
#
|
10 |
+
# The above copyright notice and this permission notice shall be included in all
|
11 |
+
# copies or substantial portions of the Software.
|
12 |
+
#
|
13 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
14 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
|
15 |
+
# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
|
16 |
+
# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
|
17 |
+
# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
|
18 |
+
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
19 |
+
|
20 |
+
"""
|
21 |
+
From https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
|
22 |
+
"""
|
23 |
+
|
24 |
+
import dataclasses
|
25 |
+
from enum import IntEnum, auto
|
26 |
+
from typing import Dict, List
|
27 |
+
|
28 |
+
|
29 |
+
class SeparatorStyle(IntEnum):
|
30 |
+
"""Separator styles."""
|
31 |
+
|
32 |
+
ADD_COLON_SINGLE = auto()
|
33 |
+
ADD_COLON_TWO = auto()
|
34 |
+
ADD_COLON_SPACE_SINGLE = auto()
|
35 |
+
NO_COLON_SINGLE = auto()
|
36 |
+
NO_COLON_TWO = auto()
|
37 |
+
ADD_NEW_LINE_SINGLE = auto()
|
38 |
+
LLAMA2 = auto()
|
39 |
+
CHATGLM = auto()
|
40 |
+
CHATML = auto()
|
41 |
+
CHATINTERN = auto()
|
42 |
+
DOLLY = auto()
|
43 |
+
RWKV = auto()
|
44 |
+
PHOENIX = auto()
|
45 |
+
ROBIN = auto()
|
46 |
+
DeepSeek = auto()
|
47 |
+
PLAIN = auto()
|
48 |
+
ALIGNMENT = auto()
|
49 |
+
|
50 |
+
|
51 |
+
@dataclasses.dataclass
|
52 |
+
class Conversation:
|
53 |
+
"""A class that manages prompt templates and keeps all conversation history."""
|
54 |
+
|
55 |
+
# The name of this template
|
56 |
+
name: str
|
57 |
+
# The template of the system prompt
|
58 |
+
system_template: str = "{system_message}"
|
59 |
+
# The system message
|
60 |
+
system_message: str = ""
|
61 |
+
# The names of two roles
|
62 |
+
roles: List[str] = (("USER", "ASSISTANT"),)
|
63 |
+
# All messages. Each item is (role, message).
|
64 |
+
messages: List[List[str]] = ()
|
65 |
+
# The number of few shot examples
|
66 |
+
offset: int = 0
|
67 |
+
# The separator style and configurations
|
68 |
+
sep_style: SeparatorStyle = SeparatorStyle.ADD_COLON_SINGLE
|
69 |
+
sep: str = "\n"
|
70 |
+
sep2: str = None
|
71 |
+
# Stop criteria (the default one is EOS token)
|
72 |
+
stop_str: str = None
|
73 |
+
# Stops generation if meeting any token in this list
|
74 |
+
stop_token_ids: List[int] = None
|
75 |
+
|
76 |
+
def get_prompt(self) -> str:
|
77 |
+
"""Get the prompt for generation."""
|
78 |
+
system_prompt = self.system_template.format(system_message=self.system_message)
|
79 |
+
|
80 |
+
if self.sep_style == SeparatorStyle.DeepSeek:
|
81 |
+
seps = [self.sep, self.sep2]
|
82 |
+
if system_prompt == "" or system_prompt is None:
|
83 |
+
ret = ""
|
84 |
+
else:
|
85 |
+
ret = system_prompt + seps[0]
|
86 |
+
for i, (role, message) in enumerate(self.messages):
|
87 |
+
if message:
|
88 |
+
ret += role + ": " + message + seps[i % 2]
|
89 |
+
else:
|
90 |
+
ret += role + ":"
|
91 |
+
return ret
|
92 |
+
elif self.sep_style == SeparatorStyle.LLAMA2:
|
93 |
+
seps = [self.sep, self.sep2]
|
94 |
+
if self.system_message:
|
95 |
+
ret = system_prompt
|
96 |
+
else:
|
97 |
+
ret = "[INST] "
|
98 |
+
for i, (role, message) in enumerate(self.messages):
|
99 |
+
tag = self.roles[i % 2]
|
100 |
+
if message:
|
101 |
+
if type(message) is tuple: # multimodal message
|
102 |
+
message, _ = message
|
103 |
+
if i == 0:
|
104 |
+
ret += message + " "
|
105 |
+
else:
|
106 |
+
ret += tag + " " + message + seps[i % 2]
|
107 |
+
else:
|
108 |
+
ret += tag
|
109 |
+
return ret
|
110 |
+
elif self.sep_style == SeparatorStyle.PLAIN:
|
111 |
+
seps = [self.sep, self.sep2]
|
112 |
+
ret = ""
|
113 |
+
for i, (role, message) in enumerate(self.messages):
|
114 |
+
if message:
|
115 |
+
if type(message) is tuple:
|
116 |
+
message, _, _ = message
|
117 |
+
if i % 2 == 0:
|
118 |
+
ret += message + seps[i % 2]
|
119 |
+
else:
|
120 |
+
ret += message + seps[i % 2]
|
121 |
+
else:
|
122 |
+
ret += ""
|
123 |
+
return ret
|
124 |
+
elif self.sep_style == SeparatorStyle.ALIGNMENT:
|
125 |
+
seps = [self.sep, self.sep2]
|
126 |
+
ret = ""
|
127 |
+
for i, (role, message) in enumerate(self.messages):
|
128 |
+
if message:
|
129 |
+
if type(message) is tuple:
|
130 |
+
message, _, _ = message
|
131 |
+
if i % 2 == 0:
|
132 |
+
ret += "<image>\n" + seps[i % 2]
|
133 |
+
else:
|
134 |
+
ret += message + seps[i % 2]
|
135 |
+
else:
|
136 |
+
ret += ""
|
137 |
+
return ret
|
138 |
+
else:
|
139 |
+
raise ValueError(f"Invalid style: {self.sep_style}")
|
140 |
+
|
141 |
+
def get_prompt_for_current_round(self, content=None):
|
142 |
+
"""Get current round formatted question prompt during sft training"""
|
143 |
+
if self.sep_style == SeparatorStyle.PLAIN:
|
144 |
+
formatted_question = "<image>\n"
|
145 |
+
elif self.sep_style == SeparatorStyle.DeepSeek:
|
146 |
+
formatted_question = (
|
147 |
+
f"{self.roles[0]}: " + content.strip() + self.sep + f"{self.roles[1]}:"
|
148 |
+
)
|
149 |
+
else:
|
150 |
+
raise ValueError(f"Unsupported sep_style: {self.sep_style}")
|
151 |
+
return formatted_question
|
152 |
+
|
153 |
+
def set_system_message(self, system_message: str):
|
154 |
+
"""Set the system message."""
|
155 |
+
self.system_message = system_message
|
156 |
+
|
157 |
+
def append_message(self, role: str, message: str):
|
158 |
+
"""Append a new message."""
|
159 |
+
self.messages.append([role, message])
|
160 |
+
|
161 |
+
def reset_message(self):
|
162 |
+
"""Reset a new message."""
|
163 |
+
self.messages = []
|
164 |
+
|
165 |
+
def update_last_message(self, message: str):
|
166 |
+
"""Update the last output.
|
167 |
+
|
168 |
+
The last message is typically set to be None when constructing the prompt,
|
169 |
+
so we need to update it in-place after getting the response from a model.
|
170 |
+
"""
|
171 |
+
self.messages[-1][1] = message
|
172 |
+
|
173 |
+
def to_gradio_chatbot(self):
|
174 |
+
"""Convert the conversation to gradio chatbot format."""
|
175 |
+
ret = []
|
176 |
+
for i, (role, msg) in enumerate(self.messages[self.offset :]):
|
177 |
+
if i % 2 == 0:
|
178 |
+
ret.append([msg, None])
|
179 |
+
else:
|
180 |
+
ret[-1][-1] = msg
|
181 |
+
return ret
|
182 |
+
|
183 |
+
def to_openai_api_messages(self):
|
184 |
+
"""Convert the conversation to OpenAI chat completion format."""
|
185 |
+
system_prompt = self.system_template.format(system_message=self.system_message)
|
186 |
+
ret = [{"role": "system", "content": system_prompt}]
|
187 |
+
|
188 |
+
for i, (_, msg) in enumerate(self.messages[self.offset :]):
|
189 |
+
if i % 2 == 0:
|
190 |
+
ret.append({"role": "user", "content": msg})
|
191 |
+
else:
|
192 |
+
if msg is not None:
|
193 |
+
ret.append({"role": "assistant", "content": msg})
|
194 |
+
return ret
|
195 |
+
|
196 |
+
def copy(self):
|
197 |
+
return Conversation(
|
198 |
+
name=self.name,
|
199 |
+
system_template=self.system_template,
|
200 |
+
system_message=self.system_message,
|
201 |
+
roles=self.roles,
|
202 |
+
messages=[[x, y] for x, y in self.messages],
|
203 |
+
offset=self.offset,
|
204 |
+
sep_style=self.sep_style,
|
205 |
+
sep=self.sep,
|
206 |
+
sep2=self.sep2,
|
207 |
+
stop_str=self.stop_str,
|
208 |
+
stop_token_ids=self.stop_token_ids,
|
209 |
+
)
|
210 |
+
|
211 |
+
def dict(self):
|
212 |
+
return {
|
213 |
+
"template_name": self.name,
|
214 |
+
"system_message": self.system_message,
|
215 |
+
"roles": self.roles,
|
216 |
+
"messages": self.messages,
|
217 |
+
"offset": self.offset,
|
218 |
+
}
|
219 |
+
|
220 |
+
|
221 |
+
# A global registry for all conversation templates
|
222 |
+
conv_templates: Dict[str, Conversation] = {}
|
223 |
+
|
224 |
+
|
225 |
+
def register_conv_template(template: Conversation, override: bool = False):
|
226 |
+
"""Register a new conversation template."""
|
227 |
+
if not override:
|
228 |
+
assert (
|
229 |
+
template.name not in conv_templates
|
230 |
+
), f"{template.name} has been registered."
|
231 |
+
|
232 |
+
conv_templates[template.name] = template
|
233 |
+
|
234 |
+
|
235 |
+
def get_conv_template(name: str) -> Conversation:
|
236 |
+
"""Get a conversation template."""
|
237 |
+
return conv_templates[name].copy()
|
238 |
+
|
239 |
+
|
240 |
+
# llava_llama2 template
|
241 |
+
register_conv_template(
|
242 |
+
Conversation(
|
243 |
+
name="llava_llama2",
|
244 |
+
system_message="You are a helpful language and vision assistant. "
|
245 |
+
"You are able to understand the visual content that the user provides, "
|
246 |
+
"and assist the user with a variety of tasks using natural language.",
|
247 |
+
system_template="[INST] <<SYS>>\n{system_message}\n<</SYS>>\n\n",
|
248 |
+
roles=("[INST]", "[/INST]"),
|
249 |
+
messages=(),
|
250 |
+
offset=0,
|
251 |
+
sep_style=SeparatorStyle.LLAMA2,
|
252 |
+
sep=" ",
|
253 |
+
sep2=" </s><s>",
|
254 |
+
stop_token_ids=[2],
|
255 |
+
)
|
256 |
+
)
|
257 |
+
|
258 |
+
# llama2 template
|
259 |
+
# reference: https://github.com/facebookresearch/llama/blob/cfc3fc8c1968d390eb830e65c63865e980873a06/llama/generation.py#L212
|
260 |
+
register_conv_template(
|
261 |
+
Conversation(
|
262 |
+
name="llama-2",
|
263 |
+
system_template="[INST] <<SYS>>\n{system_message}\n<</SYS>>\n\n",
|
264 |
+
roles=("[INST]", "[/INST]"),
|
265 |
+
messages=(),
|
266 |
+
offset=0,
|
267 |
+
sep_style=SeparatorStyle.LLAMA2,
|
268 |
+
sep=" ",
|
269 |
+
sep2=" </s><s>",
|
270 |
+
stop_token_ids=[2],
|
271 |
+
)
|
272 |
+
)
|
273 |
+
|
274 |
+
|
275 |
+
# deepseek template
|
276 |
+
register_conv_template(
|
277 |
+
Conversation(
|
278 |
+
name="deepseek",
|
279 |
+
system_template="{system_message}",
|
280 |
+
# system_message="You are a helpful assistant. Please answer truthfully and write out your "
|
281 |
+
# "thinking step by step to be sure you get the right answer.",
|
282 |
+
system_message="",
|
283 |
+
roles=("User", "Assistant"),
|
284 |
+
messages=(),
|
285 |
+
offset=0,
|
286 |
+
sep_style=SeparatorStyle.DeepSeek,
|
287 |
+
sep="\n\n",
|
288 |
+
sep2="<|end▁of▁sentence|>",
|
289 |
+
stop_token_ids=[100001],
|
290 |
+
stop_str=["User:", "<|end▁of▁sentence|>"],
|
291 |
+
)
|
292 |
+
)
|
293 |
+
|
294 |
+
register_conv_template(
|
295 |
+
Conversation(
|
296 |
+
name="plain",
|
297 |
+
system_template="",
|
298 |
+
system_message="",
|
299 |
+
roles=("", ""),
|
300 |
+
messages=(),
|
301 |
+
offset=0,
|
302 |
+
sep_style=SeparatorStyle.PLAIN,
|
303 |
+
sep="",
|
304 |
+
sep2="",
|
305 |
+
stop_token_ids=[2],
|
306 |
+
stop_str=["</s>"],
|
307 |
+
)
|
308 |
+
)
|
309 |
+
|
310 |
+
|
311 |
+
register_conv_template(
|
312 |
+
Conversation(
|
313 |
+
name="alignment",
|
314 |
+
system_template="",
|
315 |
+
system_message="",
|
316 |
+
roles=("", ""),
|
317 |
+
messages=(),
|
318 |
+
offset=0,
|
319 |
+
sep_style=SeparatorStyle.ALIGNMENT,
|
320 |
+
sep="",
|
321 |
+
sep2="",
|
322 |
+
stop_token_ids=[2],
|
323 |
+
stop_str=["</s>"],
|
324 |
+
)
|
325 |
+
)
|
326 |
+
|
327 |
+
|
328 |
+
if __name__ == "__main__":
|
329 |
+
# print("Llama-2 template:")
|
330 |
+
# conv = get_conv_template("llama-2")
|
331 |
+
# conv.set_system_message("You are a helpful, respectful and honest assistant.")
|
332 |
+
# conv.append_message(conv.roles[0], "Hello!")
|
333 |
+
# conv.append_message(conv.roles[1], "Hi!")
|
334 |
+
# conv.append_message(conv.roles[0], "How are you?")
|
335 |
+
# conv.append_message(conv.roles[1], None)
|
336 |
+
# print(conv.get_prompt())
|
337 |
+
|
338 |
+
# print("\n")
|
339 |
+
|
340 |
+
print("deepseek template:")
|
341 |
+
conv = get_conv_template("deepseek")
|
342 |
+
conv.append_message(conv.roles[0], "Hello!")
|
343 |
+
conv.append_message(conv.roles[1], "Hi! This is Tony.")
|
344 |
+
conv.append_message(conv.roles[0], "Who are you?")
|
345 |
+
conv.append_message(conv.roles[1], "I am a helpful assistant.")
|
346 |
+
conv.append_message(conv.roles[0], "How are you?")
|
347 |
+
conv.append_message(conv.roles[1], None)
|
348 |
+
print(conv.get_prompt())
|
app_modules/gradio_utils.py
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023-2024 DeepSeek.
|
2 |
+
#
|
3 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy of
|
4 |
+
# this software and associated documentation files (the "Software"), to deal in
|
5 |
+
# the Software without restriction, including without limitation the rights to
|
6 |
+
# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
|
7 |
+
# the Software, and to permit persons to whom the Software is furnished to do so,
|
8 |
+
# subject to the following conditions:
|
9 |
+
#
|
10 |
+
# The above copyright notice and this permission notice shall be included in all
|
11 |
+
# copies or substantial portions of the Software.
|
12 |
+
#
|
13 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
14 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
|
15 |
+
# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
|
16 |
+
# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
|
17 |
+
# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
|
18 |
+
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
19 |
+
|
20 |
+
from functools import wraps
|
21 |
+
|
22 |
+
import gradio as gr
|
23 |
+
|
24 |
+
|
25 |
+
def wrap_gen_fn(gen_fn):
|
26 |
+
@wraps(gen_fn)
|
27 |
+
def wrapped_gen_fn(prompt, *args, **kwargs):
|
28 |
+
try:
|
29 |
+
yield from gen_fn(prompt, *args, **kwargs)
|
30 |
+
except gr.Error as g_err:
|
31 |
+
raise g_err
|
32 |
+
except Exception as e:
|
33 |
+
raise gr.Error(f"Failed to generate text: {e}") from e
|
34 |
+
|
35 |
+
return wrapped_gen_fn
|
36 |
+
|
37 |
+
|
38 |
+
def delete_last_conversation(chatbot, history):
|
39 |
+
if len(history) % 2 != 0:
|
40 |
+
gr.Error("history length is not even")
|
41 |
+
return (
|
42 |
+
chatbot,
|
43 |
+
history,
|
44 |
+
"Delete Done",
|
45 |
+
)
|
46 |
+
|
47 |
+
if len(chatbot) > 0:
|
48 |
+
chatbot.pop()
|
49 |
+
|
50 |
+
if len(history) > 0 and len(history) % 2 == 0:
|
51 |
+
history.pop()
|
52 |
+
history.pop()
|
53 |
+
|
54 |
+
return (
|
55 |
+
chatbot,
|
56 |
+
history,
|
57 |
+
"Delete Done",
|
58 |
+
)
|
59 |
+
|
60 |
+
|
61 |
+
def reset_state():
|
62 |
+
return [], [], None, "Reset Done"
|
63 |
+
|
64 |
+
|
65 |
+
def reset_textbox():
|
66 |
+
return gr.update(value=""), ""
|
67 |
+
|
68 |
+
|
69 |
+
def cancel_outputing():
|
70 |
+
return "Stop Done"
|
71 |
+
|
72 |
+
|
73 |
+
def transfer_input(input_text, input_image):
|
74 |
+
print("transferring input text and input image")
|
75 |
+
return (
|
76 |
+
input_text,
|
77 |
+
input_image,
|
78 |
+
gr.update(value=""),
|
79 |
+
gr.update(value=None),
|
80 |
+
gr.Button(visible=True),
|
81 |
+
)
|
82 |
+
|
83 |
+
|
84 |
+
class State:
|
85 |
+
interrupted = False
|
86 |
+
|
87 |
+
def interrupt(self):
|
88 |
+
self.interrupted = True
|
89 |
+
|
90 |
+
def recover(self):
|
91 |
+
self.interrupted = False
|
92 |
+
|
93 |
+
|
94 |
+
shared_state = State()
|
app_modules/overwrites.py
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023-2024 DeepSeek.
|
2 |
+
#
|
3 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy of
|
4 |
+
# this software and associated documentation files (the "Software"), to deal in
|
5 |
+
# the Software without restriction, including without limitation the rights to
|
6 |
+
# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
|
7 |
+
# the Software, and to permit persons to whom the Software is furnished to do so,
|
8 |
+
# subject to the following conditions:
|
9 |
+
#
|
10 |
+
# The above copyright notice and this permission notice shall be included in all
|
11 |
+
# copies or substantial portions of the Software.
|
12 |
+
#
|
13 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
14 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
|
15 |
+
# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
|
16 |
+
# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
|
17 |
+
# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
|
18 |
+
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
19 |
+
|
20 |
+
from __future__ import annotations
|
21 |
+
|
22 |
+
import logging
|
23 |
+
from typing import List, Tuple
|
24 |
+
|
25 |
+
from app_modules.presets import gr
|
26 |
+
from app_modules.utils import convert_asis, convert_mdtext, detect_converted_mark
|
27 |
+
|
28 |
+
|
29 |
+
def compact_text_chunks(self, prompt, text_chunks: List[str]) -> List[str]:
|
30 |
+
logging.debug("Compacting text chunks...🚀🚀🚀")
|
31 |
+
combined_str = [c.strip() for c in text_chunks if c.strip()]
|
32 |
+
combined_str = [f"[{index+1}] {c}" for index, c in enumerate(combined_str)]
|
33 |
+
combined_str = "\n\n".join(combined_str)
|
34 |
+
# resplit based on self.max_chunk_overlap
|
35 |
+
text_splitter = self.get_text_splitter_given_prompt(prompt, 1, padding=1)
|
36 |
+
return text_splitter.split_text(combined_str)
|
37 |
+
|
38 |
+
|
39 |
+
def postprocess(
|
40 |
+
self, y: List[Tuple[str | None, str | None]]
|
41 |
+
) -> List[Tuple[str | None, str | None]]:
|
42 |
+
"""
|
43 |
+
Parameters:
|
44 |
+
y: List of tuples representing the message and response pairs. Each message and response should be a string, which may be in Markdown format.
|
45 |
+
Returns:
|
46 |
+
List of tuples representing the message and response. Each message and response will be a string of HTML.
|
47 |
+
"""
|
48 |
+
if y is None or y == []:
|
49 |
+
return []
|
50 |
+
temp = []
|
51 |
+
for x in y:
|
52 |
+
user, bot = x
|
53 |
+
if not detect_converted_mark(user):
|
54 |
+
user = convert_asis(user)
|
55 |
+
if not detect_converted_mark(bot):
|
56 |
+
bot = convert_mdtext(bot)
|
57 |
+
temp.append((user, bot))
|
58 |
+
return temp
|
59 |
+
|
60 |
+
|
61 |
+
with open("assets/custom.js", "r", encoding="utf-8") as f, open(
|
62 |
+
"assets/Kelpy-Codos.js", "r", encoding="utf-8"
|
63 |
+
) as f2:
|
64 |
+
customJS = f.read()
|
65 |
+
kelpyCodos = f2.read()
|
66 |
+
|
67 |
+
|
68 |
+
def reload_javascript():
|
69 |
+
print("Reloading javascript...")
|
70 |
+
js = f"<script>{customJS}</script><script>{kelpyCodos}</script>"
|
71 |
+
|
72 |
+
def template_response(*args, **kwargs):
|
73 |
+
res = GradioTemplateResponseOriginal(*args, **kwargs)
|
74 |
+
res.body = res.body.replace(b"</html>", f"{js}</html>".encode("utf8"))
|
75 |
+
res.init_headers()
|
76 |
+
return res
|
77 |
+
|
78 |
+
gr.routes.templates.TemplateResponse = template_response
|
79 |
+
|
80 |
+
|
81 |
+
GradioTemplateResponseOriginal = gr.routes.templates.TemplateResponse
|
app_modules/presets.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023-2024 DeepSeek.
|
2 |
+
#
|
3 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy of
|
4 |
+
# this software and associated documentation files (the "Software"), to deal in
|
5 |
+
# the Software without restriction, including without limitation the rights to
|
6 |
+
# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
|
7 |
+
# the Software, and to permit persons to whom the Software is furnished to do so,
|
8 |
+
# subject to the following conditions:
|
9 |
+
#
|
10 |
+
# The above copyright notice and this permission notice shall be included in all
|
11 |
+
# copies or substantial portions of the Software.
|
12 |
+
#
|
13 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
14 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
|
15 |
+
# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
|
16 |
+
# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
|
17 |
+
# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
|
18 |
+
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
19 |
+
|
20 |
+
# -*- coding:utf-8 -*-
|
21 |
+
import gradio as gr
|
22 |
+
|
23 |
+
title = """<h1 align="left" style="min-width:200px; margin-top:0;">Chat with DeepSeek-VL </h1>"""
|
24 |
+
description_top = """"""
|
25 |
+
description = """"""
|
26 |
+
CONCURRENT_COUNT = 10
|
27 |
+
|
28 |
+
|
29 |
+
ALREADY_CONVERTED_MARK = "<!-- ALREADY CONVERTED BY PARSER. -->"
|
30 |
+
|
31 |
+
small_and_beautiful_theme = gr.themes.Soft(
|
32 |
+
primary_hue=gr.themes.Color(
|
33 |
+
c50="#EBFAF2",
|
34 |
+
c100="#CFF3E1",
|
35 |
+
c200="#A8EAC8",
|
36 |
+
c300="#77DEA9",
|
37 |
+
c400="#3FD086",
|
38 |
+
c500="#02C160",
|
39 |
+
c600="#06AE56",
|
40 |
+
c700="#05974E",
|
41 |
+
c800="#057F45",
|
42 |
+
c900="#04673D",
|
43 |
+
c950="#2E5541",
|
44 |
+
name="small_and_beautiful",
|
45 |
+
),
|
46 |
+
secondary_hue=gr.themes.Color(
|
47 |
+
c50="#576b95",
|
48 |
+
c100="#576b95",
|
49 |
+
c200="#576b95",
|
50 |
+
c300="#576b95",
|
51 |
+
c400="#576b95",
|
52 |
+
c500="#576b95",
|
53 |
+
c600="#576b95",
|
54 |
+
c700="#576b95",
|
55 |
+
c800="#576b95",
|
56 |
+
c900="#576b95",
|
57 |
+
c950="#576b95",
|
58 |
+
),
|
59 |
+
neutral_hue=gr.themes.Color(
|
60 |
+
name="gray",
|
61 |
+
c50="#f6f7f8",
|
62 |
+
# c100="#f3f4f6",
|
63 |
+
c100="#F2F2F2",
|
64 |
+
c200="#e5e7eb",
|
65 |
+
c300="#d1d5db",
|
66 |
+
c400="#B2B2B2",
|
67 |
+
c500="#808080",
|
68 |
+
c600="#636363",
|
69 |
+
c700="#515151",
|
70 |
+
c800="#393939",
|
71 |
+
# c900="#272727",
|
72 |
+
c900="#2B2B2B",
|
73 |
+
c950="#171717",
|
74 |
+
),
|
75 |
+
radius_size=gr.themes.sizes.radius_sm,
|
76 |
+
).set(
|
77 |
+
# button_primary_background_fill="*primary_500",
|
78 |
+
button_primary_background_fill_dark="*primary_600",
|
79 |
+
# button_primary_background_fill_hover="*primary_400",
|
80 |
+
# button_primary_border_color="*primary_500",
|
81 |
+
button_primary_border_color_dark="*primary_600",
|
82 |
+
button_primary_text_color="white",
|
83 |
+
button_primary_text_color_dark="white",
|
84 |
+
button_secondary_background_fill="*neutral_100",
|
85 |
+
button_secondary_background_fill_hover="*neutral_50",
|
86 |
+
button_secondary_background_fill_dark="*neutral_900",
|
87 |
+
button_secondary_text_color="*neutral_800",
|
88 |
+
button_secondary_text_color_dark="white",
|
89 |
+
# background_fill_primary="#F7F7F7",
|
90 |
+
# background_fill_primary_dark="#1F1F1F",
|
91 |
+
# block_title_text_color="*primary_500",
|
92 |
+
block_title_background_fill_dark="*primary_900",
|
93 |
+
block_label_background_fill_dark="*primary_900",
|
94 |
+
input_background_fill="#F6F6F6",
|
95 |
+
# chatbot_code_background_color_dark="*neutral_950",
|
96 |
+
)
|
app_modules/utils.py
ADDED
@@ -0,0 +1,228 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023-2024 DeepSeek.
|
2 |
+
#
|
3 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy of
|
4 |
+
# this software and associated documentation files (the "Software"), to deal in
|
5 |
+
# the Software without restriction, including without limitation the rights to
|
6 |
+
# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
|
7 |
+
# the Software, and to permit persons to whom the Software is furnished to do so,
|
8 |
+
# subject to the following conditions:
|
9 |
+
#
|
10 |
+
# The above copyright notice and this permission notice shall be included in all
|
11 |
+
# copies or substantial portions of the Software.
|
12 |
+
#
|
13 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
14 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
|
15 |
+
# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
|
16 |
+
# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
|
17 |
+
# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
|
18 |
+
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
19 |
+
|
20 |
+
# -*- coding:utf-8 -*-
|
21 |
+
from __future__ import annotations
|
22 |
+
|
23 |
+
import html
|
24 |
+
import logging
|
25 |
+
import os
|
26 |
+
import re
|
27 |
+
import time
|
28 |
+
|
29 |
+
import mdtex2html
|
30 |
+
from app_modules.presets import ALREADY_CONVERTED_MARK
|
31 |
+
from markdown import markdown
|
32 |
+
from pygments import highlight
|
33 |
+
from pygments.formatters import HtmlFormatter
|
34 |
+
from pygments.lexers import ClassNotFound, get_lexer_by_name, guess_lexer
|
35 |
+
|
36 |
+
logger = logging.getLogger("gradio_logger")
|
37 |
+
|
38 |
+
|
39 |
+
def configure_logger():
|
40 |
+
logger = logging.getLogger("gradio_logger")
|
41 |
+
logger.setLevel(logging.DEBUG)
|
42 |
+
|
43 |
+
timestr = time.strftime("%Y%m%d-%H%M%S")
|
44 |
+
os.makedirs("logs", exist_ok=True)
|
45 |
+
file_handler = logging.FileHandler(
|
46 |
+
f"logs/{timestr}_gradio_log.log"
|
47 |
+
)
|
48 |
+
console_handler = logging.StreamHandler()
|
49 |
+
|
50 |
+
formatter = logging.Formatter(
|
51 |
+
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
52 |
+
)
|
53 |
+
console_handler.setFormatter(formatter)
|
54 |
+
file_handler.setFormatter(formatter)
|
55 |
+
|
56 |
+
console_handler.setLevel(logging.INFO)
|
57 |
+
file_handler.setLevel(logging.INFO)
|
58 |
+
|
59 |
+
logger.addHandler(console_handler)
|
60 |
+
logger.addHandler(file_handler)
|
61 |
+
|
62 |
+
return logger
|
63 |
+
|
64 |
+
|
65 |
+
def strip_stop_words(x, stop_words):
|
66 |
+
for w in stop_words:
|
67 |
+
if w in x:
|
68 |
+
return x[: x.index(w)].strip()
|
69 |
+
return x.strip()
|
70 |
+
|
71 |
+
|
72 |
+
def format_output(history, text, x):
|
73 |
+
updated_history = history + [[text, x]]
|
74 |
+
a = [[y[0], convert_to_markdown(y[1])] for y in updated_history]
|
75 |
+
return a, updated_history
|
76 |
+
|
77 |
+
|
78 |
+
def markdown_to_html_with_syntax_highlight(md_str): # deprecated
|
79 |
+
def replacer(match):
|
80 |
+
lang = match.group(1) or "text"
|
81 |
+
code = match.group(2)
|
82 |
+
|
83 |
+
try:
|
84 |
+
lexer = get_lexer_by_name(lang, stripall=True)
|
85 |
+
except ValueError:
|
86 |
+
lexer = get_lexer_by_name("text", stripall=True)
|
87 |
+
|
88 |
+
formatter = HtmlFormatter()
|
89 |
+
highlighted_code = highlight(code, lexer, formatter)
|
90 |
+
|
91 |
+
return f'<pre><code class="{lang}">{highlighted_code}</code></pre>'
|
92 |
+
|
93 |
+
code_block_pattern = r"```(\w+)?\n([\s\S]+?)\n```"
|
94 |
+
md_str = re.sub(code_block_pattern, replacer, md_str, flags=re.MULTILINE)
|
95 |
+
|
96 |
+
html_str = markdown(md_str)
|
97 |
+
return html_str
|
98 |
+
|
99 |
+
|
100 |
+
def normalize_markdown(md_text: str) -> str: # deprecated
|
101 |
+
lines = md_text.split("\n")
|
102 |
+
normalized_lines = []
|
103 |
+
inside_list = False
|
104 |
+
|
105 |
+
for i, line in enumerate(lines):
|
106 |
+
if re.match(r"^(\d+\.|-|\*|\+)\s", line.strip()):
|
107 |
+
if not inside_list and i > 0 and lines[i - 1].strip() != "":
|
108 |
+
normalized_lines.append("")
|
109 |
+
inside_list = True
|
110 |
+
normalized_lines.append(line)
|
111 |
+
elif inside_list and line.strip() == "":
|
112 |
+
if i < len(lines) - 1 and not re.match(
|
113 |
+
r"^(\d+\.|-|\*|\+)\s", lines[i + 1].strip()
|
114 |
+
):
|
115 |
+
normalized_lines.append(line)
|
116 |
+
continue
|
117 |
+
else:
|
118 |
+
inside_list = False
|
119 |
+
normalized_lines.append(line)
|
120 |
+
|
121 |
+
return "\n".join(normalized_lines)
|
122 |
+
|
123 |
+
|
124 |
+
def convert_mdtext(md_text):
|
125 |
+
code_block_pattern = re.compile(r"```(.*?)(?:```|$)", re.DOTALL)
|
126 |
+
inline_code_pattern = re.compile(r"`(.*?)`", re.DOTALL)
|
127 |
+
code_blocks = code_block_pattern.findall(md_text)
|
128 |
+
non_code_parts = code_block_pattern.split(md_text)[::2]
|
129 |
+
|
130 |
+
result = []
|
131 |
+
for non_code, code in zip(non_code_parts, code_blocks + [""]):
|
132 |
+
if non_code.strip():
|
133 |
+
non_code = normalize_markdown(non_code)
|
134 |
+
if inline_code_pattern.search(non_code):
|
135 |
+
result.append(markdown(non_code, extensions=["tables"]))
|
136 |
+
else:
|
137 |
+
result.append(mdtex2html.convert(non_code, extensions=["tables"]))
|
138 |
+
if code.strip():
|
139 |
+
code = f"\n```{code}\n\n```"
|
140 |
+
code = markdown_to_html_with_syntax_highlight(code)
|
141 |
+
result.append(code)
|
142 |
+
result = "".join(result)
|
143 |
+
result += ALREADY_CONVERTED_MARK
|
144 |
+
return result
|
145 |
+
|
146 |
+
|
147 |
+
def convert_asis(userinput):
|
148 |
+
return f'<p style="white-space:pre-wrap;">{html.escape(userinput)}</p>{ALREADY_CONVERTED_MARK}'
|
149 |
+
|
150 |
+
|
151 |
+
def is_stop_word_or_prefix(s: str, stop_words: list) -> bool:
|
152 |
+
return any(s.endswith(stop_word) for stop_word in stop_words)
|
153 |
+
|
154 |
+
|
155 |
+
def detect_converted_mark(userinput):
|
156 |
+
return bool(userinput.endswith(ALREADY_CONVERTED_MARK))
|
157 |
+
|
158 |
+
|
159 |
+
def detect_language(code):
|
160 |
+
first_line = "" if code.startswith("\n") else code.strip().split("\n", 1)[0]
|
161 |
+
language = first_line.lower() if first_line else ""
|
162 |
+
code_without_language = code[len(first_line) :].lstrip() if first_line else code
|
163 |
+
return language, code_without_language
|
164 |
+
|
165 |
+
|
166 |
+
def convert_to_markdown(text):
|
167 |
+
text = text.replace("$", "$")
|
168 |
+
text = text.replace("\r\n", "\n")
|
169 |
+
|
170 |
+
def replace_leading_tabs_and_spaces(line):
|
171 |
+
new_line = []
|
172 |
+
|
173 |
+
for char in line:
|
174 |
+
if char == "\t":
|
175 |
+
new_line.append("	")
|
176 |
+
elif char == " ":
|
177 |
+
new_line.append(" ")
|
178 |
+
else:
|
179 |
+
break
|
180 |
+
return "".join(new_line) + line[len(new_line) :]
|
181 |
+
|
182 |
+
markdown_text = ""
|
183 |
+
lines = text.split("\n")
|
184 |
+
in_code_block = False
|
185 |
+
|
186 |
+
for line in lines:
|
187 |
+
if in_code_block is False and line.startswith("```"):
|
188 |
+
in_code_block = True
|
189 |
+
markdown_text += f"{line}\n"
|
190 |
+
elif in_code_block is True and line.startswith("```"):
|
191 |
+
in_code_block = False
|
192 |
+
markdown_text += f"{line}\n"
|
193 |
+
elif in_code_block:
|
194 |
+
markdown_text += f"{line}\n"
|
195 |
+
else:
|
196 |
+
line = replace_leading_tabs_and_spaces(line)
|
197 |
+
line = re.sub(r"^(#)", r"\\\1", line)
|
198 |
+
markdown_text += f"{line} \n"
|
199 |
+
|
200 |
+
return markdown_text
|
201 |
+
|
202 |
+
|
203 |
+
def add_language_tag(text):
|
204 |
+
def detect_language(code_block):
|
205 |
+
try:
|
206 |
+
lexer = guess_lexer(code_block)
|
207 |
+
return lexer.name.lower()
|
208 |
+
except ClassNotFound:
|
209 |
+
return ""
|
210 |
+
|
211 |
+
code_block_pattern = re.compile(r"(```)(\w*\n[^`]+```)", re.MULTILINE)
|
212 |
+
|
213 |
+
def replacement(match):
|
214 |
+
code_block = match.group(2)
|
215 |
+
if match.group(2).startswith("\n"):
|
216 |
+
language = detect_language(code_block)
|
217 |
+
return (
|
218 |
+
f"```{language}{code_block}```" if language else f"```\n{code_block}```"
|
219 |
+
)
|
220 |
+
else:
|
221 |
+
return match.group(1) + code_block + "```"
|
222 |
+
|
223 |
+
text2 = code_block_pattern.sub(replacement, text)
|
224 |
+
return text2
|
225 |
+
|
226 |
+
|
227 |
+
def is_variable_assigned(var_name: str) -> bool:
|
228 |
+
return var_name in locals()
|
assets/Kelpy-Codos.js
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/**
|
2 |
+
* Copyright (c) 2023-2024 DeepSeek.
|
3 |
+
*
|
4 |
+
* Permission is hereby granted, free of charge, to any person obtaining a copy of
|
5 |
+
* this software and associated documentation files (the "Software"), to deal in
|
6 |
+
* the Software without restriction, including without limitation the rights to
|
7 |
+
* use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
|
8 |
+
* the Software, and to permit persons to whom the Software is furnished to do so,
|
9 |
+
* subject to the following conditions:
|
10 |
+
*
|
11 |
+
* The above copyright notice and this permission notice shall be included in all
|
12 |
+
* copies or substantial portions of the Software.
|
13 |
+
*
|
14 |
+
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
15 |
+
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
|
16 |
+
* FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
|
17 |
+
* COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
|
18 |
+
* IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
|
19 |
+
* CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
20 |
+
*/
|
21 |
+
|
22 |
+
// ==UserScript==
|
23 |
+
// @name Kelpy Codos
|
24 |
+
// @namespace https://github.com/Keldos-Li/Kelpy-Codos
|
25 |
+
// @version 1.0.5
|
26 |
+
// @author Keldos; https://keldos.me/
|
27 |
+
// @description Add copy button to PRE tags before CODE tag, for Chuanhu ChatGPT especially.
|
28 |
+
// Based on Chuanhu ChatGPT version: ac04408 (2023-3-22)
|
29 |
+
// @license GPL-3.0
|
30 |
+
// @grant none
|
31 |
+
// ==/UserScript==
|
32 |
+
|
33 |
+
(function () {
|
34 |
+
"use strict";
|
35 |
+
|
36 |
+
function addCopyButton(pre) {
|
37 |
+
var code = pre.querySelector("code");
|
38 |
+
if (!code) {
|
39 |
+
return; // 如果没有找到 <code> 元素,则不添加按钮
|
40 |
+
}
|
41 |
+
var firstChild = code.firstChild;
|
42 |
+
if (!firstChild) {
|
43 |
+
return; // 如果 <code> 元素没有子节点,则不添加按钮
|
44 |
+
}
|
45 |
+
var button = document.createElement("button");
|
46 |
+
button.textContent = "\uD83D\uDCCE"; // 使用 📎 符号作为“复制”按钮的文本
|
47 |
+
button.style.position = "relative";
|
48 |
+
button.style.float = "right";
|
49 |
+
button.style.fontSize = "1em"; // 可选:调整按钮大小
|
50 |
+
button.style.background = "none"; // 可选:去掉背景颜色
|
51 |
+
button.style.border = "none"; // 可选:去掉边框
|
52 |
+
button.style.cursor = "pointer"; // 可选:显示指针样式
|
53 |
+
button.addEventListener("click", function () {
|
54 |
+
var range = document.createRange();
|
55 |
+
range.selectNodeContents(code);
|
56 |
+
range.setStartBefore(firstChild); // 将范围设置为第一个子节点之前
|
57 |
+
var selection = window.getSelection();
|
58 |
+
selection.removeAllRanges();
|
59 |
+
selection.addRange(range);
|
60 |
+
|
61 |
+
try {
|
62 |
+
var success = document.execCommand("copy");
|
63 |
+
if (success) {
|
64 |
+
button.textContent = "\u2714";
|
65 |
+
setTimeout(function () {
|
66 |
+
button.textContent = "\uD83D\uDCCE"; // 恢复按钮为“复制”
|
67 |
+
}, 2000);
|
68 |
+
} else {
|
69 |
+
button.textContent = "\u2716";
|
70 |
+
}
|
71 |
+
} catch (e) {
|
72 |
+
console.error(e);
|
73 |
+
button.textContent = "\u2716";
|
74 |
+
}
|
75 |
+
|
76 |
+
selection.removeAllRanges();
|
77 |
+
});
|
78 |
+
code.insertBefore(button, firstChild); // 将按钮插入到第一个子元素之前
|
79 |
+
}
|
80 |
+
|
81 |
+
function handleNewElements(mutationsList, observer) {
|
82 |
+
for (var mutation of mutationsList) {
|
83 |
+
if (mutation.type === "childList") {
|
84 |
+
for (var node of mutation.addedNodes) {
|
85 |
+
if (node.nodeName === "PRE") {
|
86 |
+
addCopyButton(node);
|
87 |
+
}
|
88 |
+
}
|
89 |
+
}
|
90 |
+
}
|
91 |
+
}
|
92 |
+
|
93 |
+
var observer = new MutationObserver(handleNewElements);
|
94 |
+
observer.observe(document.documentElement, {
|
95 |
+
childList: true,
|
96 |
+
subtree: true,
|
97 |
+
});
|
98 |
+
|
99 |
+
document.querySelectorAll("pre").forEach(addCopyButton);
|
100 |
+
})();
|
assets/avatar.png
ADDED
![]() |
assets/custom.css
ADDED
@@ -0,0 +1,355 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/**
|
2 |
+
* Copyright (c) 2023-2024 DeepSeek.
|
3 |
+
*
|
4 |
+
* Permission is hereby granted, free of charge, to any person obtaining a copy of
|
5 |
+
* this software and associated documentation files (the "Software"), to deal in
|
6 |
+
* the Software without restriction, including without limitation the rights to
|
7 |
+
* use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
|
8 |
+
* the Software, and to permit persons to whom the Software is furnished to do so,
|
9 |
+
* subject to the following conditions:
|
10 |
+
*
|
11 |
+
* The above copyright notice and this permission notice shall be included in all
|
12 |
+
* copies or substantial portions of the Software.
|
13 |
+
*
|
14 |
+
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
15 |
+
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
|
16 |
+
* FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
|
17 |
+
* COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
|
18 |
+
* IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
|
19 |
+
* CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
20 |
+
*/
|
21 |
+
|
22 |
+
:root {
|
23 |
+
--chatbot-color-light: #f3f3f3;
|
24 |
+
--chatbot-color-dark: #121111;
|
25 |
+
}
|
26 |
+
|
27 |
+
/* status_display */
|
28 |
+
#status_display {
|
29 |
+
display: flex;
|
30 |
+
min-height: 2.5em;
|
31 |
+
align-items: flex-end;
|
32 |
+
justify-content: flex-end;
|
33 |
+
}
|
34 |
+
#status_display p {
|
35 |
+
font-size: 0.85em;
|
36 |
+
font-family: monospace;
|
37 |
+
color: var(--body-text-color-subdued);
|
38 |
+
}
|
39 |
+
|
40 |
+
/* usage_display */
|
41 |
+
#usage_display {
|
42 |
+
height: 1em;
|
43 |
+
}
|
44 |
+
#usage_display p {
|
45 |
+
padding: 0 1em;
|
46 |
+
font-size: 0.85em;
|
47 |
+
font-family: monospace;
|
48 |
+
color: var(--body-text-color-subdued);
|
49 |
+
}
|
50 |
+
/* list */
|
51 |
+
ol:not(.options),
|
52 |
+
ul:not(.options) {
|
53 |
+
padding-inline-start: 2em !important;
|
54 |
+
}
|
55 |
+
|
56 |
+
/* Thank @Keldos-Li for fixing it */
|
57 |
+
/* Light mode (default) */
|
58 |
+
#deepseek_chatbot {
|
59 |
+
background-color: var(--chatbot-color-light) !important;
|
60 |
+
color: #000000 !important;
|
61 |
+
}
|
62 |
+
[data-testid="bot"] {
|
63 |
+
background-color: #ffffff !important;
|
64 |
+
}
|
65 |
+
[data-testid="user"] {
|
66 |
+
background-color: #95ec69 !important;
|
67 |
+
}
|
68 |
+
|
69 |
+
/* Dark mode */
|
70 |
+
.dark #deepseek_chatbot {
|
71 |
+
background-color: var(--chatbot-color-dark) !important;
|
72 |
+
color: #ffffff !important;
|
73 |
+
}
|
74 |
+
.dark [data-testid="bot"] {
|
75 |
+
background-color: #2c2c2c !important;
|
76 |
+
}
|
77 |
+
.dark [data-testid="user"] {
|
78 |
+
background-color: #26b561 !important;
|
79 |
+
}
|
80 |
+
|
81 |
+
#deepseek_chatbot {
|
82 |
+
height: 100%;
|
83 |
+
min-height: 800px;
|
84 |
+
flex-grow: 1;
|
85 |
+
overflow: auto;
|
86 |
+
}
|
87 |
+
|
88 |
+
[class*="message"] {
|
89 |
+
border-radius: var(--radius-xl) !important;
|
90 |
+
border: none;
|
91 |
+
padding: var(--spacing-xl) !important;
|
92 |
+
font-size: var(--text-md) !important;
|
93 |
+
line-height: var(--line-md) !important;
|
94 |
+
min-height: calc(var(--text-md) * var(--line-md) + 2 * var(--spacing-xl));
|
95 |
+
min-width: calc(var(--text-md) * var(--line-md) + 2 * var(--spacing-xl));
|
96 |
+
}
|
97 |
+
[data-testid="bot"] {
|
98 |
+
max-width: 85%;
|
99 |
+
border-bottom-left-radius: 0 !important;
|
100 |
+
}
|
101 |
+
[data-testid="user"] {
|
102 |
+
max-width: 85%;
|
103 |
+
width: auto !important;
|
104 |
+
border-bottom-right-radius: 0 !important;
|
105 |
+
}
|
106 |
+
/* Table */
|
107 |
+
table {
|
108 |
+
margin: 1em 0;
|
109 |
+
border-collapse: collapse;
|
110 |
+
empty-cells: show;
|
111 |
+
}
|
112 |
+
td,
|
113 |
+
th {
|
114 |
+
border: 1.2px solid var(--border-color-primary) !important;
|
115 |
+
padding: 0.2em;
|
116 |
+
}
|
117 |
+
thead {
|
118 |
+
background-color: rgba(175, 184, 193, 0.2);
|
119 |
+
}
|
120 |
+
thead th {
|
121 |
+
padding: 0.5em 0.2em;
|
122 |
+
}
|
123 |
+
/* Inline code */
|
124 |
+
#deepseek_chatbot code {
|
125 |
+
display: inline;
|
126 |
+
white-space: break-spaces;
|
127 |
+
border-radius: 6px;
|
128 |
+
margin: 0 2px 0 2px;
|
129 |
+
padding: 0.2em 0.4em 0.1em 0.4em;
|
130 |
+
background-color: rgba(175, 184, 193, 0.2);
|
131 |
+
}
|
132 |
+
/* Code block */
|
133 |
+
#deepseek_chatbot pre code {
|
134 |
+
display: block;
|
135 |
+
overflow: auto;
|
136 |
+
white-space: pre;
|
137 |
+
background-color: #1c1d1e !important;
|
138 |
+
border-radius: 10px;
|
139 |
+
padding: 1.4em 1.2em 0em 1.4em;
|
140 |
+
margin: 1.2em 2em 1.2em 0.5em;
|
141 |
+
color: #fdf8f8;
|
142 |
+
box-shadow: 6px 6px 16px hsla(0, 0%, 0%, 0.2);
|
143 |
+
}
|
144 |
+
/* Hightlight */
|
145 |
+
#deepseek_chatbot .highlight {
|
146 |
+
background-color: transparent;
|
147 |
+
}
|
148 |
+
#deepseek_chatbot .highlight .hll {
|
149 |
+
background-color: #49483e;
|
150 |
+
}
|
151 |
+
#deepseek_chatbot .highlight .c {
|
152 |
+
color: #75715e;
|
153 |
+
} /* Comment */
|
154 |
+
#deepseek_chatbot .highlight .err {
|
155 |
+
color: #960050;
|
156 |
+
background-color: #1e0010;
|
157 |
+
} /* Error */
|
158 |
+
#deepseek_chatbot .highlight .k {
|
159 |
+
color: #66d9ef;
|
160 |
+
} /* Keyword */
|
161 |
+
#deepseek_chatbot .highlight .l {
|
162 |
+
color: #ae81ff;
|
163 |
+
} /* Literal */
|
164 |
+
#deepseek_chatbot .highlight .n {
|
165 |
+
color: #f8f8f2;
|
166 |
+
} /* Name */
|
167 |
+
#deepseek_chatbot .highlight .o {
|
168 |
+
color: #f92672;
|
169 |
+
} /* Operator */
|
170 |
+
#deepseek_chatbot .highlight .p {
|
171 |
+
color: #f8f8f2;
|
172 |
+
} /* Punctuation */
|
173 |
+
#deepseek_chatbot .highlight .ch {
|
174 |
+
color: #75715e;
|
175 |
+
} /* Comment.Hashbang */
|
176 |
+
#deepseek_chatbot .highlight .cm {
|
177 |
+
color: #75715e;
|
178 |
+
} /* Comment.Multiline */
|
179 |
+
#deepseek_chatbot .highlight .cp {
|
180 |
+
color: #75715e;
|
181 |
+
} /* Comment.Preproc */
|
182 |
+
#deepseek_chatbot .highlight .cpf {
|
183 |
+
color: #75715e;
|
184 |
+
} /* Comment.PreprocFile */
|
185 |
+
#deepseek_chatbot .highlight .c1 {
|
186 |
+
color: #75715e;
|
187 |
+
} /* Comment.Single */
|
188 |
+
#deepseek_chatbot .highlight .cs {
|
189 |
+
color: #75715e;
|
190 |
+
} /* Comment.Special */
|
191 |
+
#deepseek_chatbot .highlight .gd {
|
192 |
+
color: #f92672;
|
193 |
+
} /* Generic.Deleted */
|
194 |
+
#deepseek_chatbot .highlight .ge {
|
195 |
+
font-style: italic;
|
196 |
+
} /* Generic.Emph */
|
197 |
+
#deepseek_chatbot .highlight .gi {
|
198 |
+
color: #a6e22e;
|
199 |
+
} /* Generic.Inserted */
|
200 |
+
#deepseek_chatbot .highlight .gs {
|
201 |
+
font-weight: bold;
|
202 |
+
} /* Generic.Strong */
|
203 |
+
#deepseek_chatbot .highlight .gu {
|
204 |
+
color: #75715e;
|
205 |
+
} /* Generic.Subheading */
|
206 |
+
#deepseek_chatbot .highlight .kc {
|
207 |
+
color: #66d9ef;
|
208 |
+
} /* Keyword.Constant */
|
209 |
+
#deepseek_chatbot .highlight .kd {
|
210 |
+
color: #66d9ef;
|
211 |
+
} /* Keyword.Declaration */
|
212 |
+
#deepseek_chatbot .highlight .kn {
|
213 |
+
color: #f92672;
|
214 |
+
} /* Keyword.Namespace */
|
215 |
+
#deepseek_chatbot .highlight .kp {
|
216 |
+
color: #66d9ef;
|
217 |
+
} /* Keyword.Pseudo */
|
218 |
+
#deepseek_chatbot .highlight .kr {
|
219 |
+
color: #66d9ef;
|
220 |
+
} /* Keyword.Reserved */
|
221 |
+
#deepseek_chatbot .highlight .kt {
|
222 |
+
color: #66d9ef;
|
223 |
+
} /* Keyword.Type */
|
224 |
+
#deepseek_chatbot .highlight .ld {
|
225 |
+
color: #e6db74;
|
226 |
+
} /* Literal.Date */
|
227 |
+
#deepseek_chatbot .highlight .m {
|
228 |
+
color: #ae81ff;
|
229 |
+
} /* Literal.Number */
|
230 |
+
#deepseek_chatbot .highlight .s {
|
231 |
+
color: #e6db74;
|
232 |
+
} /* Literal.String */
|
233 |
+
#deepseek_chatbot .highlight .na {
|
234 |
+
color: #a6e22e;
|
235 |
+
} /* Name.Attribute */
|
236 |
+
#deepseek_chatbot .highlight .nb {
|
237 |
+
color: #f8f8f2;
|
238 |
+
} /* Name.Builtin */
|
239 |
+
#deepseek_chatbot .highlight .nc {
|
240 |
+
color: #a6e22e;
|
241 |
+
} /* Name.Class */
|
242 |
+
#deepseek_chatbot .highlight .no {
|
243 |
+
color: #66d9ef;
|
244 |
+
} /* Name.Constant */
|
245 |
+
#deepseek_chatbot .highlight .nd {
|
246 |
+
color: #a6e22e;
|
247 |
+
} /* Name.Decorator */
|
248 |
+
#deepseek_chatbot .highlight .ni {
|
249 |
+
color: #f8f8f2;
|
250 |
+
} /* Name.Entity */
|
251 |
+
#deepseek_chatbot .highlight .ne {
|
252 |
+
color: #a6e22e;
|
253 |
+
} /* Name.Exception */
|
254 |
+
#deepseek_chatbot .highlight .nf {
|
255 |
+
color: #a6e22e;
|
256 |
+
} /* Name.Function */
|
257 |
+
#deepseek_chatbot .highlight .nl {
|
258 |
+
color: #f8f8f2;
|
259 |
+
} /* Name.Label */
|
260 |
+
#deepseek_chatbot .highlight .nn {
|
261 |
+
color: #f8f8f2;
|
262 |
+
} /* Name.Namespace */
|
263 |
+
#deepseek_chatbot .highlight .nx {
|
264 |
+
color: #a6e22e;
|
265 |
+
} /* Name.Other */
|
266 |
+
#deepseek_chatbot .highlight .py {
|
267 |
+
color: #f8f8f2;
|
268 |
+
} /* Name.Property */
|
269 |
+
#deepseek_chatbot .highlight .nt {
|
270 |
+
color: #f92672;
|
271 |
+
} /* Name.Tag */
|
272 |
+
#deepseek_chatbot .highlight .nv {
|
273 |
+
color: #f8f8f2;
|
274 |
+
} /* Name.Variable */
|
275 |
+
#deepseek_chatbot .highlight .ow {
|
276 |
+
color: #f92672;
|
277 |
+
} /* Operator.Word */
|
278 |
+
#deepseek_chatbot .highlight .w {
|
279 |
+
color: #f8f8f2;
|
280 |
+
} /* Text.Whitespace */
|
281 |
+
#deepseek_chatbot .highlight .mb {
|
282 |
+
color: #ae81ff;
|
283 |
+
} /* Literal.Number.Bin */
|
284 |
+
#deepseek_chatbot .highlight .mf {
|
285 |
+
color: #ae81ff;
|
286 |
+
} /* Literal.Number.Float */
|
287 |
+
#deepseek_chatbot .highlight .mh {
|
288 |
+
color: #ae81ff;
|
289 |
+
} /* Literal.Number.Hex */
|
290 |
+
#deepseek_chatbot .highlight .mi {
|
291 |
+
color: #ae81ff;
|
292 |
+
} /* Literal.Number.Integer */
|
293 |
+
#deepseek_chatbot .highlight .mo {
|
294 |
+
color: #ae81ff;
|
295 |
+
} /* Literal.Number.Oct */
|
296 |
+
#deepseek_chatbot .highlight .sa {
|
297 |
+
color: #e6db74;
|
298 |
+
} /* Literal.String.Affix */
|
299 |
+
#deepseek_chatbot .highlight .sb {
|
300 |
+
color: #e6db74;
|
301 |
+
} /* Literal.String.Backtick */
|
302 |
+
#deepseek_chatbot .highlight .sc {
|
303 |
+
color: #e6db74;
|
304 |
+
} /* Literal.String.Char */
|
305 |
+
#deepseek_chatbot .highlight .dl {
|
306 |
+
color: #e6db74;
|
307 |
+
} /* Literal.String.Delimiter */
|
308 |
+
#deepseek_chatbot .highlight .sd {
|
309 |
+
color: #e6db74;
|
310 |
+
} /* Literal.String.Doc */
|
311 |
+
#deepseek_chatbot .highlight .s2 {
|
312 |
+
color: #e6db74;
|
313 |
+
} /* Literal.String.Double */
|
314 |
+
#deepseek_chatbot .highlight .se {
|
315 |
+
color: #ae81ff;
|
316 |
+
} /* Literal.String.Escape */
|
317 |
+
#deepseek_chatbot .highlight .sh {
|
318 |
+
color: #e6db74;
|
319 |
+
} /* Literal.String.Heredoc */
|
320 |
+
#deepseek_chatbot .highlight .si {
|
321 |
+
color: #e6db74;
|
322 |
+
} /* Literal.String.Interpol */
|
323 |
+
#deepseek_chatbot .highlight .sx {
|
324 |
+
color: #e6db74;
|
325 |
+
} /* Literal.String.Other */
|
326 |
+
#deepseek_chatbot .highlight .sr {
|
327 |
+
color: #e6db74;
|
328 |
+
} /* Literal.String.Regex */
|
329 |
+
#deepseek_chatbot .highlight .s1 {
|
330 |
+
color: #e6db74;
|
331 |
+
} /* Literal.String.Single */
|
332 |
+
#deepseek_chatbot .highlight .ss {
|
333 |
+
color: #e6db74;
|
334 |
+
} /* Literal.String.Symbol */
|
335 |
+
#deepseek_chatbot .highlight .bp {
|
336 |
+
color: #f8f8f2;
|
337 |
+
} /* Name.Builtin.Pseudo */
|
338 |
+
#deepseek_chatbot .highlight .fm {
|
339 |
+
color: #a6e22e;
|
340 |
+
} /* Name.Function.Magic */
|
341 |
+
#deepseek_chatbot .highlight .vc {
|
342 |
+
color: #f8f8f2;
|
343 |
+
} /* Name.Variable.Class */
|
344 |
+
#deepseek_chatbot .highlight .vg {
|
345 |
+
color: #f8f8f2;
|
346 |
+
} /* Name.Variable.Global */
|
347 |
+
#deepseek_chatbot .highlight .vi {
|
348 |
+
color: #f8f8f2;
|
349 |
+
} /* Name.Variable.Instance */
|
350 |
+
#deepseek_chatbot .highlight .vm {
|
351 |
+
color: #f8f8f2;
|
352 |
+
} /* Name.Variable.Magic */
|
353 |
+
#deepseek_chatbot .highlight .il {
|
354 |
+
color: #ae81ff;
|
355 |
+
} /* Literal.Number.Integer.Long */
|
assets/custom.js
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/**
|
2 |
+
* Copyright (c) 2023-2024 DeepSeek.
|
3 |
+
*
|
4 |
+
* Permission is hereby granted, free of charge, to any person obtaining a copy of
|
5 |
+
* this software and associated documentation files (the "Software"), to deal in
|
6 |
+
* the Software without restriction, including without limitation the rights to
|
7 |
+
* use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
|
8 |
+
* the Software, and to permit persons to whom the Software is furnished to do so,
|
9 |
+
* subject to the following conditions:
|
10 |
+
*
|
11 |
+
* The above copyright notice and this permission notice shall be included in all
|
12 |
+
* copies or substantial portions of the Software.
|
13 |
+
*
|
14 |
+
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
15 |
+
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
|
16 |
+
* FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
|
17 |
+
* COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
|
18 |
+
* IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
|
19 |
+
* CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
20 |
+
*/
|
21 |
+
|
22 |
+
// custom javascript here
|
assets/favicon.ico
ADDED
|
deepseek_vl/__init__.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023-2024 DeepSeek.
|
2 |
+
#
|
3 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy of
|
4 |
+
# this software and associated documentation files (the "Software"), to deal in
|
5 |
+
# the Software without restriction, including without limitation the rights to
|
6 |
+
# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
|
7 |
+
# the Software, and to permit persons to whom the Software is furnished to do so,
|
8 |
+
# subject to the following conditions:
|
9 |
+
#
|
10 |
+
# The above copyright notice and this permission notice shall be included in all
|
11 |
+
# copies or substantial portions of the Software.
|
12 |
+
#
|
13 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
14 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
|
15 |
+
# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
|
16 |
+
# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
|
17 |
+
# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
|
18 |
+
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
19 |
+
|
20 |
+
|
21 |
+
# check if python version is above 3.10
|
22 |
+
import sys
|
23 |
+
|
24 |
+
if sys.version_info >= (3, 10):
|
25 |
+
print("Python version is above 3.10, patching the collections module.")
|
26 |
+
# Monkey patch collections
|
27 |
+
import collections
|
28 |
+
import collections.abc
|
29 |
+
|
30 |
+
for type_name in collections.abc.__all__:
|
31 |
+
setattr(collections, type_name, getattr(collections.abc, type_name))
|
deepseek_vl/models/__init__.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023-2024 DeepSeek.
|
2 |
+
#
|
3 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy of
|
4 |
+
# this software and associated documentation files (the "Software"), to deal in
|
5 |
+
# the Software without restriction, including without limitation the rights to
|
6 |
+
# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
|
7 |
+
# the Software, and to permit persons to whom the Software is furnished to do so,
|
8 |
+
# subject to the following conditions:
|
9 |
+
#
|
10 |
+
# The above copyright notice and this permission notice shall be included in all
|
11 |
+
# copies or substantial portions of the Software.
|
12 |
+
#
|
13 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
14 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
|
15 |
+
# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
|
16 |
+
# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
|
17 |
+
# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
|
18 |
+
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
19 |
+
|
20 |
+
from .image_processing_vlm import VLMImageProcessor
|
21 |
+
from .modeling_vlm import MultiModalityCausalLM
|
22 |
+
from .processing_vlm import VLChatProcessor
|
23 |
+
|
24 |
+
__all__ = [
|
25 |
+
"VLMImageProcessor",
|
26 |
+
"VLChatProcessor",
|
27 |
+
"MultiModalityCausalLM",
|
28 |
+
]
|
deepseek_vl/models/clip_encoder.py
ADDED
@@ -0,0 +1,242 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023-2024 DeepSeek.
|
2 |
+
#
|
3 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy of
|
4 |
+
# this software and associated documentation files (the "Software"), to deal in
|
5 |
+
# the Software without restriction, including without limitation the rights to
|
6 |
+
# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
|
7 |
+
# the Software, and to permit persons to whom the Software is furnished to do so,
|
8 |
+
# subject to the following conditions:
|
9 |
+
#
|
10 |
+
# The above copyright notice and this permission notice shall be included in all
|
11 |
+
# copies or substantial portions of the Software.
|
12 |
+
#
|
13 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
14 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
|
15 |
+
# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
|
16 |
+
# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
|
17 |
+
# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
|
18 |
+
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
19 |
+
|
20 |
+
from typing import Dict, List, Literal, Optional, Tuple, Union
|
21 |
+
|
22 |
+
import torch
|
23 |
+
import torch.nn as nn
|
24 |
+
import torchvision.transforms
|
25 |
+
from einops import rearrange
|
26 |
+
|
27 |
+
from deepseek_vl.models.sam import create_sam_vit
|
28 |
+
from deepseek_vl.models.siglip_vit import create_siglip_vit
|
29 |
+
|
30 |
+
|
31 |
+
class CLIPVisionTower(nn.Module):
|
32 |
+
def __init__(
|
33 |
+
self,
|
34 |
+
model_name: str = "siglip_large_patch16_384",
|
35 |
+
image_size: Union[Tuple[int, int], int] = 336,
|
36 |
+
select_feature: str = "patch",
|
37 |
+
select_layer: int = -2,
|
38 |
+
select_layers: list = None,
|
39 |
+
ckpt_path: str = "",
|
40 |
+
pixel_mean: Optional[List[float]] = None,
|
41 |
+
pixel_std: Optional[List[float]] = None,
|
42 |
+
**kwargs,
|
43 |
+
):
|
44 |
+
super().__init__()
|
45 |
+
|
46 |
+
self.model_name = model_name
|
47 |
+
self.select_feature = select_feature
|
48 |
+
self.select_layer = select_layer
|
49 |
+
self.select_layers = select_layers
|
50 |
+
|
51 |
+
vision_tower_params = {
|
52 |
+
"model_name": model_name,
|
53 |
+
"image_size": image_size,
|
54 |
+
"ckpt_path": ckpt_path,
|
55 |
+
"select_layer": select_layer,
|
56 |
+
}
|
57 |
+
vision_tower_params.update(kwargs)
|
58 |
+
self.vision_tower, self.forward_kwargs = self.build_vision_tower(
|
59 |
+
vision_tower_params
|
60 |
+
)
|
61 |
+
|
62 |
+
if pixel_mean is not None and pixel_std is not None:
|
63 |
+
image_norm = torchvision.transforms.Normalize(
|
64 |
+
mean=pixel_mean, std=pixel_std
|
65 |
+
)
|
66 |
+
else:
|
67 |
+
image_norm = None
|
68 |
+
|
69 |
+
self.image_norm = image_norm
|
70 |
+
|
71 |
+
def build_vision_tower(self, vision_tower_params):
|
72 |
+
if self.model_name.startswith("siglip"):
|
73 |
+
self.select_feature = "same"
|
74 |
+
vision_tower = create_siglip_vit(**vision_tower_params)
|
75 |
+
forward_kwargs = dict()
|
76 |
+
|
77 |
+
elif self.model_name.startswith("sam"):
|
78 |
+
vision_tower = create_sam_vit(**vision_tower_params)
|
79 |
+
forward_kwargs = dict()
|
80 |
+
|
81 |
+
else: # huggingface
|
82 |
+
from transformers import CLIPVisionModel
|
83 |
+
|
84 |
+
vision_tower = CLIPVisionModel.from_pretrained(**vision_tower_params)
|
85 |
+
forward_kwargs = dict(output_hidden_states=True)
|
86 |
+
|
87 |
+
return vision_tower, forward_kwargs
|
88 |
+
|
89 |
+
def feature_select(self, image_forward_outs):
|
90 |
+
if isinstance(image_forward_outs, torch.Tensor):
|
91 |
+
# the output has been the self.select_layer"s features
|
92 |
+
image_features = image_forward_outs
|
93 |
+
else:
|
94 |
+
image_features = image_forward_outs.hidden_states[self.select_layer]
|
95 |
+
|
96 |
+
if self.select_feature == "patch":
|
97 |
+
# if the output has cls_token
|
98 |
+
image_features = image_features[:, 1:]
|
99 |
+
elif self.select_feature == "cls_patch":
|
100 |
+
image_features = image_features
|
101 |
+
elif self.select_feature == "same":
|
102 |
+
image_features = image_features
|
103 |
+
|
104 |
+
else:
|
105 |
+
raise ValueError(f"Unexpected select feature: {self.select_feature}")
|
106 |
+
return image_features
|
107 |
+
|
108 |
+
def forward(self, images):
|
109 |
+
"""
|
110 |
+
|
111 |
+
Args:
|
112 |
+
images (torch.Tensor): [b, 3, H, W]
|
113 |
+
|
114 |
+
Returns:
|
115 |
+
image_features (torch.Tensor): [b, n_patch, d]
|
116 |
+
"""
|
117 |
+
|
118 |
+
if self.image_norm is not None:
|
119 |
+
images = self.image_norm(images)
|
120 |
+
|
121 |
+
image_forward_outs = self.vision_tower(images, **self.forward_kwargs)
|
122 |
+
image_features = self.feature_select(image_forward_outs)
|
123 |
+
return image_features
|
124 |
+
|
125 |
+
|
126 |
+
class HybridVisionTower(nn.Module):
|
127 |
+
def __init__(
|
128 |
+
self,
|
129 |
+
high_res_cfg: Dict,
|
130 |
+
low_res_cfg: Dict,
|
131 |
+
freeze_high: bool = False,
|
132 |
+
freeze_low: bool = False,
|
133 |
+
concat_type: Literal["feature", "sequence", "add", "tuple"] = "tuple",
|
134 |
+
**ignore_kwargs,
|
135 |
+
):
|
136 |
+
super().__init__()
|
137 |
+
|
138 |
+
self.vision_tower_high = CLIPVisionTower(**high_res_cfg)
|
139 |
+
self.vision_tower_low = CLIPVisionTower(**low_res_cfg)
|
140 |
+
self.low_res_size = low_res_cfg["image_size"]
|
141 |
+
self.concat_type = concat_type
|
142 |
+
|
143 |
+
self.high_layer_norm = nn.LayerNorm(high_res_cfg.get("output_dim", 1024))
|
144 |
+
self.low_layer_norm = nn.LayerNorm(low_res_cfg.get("output_dim", 1024))
|
145 |
+
|
146 |
+
if freeze_high:
|
147 |
+
for p_name, p in self.vision_tower_high.named_parameters():
|
148 |
+
p.requires_grad = False
|
149 |
+
self.vision_tower_high = self.vision_tower_high.eval()
|
150 |
+
else:
|
151 |
+
# train donwsamples and neck
|
152 |
+
for p_name, p in self.vision_tower_high.named_parameters():
|
153 |
+
if "downsamples" in p_name or "neck" in p_name:
|
154 |
+
p.requires_grad = True
|
155 |
+
else:
|
156 |
+
p.requires_grad = False
|
157 |
+
|
158 |
+
if freeze_low:
|
159 |
+
for p in self.vision_tower_low.parameters():
|
160 |
+
p.requires_grad = False
|
161 |
+
self.vision_tower_low = self.vision_tower_low.eval()
|
162 |
+
|
163 |
+
self.resize = torchvision.transforms.Resize(self.low_res_size, antialias=True)
|
164 |
+
|
165 |
+
def forward(self, images: torch.Tensor):
|
166 |
+
"""
|
167 |
+
|
168 |
+
Args:
|
169 |
+
images (torch.Tensor): [bs, 3, H, W]
|
170 |
+
|
171 |
+
Returns:
|
172 |
+
res (torch.Tensor): [bs, t, c]
|
173 |
+
"""
|
174 |
+
|
175 |
+
# [bs, c, h, w]
|
176 |
+
high_images = images
|
177 |
+
|
178 |
+
# [bs, c, h_low, w_low]
|
179 |
+
low_images = self.resize(images)
|
180 |
+
|
181 |
+
# separately run two vision towers
|
182 |
+
# run high_res vision tower
|
183 |
+
high_res = self.vision_tower_high(high_images)
|
184 |
+
# [bs, c, h, w] -> [bs, h*w, c]
|
185 |
+
high_res = rearrange(high_res, "b c h w -> b (h w) c")
|
186 |
+
# run low_res vision tower
|
187 |
+
low_res = self.vision_tower_low(low_images)
|
188 |
+
|
189 |
+
if self.concat_type == "feature":
|
190 |
+
images_features = torch.cat([high_res, low_res], dim=-1)
|
191 |
+
elif self.concat_type == "sequence":
|
192 |
+
images_features = torch.cat([high_res, low_res], dim=1)
|
193 |
+
elif self.concat_type == "add":
|
194 |
+
images_features = high_res + low_res
|
195 |
+
elif self.concat_type == "tuple":
|
196 |
+
images_features = (high_res, low_res)
|
197 |
+
|
198 |
+
else:
|
199 |
+
raise ValueError(
|
200 |
+
"Currently only support `feature`, `sequence`, `add` and `tuple` concat type."
|
201 |
+
)
|
202 |
+
|
203 |
+
return images_features
|
204 |
+
|
205 |
+
|
206 |
+
if __name__ == "__main__":
|
207 |
+
image_size = 1024
|
208 |
+
x = torch.zeros(2, 3, image_size, image_size).bfloat16().cuda()
|
209 |
+
|
210 |
+
high_res_cfg = dict(
|
211 |
+
model_name="sam_b_downsample",
|
212 |
+
select_feature="same",
|
213 |
+
image_size=image_size,
|
214 |
+
pixel_mean=(0.48145466, 0.4578275, 0.40821073),
|
215 |
+
pixel_std=(0.26862954, 0.26130258, 0.27577711),
|
216 |
+
select_layer=-1,
|
217 |
+
ckpt_path="",
|
218 |
+
)
|
219 |
+
|
220 |
+
low_res_cfg = dict(
|
221 |
+
model_name="siglip_large_patch16_384",
|
222 |
+
select_feature="same",
|
223 |
+
image_size=384,
|
224 |
+
pixel_mean=(0.5, 0.5, 0.5),
|
225 |
+
pixel_std=(0.5, 0.5, 0.5),
|
226 |
+
select_layer=-1,
|
227 |
+
ckpt_path="",
|
228 |
+
)
|
229 |
+
|
230 |
+
net = (
|
231 |
+
HybridVisionTower(
|
232 |
+
high_res_cfg=high_res_cfg,
|
233 |
+
low_res_cfg=low_res_cfg,
|
234 |
+
freeze_high=True,
|
235 |
+
freeze_low=True,
|
236 |
+
concat_type="tuple",
|
237 |
+
)
|
238 |
+
.bfloat16()
|
239 |
+
.cuda()
|
240 |
+
)
|
241 |
+
high_x, low_x = net(x)
|
242 |
+
print(x.shape, high_x.shape, low_x.shape)
|
deepseek_vl/models/image_processing_vlm.py
ADDED
@@ -0,0 +1,208 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023-2024 DeepSeek.
|
2 |
+
#
|
3 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy of
|
4 |
+
# this software and associated documentation files (the "Software"), to deal in
|
5 |
+
# the Software without restriction, including without limitation the rights to
|
6 |
+
# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
|
7 |
+
# the Software, and to permit persons to whom the Software is furnished to do so,
|
8 |
+
# subject to the following conditions:
|
9 |
+
#
|
10 |
+
# The above copyright notice and this permission notice shall be included in all
|
11 |
+
# copies or substantial portions of the Software.
|
12 |
+
#
|
13 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
14 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
|
15 |
+
# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
|
16 |
+
# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
|
17 |
+
# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
|
18 |
+
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
19 |
+
|
20 |
+
from typing import List, Tuple, Union
|
21 |
+
|
22 |
+
import numpy as np
|
23 |
+
import torch
|
24 |
+
import torchvision
|
25 |
+
import torchvision.transforms.functional
|
26 |
+
from PIL import Image
|
27 |
+
from transformers import AutoImageProcessor, PretrainedConfig
|
28 |
+
from transformers.image_processing_utils import BaseImageProcessor, BatchFeature
|
29 |
+
from transformers.image_utils import to_numpy_array
|
30 |
+
from transformers.utils import logging
|
31 |
+
|
32 |
+
logger = logging.get_logger(__name__)
|
33 |
+
|
34 |
+
ImageType = Union[np.ndarray, torch.Tensor, Image.Image]
|
35 |
+
IMAGENET_MEAN = (0.48145466, 0.4578275, 0.40821073)
|
36 |
+
IMAGENET_STD = (0.26862954, 0.26130258, 0.27577711)
|
37 |
+
IMAGENET_INCEPTION_MEAN = (0.5, 0.5, 0.5)
|
38 |
+
IMAGENET_INCEPTION_STD = (0.5, 0.5, 0.5)
|
39 |
+
|
40 |
+
|
41 |
+
def expand2square(pil_img, background_color):
|
42 |
+
width, height = pil_img.size
|
43 |
+
if width == height:
|
44 |
+
return pil_img
|
45 |
+
elif width > height:
|
46 |
+
result = Image.new(pil_img.mode, (width, width), background_color)
|
47 |
+
result.paste(pil_img, (0, (width - height) // 2))
|
48 |
+
return result
|
49 |
+
else:
|
50 |
+
result = Image.new(pil_img.mode, (height, height), background_color)
|
51 |
+
result.paste(pil_img, ((height - width) // 2, 0))
|
52 |
+
return result
|
53 |
+
|
54 |
+
|
55 |
+
class VLMImageProcessorConfig(PretrainedConfig):
|
56 |
+
model_type = "deepseek_vlm"
|
57 |
+
image_size: int
|
58 |
+
min_size: int
|
59 |
+
image_mean: Union[Tuple[float, float, float], List[float]]
|
60 |
+
image_std: Union[Tuple[float, float, float], List[float]]
|
61 |
+
rescale_factor: float
|
62 |
+
do_normalize: bool
|
63 |
+
|
64 |
+
def __init__(
|
65 |
+
self,
|
66 |
+
image_size: int,
|
67 |
+
min_size: int = 14,
|
68 |
+
image_mean: Union[Tuple[float, float, float], List[float]] = (
|
69 |
+
0.48145466,
|
70 |
+
0.4578275,
|
71 |
+
0.40821073,
|
72 |
+
),
|
73 |
+
image_std: Union[Tuple[float, float, float], List[float]] = (
|
74 |
+
0.26862954,
|
75 |
+
0.26130258,
|
76 |
+
0.27577711,
|
77 |
+
),
|
78 |
+
rescale_factor: float = 1.0 / 255.0,
|
79 |
+
do_normalize: bool = True,
|
80 |
+
**kwargs,
|
81 |
+
):
|
82 |
+
self.image_size = image_size
|
83 |
+
self.min_size = min_size
|
84 |
+
self.image_mean = image_mean
|
85 |
+
self.image_std = image_std
|
86 |
+
self.rescale_factor = rescale_factor
|
87 |
+
self.do_normalize = do_normalize
|
88 |
+
|
89 |
+
super().__init__(**kwargs)
|
90 |
+
|
91 |
+
|
92 |
+
class VLMImageProcessor(BaseImageProcessor):
|
93 |
+
model_input_names = ["pixel_values"]
|
94 |
+
|
95 |
+
def __init__(
|
96 |
+
self,
|
97 |
+
image_size: int,
|
98 |
+
min_size: int = 14,
|
99 |
+
image_mean: Union[Tuple[float, float, float], List[float]] = (
|
100 |
+
0.48145466,
|
101 |
+
0.4578275,
|
102 |
+
0.40821073,
|
103 |
+
),
|
104 |
+
image_std: Union[Tuple[float, float, float], List[float]] = (
|
105 |
+
0.26862954,
|
106 |
+
0.26130258,
|
107 |
+
0.27577711,
|
108 |
+
),
|
109 |
+
rescale_factor: float = 1.0 / 255.0,
|
110 |
+
do_normalize: bool = True,
|
111 |
+
**kwargs,
|
112 |
+
):
|
113 |
+
super().__init__(**kwargs)
|
114 |
+
|
115 |
+
self.image_size = image_size
|
116 |
+
self.rescale_factor = rescale_factor
|
117 |
+
self.image_mean = image_mean
|
118 |
+
self.image_std = image_std
|
119 |
+
self.min_size = min_size
|
120 |
+
self.do_normalize = do_normalize
|
121 |
+
|
122 |
+
if image_mean is None:
|
123 |
+
self.background_color = (127, 127, 127)
|
124 |
+
else:
|
125 |
+
self.background_color = tuple([int(x * 255) for x in image_mean])
|
126 |
+
|
127 |
+
def resize(self, pil_img: Image) -> np.ndarray:
|
128 |
+
"""
|
129 |
+
|
130 |
+
Args:
|
131 |
+
pil_img (PIL.Image): [H, W, 3] in PIL.Image in RGB
|
132 |
+
|
133 |
+
Returns:
|
134 |
+
x (np.ndarray): [3, self.image_size, self.image_size]
|
135 |
+
"""
|
136 |
+
|
137 |
+
width, height = pil_img.size
|
138 |
+
max_size = max(width, height)
|
139 |
+
|
140 |
+
size = [
|
141 |
+
max(int(height / max_size * self.image_size), self.min_size),
|
142 |
+
max(int(width / max_size * self.image_size), self.min_size),
|
143 |
+
]
|
144 |
+
|
145 |
+
if width <= 0 or height <= 0 or size[0] <= 0 or size[1] <= 0:
|
146 |
+
print(f"orig size = {pil_img.size}, new size = {size}")
|
147 |
+
raise ValueError("Invalid size!")
|
148 |
+
|
149 |
+
pil_img = torchvision.transforms.functional.resize(
|
150 |
+
pil_img,
|
151 |
+
size,
|
152 |
+
interpolation=torchvision.transforms.functional.InterpolationMode.BICUBIC,
|
153 |
+
antialias=True,
|
154 |
+
)
|
155 |
+
|
156 |
+
pil_img = expand2square(pil_img, self.background_color)
|
157 |
+
x = to_numpy_array(pil_img)
|
158 |
+
|
159 |
+
# [H, W, 3] -> [3, H, W]
|
160 |
+
x = np.transpose(x, (2, 0, 1))
|
161 |
+
|
162 |
+
return x
|
163 |
+
|
164 |
+
def preprocess(self, images, return_tensors: str = "pt", **kwargs) -> BatchFeature:
|
165 |
+
# resize and pad to [self.image_size, self.image_size]
|
166 |
+
# then convert from [H, W, 3] to [3, H, W]
|
167 |
+
images: List[np.ndarray] = [self.resize(image) for image in images]
|
168 |
+
|
169 |
+
# resacle from [0, 255] -> [0, 1]
|
170 |
+
images = [
|
171 |
+
self.rescale(
|
172 |
+
image=image,
|
173 |
+
scale=self.rescale_factor,
|
174 |
+
input_data_format="channels_first",
|
175 |
+
)
|
176 |
+
for image in images
|
177 |
+
]
|
178 |
+
|
179 |
+
# normalize
|
180 |
+
if self.do_normalize:
|
181 |
+
images = [
|
182 |
+
self.normalize(
|
183 |
+
image=image,
|
184 |
+
mean=self.image_mean,
|
185 |
+
std=self.image_std,
|
186 |
+
input_data_format="channels_first",
|
187 |
+
)
|
188 |
+
for image in images
|
189 |
+
]
|
190 |
+
|
191 |
+
data = {"pixel_values": images}
|
192 |
+
return BatchFeature(data=data, tensor_type=return_tensors)
|
193 |
+
|
194 |
+
@property
|
195 |
+
def default_shape(self):
|
196 |
+
return [3, self.image_size, self.image_size]
|
197 |
+
|
198 |
+
|
199 |
+
AutoImageProcessor.register(VLMImageProcessorConfig, VLMImageProcessor)
|
200 |
+
|
201 |
+
|
202 |
+
if __name__ == "__main__":
|
203 |
+
image_processor = VLMImageProcessor(
|
204 |
+
image_size=1024,
|
205 |
+
image_mean=IMAGENET_INCEPTION_MEAN,
|
206 |
+
image_std=IMAGENET_INCEPTION_STD,
|
207 |
+
do_normalize=True,
|
208 |
+
)
|
deepseek_vl/models/modeling_vlm.py
ADDED
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023-2024 DeepSeek.
|
2 |
+
#
|
3 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy of
|
4 |
+
# this software and associated documentation files (the "Software"), to deal in
|
5 |
+
# the Software without restriction, including without limitation the rights to
|
6 |
+
# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
|
7 |
+
# the Software, and to permit persons to whom the Software is furnished to do so,
|
8 |
+
# subject to the following conditions:
|
9 |
+
#
|
10 |
+
# The above copyright notice and this permission notice shall be included in all
|
11 |
+
# copies or substantial portions of the Software.
|
12 |
+
#
|
13 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
14 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
|
15 |
+
# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
|
16 |
+
# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
|
17 |
+
# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
|
18 |
+
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
19 |
+
|
20 |
+
import torch
|
21 |
+
from attrdict import AttrDict
|
22 |
+
from einops import rearrange
|
23 |
+
from transformers import (
|
24 |
+
AutoConfig,
|
25 |
+
AutoModelForCausalLM,
|
26 |
+
LlamaConfig,
|
27 |
+
LlamaForCausalLM,
|
28 |
+
PreTrainedModel,
|
29 |
+
)
|
30 |
+
from transformers.configuration_utils import PretrainedConfig
|
31 |
+
|
32 |
+
from deepseek_vl.models.clip_encoder import CLIPVisionTower, HybridVisionTower
|
33 |
+
from deepseek_vl.models.projector import MlpProjector
|
34 |
+
|
35 |
+
|
36 |
+
def model_name_to_cls(cls_name):
|
37 |
+
if "MlpProjector" in cls_name:
|
38 |
+
cls = MlpProjector
|
39 |
+
|
40 |
+
elif "CLIPVisionTower" in cls_name:
|
41 |
+
cls = CLIPVisionTower
|
42 |
+
|
43 |
+
elif "HybridVisionTower" in cls_name:
|
44 |
+
cls = HybridVisionTower
|
45 |
+
|
46 |
+
else:
|
47 |
+
raise ValueError(f"class_name {cls_name} is invalid.")
|
48 |
+
|
49 |
+
return cls
|
50 |
+
|
51 |
+
|
52 |
+
class VisionConfig(PretrainedConfig):
|
53 |
+
model_type = "vision"
|
54 |
+
cls: str = ""
|
55 |
+
params: AttrDict = {}
|
56 |
+
|
57 |
+
def __init__(self, **kwargs):
|
58 |
+
super().__init__(**kwargs)
|
59 |
+
|
60 |
+
self.cls = kwargs.get("cls", "")
|
61 |
+
if not isinstance(self.cls, str):
|
62 |
+
self.cls = self.cls.__name__
|
63 |
+
|
64 |
+
self.params = AttrDict(kwargs.get("params", {}))
|
65 |
+
|
66 |
+
|
67 |
+
class AlignerConfig(PretrainedConfig):
|
68 |
+
model_type = "aligner"
|
69 |
+
cls: str = ""
|
70 |
+
params: AttrDict = {}
|
71 |
+
|
72 |
+
def __init__(self, **kwargs):
|
73 |
+
super().__init__(**kwargs)
|
74 |
+
|
75 |
+
self.cls = kwargs.get("cls", "")
|
76 |
+
if not isinstance(self.cls, str):
|
77 |
+
self.cls = self.cls.__name__
|
78 |
+
|
79 |
+
self.params = AttrDict(kwargs.get("params", {}))
|
80 |
+
|
81 |
+
|
82 |
+
class MultiModalityConfig(PretrainedConfig):
|
83 |
+
model_type = "multi_modality"
|
84 |
+
vision_config: VisionConfig
|
85 |
+
aligner_config: AlignerConfig
|
86 |
+
language_config: LlamaConfig
|
87 |
+
|
88 |
+
def __init__(self, **kwargs):
|
89 |
+
super().__init__(**kwargs)
|
90 |
+
vision_config = kwargs.get("vision_config", {})
|
91 |
+
self.vision_config = VisionConfig(**vision_config)
|
92 |
+
|
93 |
+
aligner_config = kwargs.get("aligner_config", {})
|
94 |
+
self.aligner_config = AlignerConfig(**aligner_config)
|
95 |
+
|
96 |
+
language_config = kwargs.get("language_config", {})
|
97 |
+
if isinstance(language_config, LlamaConfig):
|
98 |
+
self.language_config = language_config
|
99 |
+
else:
|
100 |
+
self.language_config = LlamaConfig(**language_config)
|
101 |
+
|
102 |
+
|
103 |
+
class MultiModalityPreTrainedModel(PreTrainedModel):
|
104 |
+
config_class = MultiModalityConfig
|
105 |
+
base_model_prefix = "multi_modality"
|
106 |
+
_no_split_modules = []
|
107 |
+
_skip_keys_device_placement = "past_key_values"
|
108 |
+
|
109 |
+
|
110 |
+
class MultiModalityCausalLM(MultiModalityPreTrainedModel):
|
111 |
+
def __init__(self, config: MultiModalityConfig):
|
112 |
+
super().__init__(config)
|
113 |
+
|
114 |
+
vision_config = config.vision_config
|
115 |
+
vision_cls = model_name_to_cls(vision_config.cls)
|
116 |
+
self.vision_model = vision_cls(**vision_config.params)
|
117 |
+
|
118 |
+
aligner_config = config.aligner_config
|
119 |
+
aligner_cls = model_name_to_cls(aligner_config.cls)
|
120 |
+
self.aligner = aligner_cls(aligner_config.params)
|
121 |
+
|
122 |
+
language_config = config.language_config
|
123 |
+
self.language_model = LlamaForCausalLM(language_config)
|
124 |
+
|
125 |
+
def prepare_inputs_embeds(
|
126 |
+
self,
|
127 |
+
input_ids: torch.LongTensor,
|
128 |
+
pixel_values: torch.FloatTensor,
|
129 |
+
images_seq_mask: torch.LongTensor,
|
130 |
+
images_emb_mask: torch.LongTensor,
|
131 |
+
**kwargs,
|
132 |
+
):
|
133 |
+
"""
|
134 |
+
|
135 |
+
Args:
|
136 |
+
input_ids (torch.LongTensor): [b, T]
|
137 |
+
pixel_values (torch.FloatTensor): [b, n_images, 3, h, w]
|
138 |
+
images_seq_mask (torch.BoolTensor): [b, T]
|
139 |
+
images_emb_mask (torch.BoolTensor): [b, n_images, n_image_tokens]
|
140 |
+
|
141 |
+
assert torch.sum(images_seq_mask) == torch.sum(images_emb_mask)
|
142 |
+
|
143 |
+
Returns:
|
144 |
+
input_embeds (torch.Tensor): [b, T, D]
|
145 |
+
"""
|
146 |
+
|
147 |
+
bs, n = pixel_values.shape[0:2]
|
148 |
+
images = rearrange(pixel_values, "b n c h w -> (b n) c h w")
|
149 |
+
# [b x n, T2, D]
|
150 |
+
images_embeds = self.aligner(self.vision_model(images))
|
151 |
+
|
152 |
+
# [b x n, T2, D] -> [b, n x T2, D]
|
153 |
+
images_embeds = rearrange(images_embeds, "(b n) t d -> b (n t) d", b=bs, n=n)
|
154 |
+
# [b, n, T2] -> [b, n x T2]
|
155 |
+
images_emb_mask = rearrange(images_emb_mask, "b n t -> b (n t)")
|
156 |
+
|
157 |
+
# [b, T, D]
|
158 |
+
input_ids[input_ids < 0] = 0 # ignore the image embeddings
|
159 |
+
inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
|
160 |
+
|
161 |
+
# replace with the image embeddings
|
162 |
+
inputs_embeds[images_seq_mask] = images_embeds[images_emb_mask]
|
163 |
+
|
164 |
+
return inputs_embeds
|
165 |
+
|
166 |
+
|
167 |
+
AutoConfig.register("vision", VisionConfig)
|
168 |
+
AutoConfig.register("aligner", AlignerConfig)
|
169 |
+
AutoConfig.register("multi_modality", MultiModalityConfig)
|
170 |
+
AutoModelForCausalLM.register(MultiModalityConfig, MultiModalityCausalLM)
|
deepseek_vl/models/processing_vlm.py
ADDED
@@ -0,0 +1,390 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023-2024 DeepSeek.
|
2 |
+
#
|
3 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy of
|
4 |
+
# this software and associated documentation files (the "Software"), to deal in
|
5 |
+
# the Software without restriction, including without limitation the rights to
|
6 |
+
# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
|
7 |
+
# the Software, and to permit persons to whom the Software is furnished to do so,
|
8 |
+
# subject to the following conditions:
|
9 |
+
#
|
10 |
+
# The above copyright notice and this permission notice shall be included in all
|
11 |
+
# copies or substantial portions of the Software.
|
12 |
+
#
|
13 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
14 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
|
15 |
+
# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
|
16 |
+
# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
|
17 |
+
# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
|
18 |
+
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
19 |
+
|
20 |
+
from dataclasses import dataclass
|
21 |
+
from typing import Dict, List
|
22 |
+
|
23 |
+
import torch
|
24 |
+
from PIL.Image import Image
|
25 |
+
from transformers import LlamaTokenizerFast
|
26 |
+
from transformers.processing_utils import ProcessorMixin
|
27 |
+
|
28 |
+
from deepseek_vl.models.image_processing_vlm import VLMImageProcessor
|
29 |
+
from deepseek_vl.utils.conversation import get_conv_template
|
30 |
+
|
31 |
+
|
32 |
+
class DictOutput(object):
|
33 |
+
def keys(self):
|
34 |
+
return self.__dict__.keys()
|
35 |
+
|
36 |
+
def __getitem__(self, item):
|
37 |
+
return self.__dict__[item]
|
38 |
+
|
39 |
+
def __setitem__(self, key, value):
|
40 |
+
self.__dict__[key] = value
|
41 |
+
|
42 |
+
|
43 |
+
@dataclass
|
44 |
+
class VLChatProcessorOutput(DictOutput):
|
45 |
+
sft_format: str
|
46 |
+
input_ids: torch.Tensor
|
47 |
+
pixel_values: torch.Tensor
|
48 |
+
num_image_tokens: torch.IntTensor
|
49 |
+
|
50 |
+
def __len__(self):
|
51 |
+
return len(self.input_ids)
|
52 |
+
|
53 |
+
|
54 |
+
@dataclass
|
55 |
+
class BatchedVLChatProcessorOutput(DictOutput):
|
56 |
+
sft_format: List[str]
|
57 |
+
input_ids: torch.Tensor
|
58 |
+
pixel_values: torch.Tensor
|
59 |
+
attention_mask: torch.Tensor
|
60 |
+
images_seq_mask: torch.BoolTensor
|
61 |
+
images_emb_mask: torch.BoolTensor
|
62 |
+
|
63 |
+
def to(self, device, dtype=torch.bfloat16):
|
64 |
+
self.input_ids = self.input_ids.to(device)
|
65 |
+
self.attention_mask = self.attention_mask.to(device)
|
66 |
+
self.images_seq_mask = self.images_seq_mask.to(device)
|
67 |
+
self.images_emb_mask = self.images_emb_mask.to(device)
|
68 |
+
self.pixel_values = self.pixel_values.to(device=device, dtype=dtype)
|
69 |
+
return self
|
70 |
+
|
71 |
+
|
72 |
+
class VLChatProcessor(ProcessorMixin):
|
73 |
+
image_processor_class = "AutoImageProcessor"
|
74 |
+
tokenizer_class = ("LlamaTokenizer", "LlamaTokenizerFast")
|
75 |
+
|
76 |
+
attributes = ["image_processor", "tokenizer"]
|
77 |
+
|
78 |
+
system_prompt = (
|
79 |
+
"You are a helpful language and vision assistant. "
|
80 |
+
"You are able to understand the visual content that the user provides, "
|
81 |
+
"and assist the user with a variety of tasks using natural language."
|
82 |
+
)
|
83 |
+
|
84 |
+
def __init__(
|
85 |
+
self,
|
86 |
+
image_processor: VLMImageProcessor,
|
87 |
+
tokenizer: LlamaTokenizerFast,
|
88 |
+
image_tag: str = "<image_placeholder>",
|
89 |
+
num_image_tokens: int = 576,
|
90 |
+
add_special_token: bool = False,
|
91 |
+
sft_format: str = "deepseek",
|
92 |
+
mask_prompt: bool = True,
|
93 |
+
ignore_id: int = -100,
|
94 |
+
**kwargs,
|
95 |
+
):
|
96 |
+
self.image_processor = image_processor
|
97 |
+
self.tokenizer = tokenizer
|
98 |
+
|
99 |
+
image_id = self.tokenizer.vocab.get(image_tag)
|
100 |
+
if image_id is None:
|
101 |
+
special_tokens = [image_tag]
|
102 |
+
special_tokens_dict = {"additional_special_tokens": special_tokens}
|
103 |
+
self.tokenizer.add_special_tokens(special_tokens_dict)
|
104 |
+
print(f"Add image tag = {image_tag} to the tokenizer")
|
105 |
+
|
106 |
+
self.image_tag = image_tag
|
107 |
+
self.num_image_tokens = num_image_tokens
|
108 |
+
self.add_special_token = add_special_token
|
109 |
+
self.sft_format = sft_format
|
110 |
+
self.mask_prompt = mask_prompt
|
111 |
+
self.ignore_id = ignore_id
|
112 |
+
|
113 |
+
super().__init__(
|
114 |
+
image_processor,
|
115 |
+
tokenizer,
|
116 |
+
image_tag,
|
117 |
+
num_image_tokens,
|
118 |
+
add_special_token,
|
119 |
+
sft_format,
|
120 |
+
mask_prompt,
|
121 |
+
ignore_id,
|
122 |
+
**kwargs,
|
123 |
+
)
|
124 |
+
|
125 |
+
def new_chat_template(self):
|
126 |
+
conv = get_conv_template(self.sft_format)
|
127 |
+
conv.set_system_message(self.system_prompt)
|
128 |
+
return conv
|
129 |
+
|
130 |
+
def apply_sft_template_for_multi_turn_prompts(
|
131 |
+
self,
|
132 |
+
conversations: List[Dict[str, str]],
|
133 |
+
sft_format: str = "deepseek",
|
134 |
+
system_prompt: str = "",
|
135 |
+
):
|
136 |
+
"""
|
137 |
+
Applies the SFT template to conversation.
|
138 |
+
|
139 |
+
An example of conversation:
|
140 |
+
conversation = [
|
141 |
+
{
|
142 |
+
"role": "User",
|
143 |
+
"content": "<image_placeholder> is Figure 1.\n<image_placeholder> is Figure 2.\nWhich image is brighter?",
|
144 |
+
"images": [
|
145 |
+
"./multi-images/attribute_comparison_1.png",
|
146 |
+
"./multi-images/attribute_comparison_2.png"
|
147 |
+
]
|
148 |
+
},
|
149 |
+
{
|
150 |
+
"role": "Assistant",
|
151 |
+
"content": ""
|
152 |
+
}
|
153 |
+
]
|
154 |
+
|
155 |
+
Args:
|
156 |
+
conversations (List[Dict]): A conversation with a List of Dict[str, str] text.
|
157 |
+
sft_format (str, optional): The format of the SFT template to use. Defaults to "deepseek".
|
158 |
+
system_prompt (str, optional): The system prompt to use in the SFT template. Defaults to "".
|
159 |
+
|
160 |
+
Returns:
|
161 |
+
sft_prompt (str): The formatted text.
|
162 |
+
"""
|
163 |
+
|
164 |
+
conv = get_conv_template(sft_format)
|
165 |
+
conv.set_system_message(system_prompt)
|
166 |
+
for message in conversations:
|
167 |
+
conv.append_message(message["role"], message["content"].strip())
|
168 |
+
sft_prompt = conv.get_prompt().strip()
|
169 |
+
|
170 |
+
return sft_prompt
|
171 |
+
|
172 |
+
@property
|
173 |
+
def image_token(self):
|
174 |
+
return self.image_tag
|
175 |
+
|
176 |
+
@property
|
177 |
+
def image_id(self):
|
178 |
+
image_id = self.tokenizer.vocab.get(self.image_tag)
|
179 |
+
return image_id
|
180 |
+
|
181 |
+
@property
|
182 |
+
def pad_id(self):
|
183 |
+
pad_id = self.tokenizer.pad_token_id
|
184 |
+
if pad_id is None:
|
185 |
+
pad_id = self.tokenizer.eos_token_id
|
186 |
+
|
187 |
+
return pad_id
|
188 |
+
|
189 |
+
def add_image_token(
|
190 |
+
self,
|
191 |
+
image_indices: List[int],
|
192 |
+
input_ids: torch.LongTensor,
|
193 |
+
):
|
194 |
+
"""
|
195 |
+
|
196 |
+
Args:
|
197 |
+
image_indices (List[int]): [index_0, index_1, ..., index_j]
|
198 |
+
input_ids (torch.LongTensor): [N]
|
199 |
+
|
200 |
+
Returns:
|
201 |
+
input_ids (torch.LongTensor): [N + image tokens]
|
202 |
+
num_image_tokens (torch.IntTensor): [n_images]
|
203 |
+
"""
|
204 |
+
|
205 |
+
input_slices = []
|
206 |
+
|
207 |
+
start = 0
|
208 |
+
for index in image_indices:
|
209 |
+
if self.add_special_token:
|
210 |
+
end = index + 1
|
211 |
+
else:
|
212 |
+
end = index
|
213 |
+
|
214 |
+
# original text tokens
|
215 |
+
input_slices.append(input_ids[start:end])
|
216 |
+
|
217 |
+
# add image tokens, and set the mask as False
|
218 |
+
input_slices.append(
|
219 |
+
self.image_id * torch.ones((self.num_image_tokens,), dtype=torch.long)
|
220 |
+
)
|
221 |
+
start = index + 1
|
222 |
+
|
223 |
+
# the left part
|
224 |
+
input_slices.append(input_ids[start:])
|
225 |
+
|
226 |
+
# concat all slices
|
227 |
+
input_ids = torch.cat(input_slices, dim=0)
|
228 |
+
num_image_tokens = torch.IntTensor([self.num_image_tokens] * len(image_indices))
|
229 |
+
|
230 |
+
return input_ids, num_image_tokens
|
231 |
+
|
232 |
+
def process_one(
|
233 |
+
self,
|
234 |
+
prompt: str = None,
|
235 |
+
conversations: List[Dict[str, str]] = None,
|
236 |
+
images: List[Image] = None,
|
237 |
+
**kwargs,
|
238 |
+
):
|
239 |
+
"""
|
240 |
+
|
241 |
+
Args:
|
242 |
+
prompt (str): the formatted prompt;
|
243 |
+
conversations (List[Dict]): conversations with a list of messages;
|
244 |
+
images (List[ImageType]): the list of images;
|
245 |
+
**kwargs:
|
246 |
+
|
247 |
+
Returns:
|
248 |
+
outputs (BaseProcessorOutput): the output of the processor,
|
249 |
+
- input_ids (torch.LongTensor): [N + image tokens]
|
250 |
+
- target_ids (torch.LongTensor): [N + image tokens]
|
251 |
+
- images (torch.FloatTensor): [n_images, 3, H, W]
|
252 |
+
- image_id (int): the id of the image token
|
253 |
+
- num_image_tokens (List[int]): the number of image tokens
|
254 |
+
"""
|
255 |
+
|
256 |
+
assert (
|
257 |
+
prompt is None or conversations is None
|
258 |
+
), "prompt and conversations cannot be used at the same time."
|
259 |
+
|
260 |
+
if prompt is None:
|
261 |
+
# apply sft format
|
262 |
+
sft_format = self.apply_sft_template_for_multi_turn_prompts(
|
263 |
+
conversations=conversations,
|
264 |
+
sft_format=self.sft_format,
|
265 |
+
system_prompt=self.system_prompt,
|
266 |
+
)
|
267 |
+
else:
|
268 |
+
sft_format = prompt
|
269 |
+
|
270 |
+
# tokenize
|
271 |
+
input_ids = self.tokenizer.encode(sft_format)
|
272 |
+
input_ids = torch.LongTensor(input_ids)
|
273 |
+
|
274 |
+
# add image tokens to the input_ids
|
275 |
+
image_token_mask: torch.BoolTensor = input_ids == self.image_id
|
276 |
+
image_indices = image_token_mask.nonzero()
|
277 |
+
input_ids, num_image_tokens = self.add_image_token(
|
278 |
+
image_indices=image_indices,
|
279 |
+
input_ids=input_ids,
|
280 |
+
)
|
281 |
+
|
282 |
+
# load images
|
283 |
+
images_outputs = self.image_processor(images, return_tensors="pt")
|
284 |
+
|
285 |
+
prepare = VLChatProcessorOutput(
|
286 |
+
sft_format=sft_format,
|
287 |
+
input_ids=input_ids,
|
288 |
+
pixel_values=images_outputs.pixel_values,
|
289 |
+
num_image_tokens=num_image_tokens,
|
290 |
+
)
|
291 |
+
|
292 |
+
return prepare
|
293 |
+
|
294 |
+
def __call__(
|
295 |
+
self,
|
296 |
+
*,
|
297 |
+
prompt: str = None,
|
298 |
+
conversations: List[Dict[str, str]] = None,
|
299 |
+
images: List[Image] = None,
|
300 |
+
force_batchify: bool = True,
|
301 |
+
**kwargs,
|
302 |
+
):
|
303 |
+
"""
|
304 |
+
|
305 |
+
Args:
|
306 |
+
prompt (str): the formatted prompt;
|
307 |
+
conversations (List[Dict]): conversations with a list of messages;
|
308 |
+
images (List[ImageType]): the list of images;
|
309 |
+
force_batchify (bool): force batchify the inputs;
|
310 |
+
**kwargs:
|
311 |
+
|
312 |
+
Returns:
|
313 |
+
outputs (BaseProcessorOutput): the output of the processor,
|
314 |
+
- input_ids (torch.LongTensor): [N + image tokens]
|
315 |
+
- images (torch.FloatTensor): [n_images, 3, H, W]
|
316 |
+
- image_id (int): the id of the image token
|
317 |
+
- num_image_tokens (List[int]): the number of image tokens
|
318 |
+
"""
|
319 |
+
|
320 |
+
prepare = self.process_one(
|
321 |
+
prompt=prompt, conversations=conversations, images=images
|
322 |
+
)
|
323 |
+
|
324 |
+
if force_batchify:
|
325 |
+
prepare = self.batchify([prepare])
|
326 |
+
|
327 |
+
return prepare
|
328 |
+
|
329 |
+
def batchify(
|
330 |
+
self, prepare_list: List[VLChatProcessorOutput]
|
331 |
+
) -> BatchedVLChatProcessorOutput:
|
332 |
+
"""
|
333 |
+
Preprocesses the inputs for multimodal inference.
|
334 |
+
|
335 |
+
Args:
|
336 |
+
prepare_list (List[VLChatProcessorOutput]): A list of VLChatProcessorOutput.
|
337 |
+
|
338 |
+
Returns:
|
339 |
+
BatchedVLChatProcessorOutput: A dictionary of the inputs to use for multimodal inference.
|
340 |
+
"""
|
341 |
+
|
342 |
+
batch_size = len(prepare_list)
|
343 |
+
sft_format = []
|
344 |
+
n_images = []
|
345 |
+
seq_lens = []
|
346 |
+
for prepare in prepare_list:
|
347 |
+
n_images.append(len(prepare.num_image_tokens))
|
348 |
+
seq_lens.append(len(prepare))
|
349 |
+
|
350 |
+
input_token_max_len = max(seq_lens)
|
351 |
+
max_n_images = max(1, max(n_images))
|
352 |
+
|
353 |
+
batched_input_ids = torch.full(
|
354 |
+
(batch_size, input_token_max_len), self.pad_id
|
355 |
+
).long() # FIXME
|
356 |
+
batched_attention_mask = torch.zeros((batch_size, input_token_max_len)).long()
|
357 |
+
batched_pixel_values = torch.zeros(
|
358 |
+
(batch_size, max_n_images, *self.image_processor.default_shape)
|
359 |
+
).float()
|
360 |
+
batched_images_seq_mask = torch.zeros((batch_size, input_token_max_len)).bool()
|
361 |
+
batched_images_emb_mask = torch.zeros(
|
362 |
+
(batch_size, max_n_images, self.num_image_tokens)
|
363 |
+
).bool()
|
364 |
+
|
365 |
+
for i, prepare in enumerate(prepare_list):
|
366 |
+
input_ids = prepare.input_ids
|
367 |
+
seq_len = len(prepare)
|
368 |
+
n_image = len(prepare.num_image_tokens)
|
369 |
+
# left-padding
|
370 |
+
batched_attention_mask[i, -seq_len:] = 1
|
371 |
+
batched_input_ids[i, -seq_len:] = torch.LongTensor(input_ids)
|
372 |
+
batched_images_seq_mask[i, -seq_len:] = input_ids == self.image_id
|
373 |
+
|
374 |
+
if n_image > 0:
|
375 |
+
batched_pixel_values[i, :n_image] = prepare.pixel_values
|
376 |
+
for j, n_image_tokens in enumerate(prepare.num_image_tokens):
|
377 |
+
batched_images_emb_mask[i, j, :n_image_tokens] = True
|
378 |
+
|
379 |
+
sft_format.append(prepare.sft_format)
|
380 |
+
|
381 |
+
batched_prepares = BatchedVLChatProcessorOutput(
|
382 |
+
input_ids=batched_input_ids,
|
383 |
+
attention_mask=batched_attention_mask,
|
384 |
+
pixel_values=batched_pixel_values,
|
385 |
+
images_seq_mask=batched_images_seq_mask,
|
386 |
+
images_emb_mask=batched_images_emb_mask,
|
387 |
+
sft_format=sft_format,
|
388 |
+
)
|
389 |
+
|
390 |
+
return batched_prepares
|
deepseek_vl/models/projector.py
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023-2024 DeepSeek.
|
2 |
+
#
|
3 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy of
|
4 |
+
# this software and associated documentation files (the "Software"), to deal in
|
5 |
+
# the Software without restriction, including without limitation the rights to
|
6 |
+
# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
|
7 |
+
# the Software, and to permit persons to whom the Software is furnished to do so,
|
8 |
+
# subject to the following conditions:
|
9 |
+
#
|
10 |
+
# The above copyright notice and this permission notice shall be included in all
|
11 |
+
# copies or substantial portions of the Software.
|
12 |
+
#
|
13 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
14 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
|
15 |
+
# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
|
16 |
+
# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
|
17 |
+
# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
|
18 |
+
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
19 |
+
|
20 |
+
from typing import Tuple, Union
|
21 |
+
|
22 |
+
import torch
|
23 |
+
import torch.nn as nn
|
24 |
+
from attrdict import AttrDict
|
25 |
+
|
26 |
+
|
27 |
+
class MlpProjector(nn.Module):
|
28 |
+
def __init__(self, cfg):
|
29 |
+
super().__init__()
|
30 |
+
|
31 |
+
self.cfg = cfg
|
32 |
+
|
33 |
+
if cfg.projector_type == "identity":
|
34 |
+
modules = nn.Identity()
|
35 |
+
|
36 |
+
elif cfg.projector_type == "linear":
|
37 |
+
modules = nn.Linear(cfg.input_dim, cfg.n_embed)
|
38 |
+
|
39 |
+
elif cfg.projector_type == "mlp_gelu":
|
40 |
+
mlp_depth = cfg.get("depth", 1)
|
41 |
+
modules = [nn.Linear(cfg.input_dim, cfg.n_embed)]
|
42 |
+
for _ in range(1, mlp_depth):
|
43 |
+
modules.append(nn.GELU())
|
44 |
+
modules.append(nn.Linear(cfg.n_embed, cfg.n_embed))
|
45 |
+
modules = nn.Sequential(*modules)
|
46 |
+
|
47 |
+
elif cfg.projector_type == "low_high_hybrid_split_mlp_gelu":
|
48 |
+
mlp_depth = cfg.get("depth", 1)
|
49 |
+
self.high_up_proj = nn.Linear(cfg.input_dim, cfg.n_embed // 2)
|
50 |
+
self.low_up_proj = nn.Linear(cfg.input_dim, cfg.n_embed // 2)
|
51 |
+
|
52 |
+
modules = []
|
53 |
+
for _ in range(1, mlp_depth):
|
54 |
+
modules.append(nn.GELU())
|
55 |
+
modules.append(nn.Linear(cfg.n_embed, cfg.n_embed))
|
56 |
+
modules = nn.Sequential(*modules)
|
57 |
+
|
58 |
+
else:
|
59 |
+
raise ValueError(f"Unknown projector type: {cfg.projector_type}")
|
60 |
+
|
61 |
+
self.layers = modules
|
62 |
+
|
63 |
+
def forward(
|
64 |
+
self, x_or_tuple: Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor]
|
65 |
+
):
|
66 |
+
"""
|
67 |
+
|
68 |
+
Args:
|
69 |
+
x_or_tuple (Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor]: if it is a tuple of torch.Tensor,
|
70 |
+
then it comes from the hybrid vision encoder, and x = high_res_x, low_res_x);
|
71 |
+
otherwise it is the feature from the single vision encoder.
|
72 |
+
|
73 |
+
Returns:
|
74 |
+
x (torch.Tensor): [b, s, c]
|
75 |
+
"""
|
76 |
+
|
77 |
+
if isinstance(x_or_tuple, tuple):
|
78 |
+
# self.cfg.projector_type == "low_high_hybrid_split_mlp_gelu":
|
79 |
+
high_x, low_x = x_or_tuple
|
80 |
+
high_x = self.high_up_proj(high_x)
|
81 |
+
low_x = self.low_up_proj(low_x)
|
82 |
+
x = torch.concat([high_x, low_x], dim=-1)
|
83 |
+
else:
|
84 |
+
x = x_or_tuple
|
85 |
+
|
86 |
+
return self.layers(x)
|
87 |
+
|
88 |
+
|
89 |
+
if __name__ == "__main__":
|
90 |
+
cfg = AttrDict(
|
91 |
+
input_dim=1024,
|
92 |
+
n_embed=2048,
|
93 |
+
depth=2,
|
94 |
+
projector_type="low_high_hybrid_split_mlp_gelu",
|
95 |
+
)
|
96 |
+
inputs = (torch.rand(4, 576, 1024), torch.rand(4, 576, 1024))
|
97 |
+
|
98 |
+
m = MlpProjector(cfg)
|
99 |
+
out = m(inputs)
|
100 |
+
print(out.shape)
|
deepseek_vl/models/sam.py
ADDED
@@ -0,0 +1,593 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import copy
|
8 |
+
from dataclasses import dataclass
|
9 |
+
from functools import partial
|
10 |
+
from typing import List, Optional, Tuple, Type, Union
|
11 |
+
|
12 |
+
import torch
|
13 |
+
import torch.nn as nn
|
14 |
+
import torch.nn.functional as F
|
15 |
+
|
16 |
+
|
17 |
+
class MLPBlock(nn.Module):
|
18 |
+
def __init__(
|
19 |
+
self,
|
20 |
+
embedding_dim: int,
|
21 |
+
mlp_dim: int,
|
22 |
+
act: Type[nn.Module] = nn.GELU,
|
23 |
+
) -> None:
|
24 |
+
super().__init__()
|
25 |
+
self.lin1 = nn.Linear(embedding_dim, mlp_dim)
|
26 |
+
self.lin2 = nn.Linear(mlp_dim, embedding_dim)
|
27 |
+
self.act = act()
|
28 |
+
|
29 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
30 |
+
return self.lin2(self.act(self.lin1(x)))
|
31 |
+
|
32 |
+
|
33 |
+
# From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa
|
34 |
+
# Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa
|
35 |
+
class LayerNorm2d(nn.Module):
|
36 |
+
def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
|
37 |
+
super().__init__()
|
38 |
+
self.weight = nn.Parameter(torch.ones(num_channels))
|
39 |
+
self.bias = nn.Parameter(torch.zeros(num_channels))
|
40 |
+
self.eps = eps
|
41 |
+
|
42 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
43 |
+
u = x.mean(1, keepdim=True)
|
44 |
+
s = (x - u).pow(2).mean(1, keepdim=True)
|
45 |
+
x = (x - u) / torch.sqrt(s + self.eps)
|
46 |
+
x = self.weight[:, None, None] * x + self.bias[:, None, None]
|
47 |
+
return x
|
48 |
+
|
49 |
+
|
50 |
+
# This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py # noqa
|
51 |
+
class ImageEncoderViT(nn.Module):
|
52 |
+
def __init__(
|
53 |
+
self,
|
54 |
+
img_size: int = 1024,
|
55 |
+
patch_size: int = 16,
|
56 |
+
in_chans: int = 3,
|
57 |
+
embed_dim: int = 768,
|
58 |
+
depth: int = 12,
|
59 |
+
num_heads: int = 12,
|
60 |
+
mlp_ratio: float = 4.0,
|
61 |
+
out_chans: int = 256,
|
62 |
+
qkv_bias: bool = True,
|
63 |
+
norm_layer: Type[nn.Module] = nn.LayerNorm,
|
64 |
+
act_layer: Type[nn.Module] = nn.GELU,
|
65 |
+
use_abs_pos: bool = True,
|
66 |
+
use_rel_pos: bool = False,
|
67 |
+
rel_pos_zero_init: bool = True,
|
68 |
+
window_size: int = 0,
|
69 |
+
global_attn_indexes: Tuple[int, ...] = (),
|
70 |
+
downsample_channels: Tuple[int, ...] = (512, 1024),
|
71 |
+
) -> None:
|
72 |
+
"""
|
73 |
+
Args:
|
74 |
+
img_size (int): Input image size.
|
75 |
+
patch_size (int): Patch size.
|
76 |
+
in_chans (int): Number of input image channels.
|
77 |
+
embed_dim (int): Patch embedding dimension.
|
78 |
+
depth (int): Depth of ViT.
|
79 |
+
num_heads (int): Number of attention heads in each ViT block.
|
80 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
81 |
+
qkv_bias (bool): If True, add a learnable bias to query, key, value.
|
82 |
+
norm_layer (nn.Module): Normalization layer.
|
83 |
+
act_layer (nn.Module): Activation layer.
|
84 |
+
use_abs_pos (bool): If True, use absolute positional embeddings.
|
85 |
+
use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
|
86 |
+
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
|
87 |
+
window_size (int): Window size for window attention blocks.
|
88 |
+
global_attn_indexes (list): Indexes for blocks using global attention.
|
89 |
+
downsample_channels (list): Channels for downsampling layers.
|
90 |
+
"""
|
91 |
+
super().__init__()
|
92 |
+
self.img_size = img_size
|
93 |
+
|
94 |
+
self.patch_embed = PatchEmbed(
|
95 |
+
kernel_size=(patch_size, patch_size),
|
96 |
+
stride=(patch_size, patch_size),
|
97 |
+
in_chans=in_chans,
|
98 |
+
embed_dim=embed_dim,
|
99 |
+
)
|
100 |
+
|
101 |
+
self.pos_embed: Optional[nn.Parameter] = None
|
102 |
+
if use_abs_pos:
|
103 |
+
# Initialize absolute positional embedding with pretrain image size.
|
104 |
+
self.pos_embed = nn.Parameter(
|
105 |
+
torch.zeros(
|
106 |
+
1, img_size // patch_size, img_size // patch_size, embed_dim
|
107 |
+
)
|
108 |
+
)
|
109 |
+
|
110 |
+
self.blocks = nn.ModuleList()
|
111 |
+
for i in range(depth):
|
112 |
+
block = Block(
|
113 |
+
dim=embed_dim,
|
114 |
+
num_heads=num_heads,
|
115 |
+
mlp_ratio=mlp_ratio,
|
116 |
+
qkv_bias=qkv_bias,
|
117 |
+
norm_layer=norm_layer,
|
118 |
+
act_layer=act_layer,
|
119 |
+
use_rel_pos=use_rel_pos,
|
120 |
+
rel_pos_zero_init=rel_pos_zero_init,
|
121 |
+
window_size=window_size if i not in global_attn_indexes else 0,
|
122 |
+
input_size=(img_size // patch_size, img_size // patch_size),
|
123 |
+
)
|
124 |
+
self.blocks.append(block)
|
125 |
+
|
126 |
+
self.neck = nn.Sequential(
|
127 |
+
nn.Conv2d(
|
128 |
+
embed_dim,
|
129 |
+
out_chans,
|
130 |
+
kernel_size=1,
|
131 |
+
bias=False,
|
132 |
+
),
|
133 |
+
LayerNorm2d(out_chans),
|
134 |
+
nn.Conv2d(
|
135 |
+
out_chans,
|
136 |
+
out_chans,
|
137 |
+
kernel_size=3,
|
138 |
+
padding=1,
|
139 |
+
bias=False,
|
140 |
+
),
|
141 |
+
LayerNorm2d(out_chans),
|
142 |
+
)
|
143 |
+
|
144 |
+
in_channels = out_chans
|
145 |
+
downsamples = []
|
146 |
+
for i in range(len(downsample_channels)):
|
147 |
+
out_channels = downsample_channels[i]
|
148 |
+
downsamples.append(
|
149 |
+
nn.Conv2d(
|
150 |
+
in_channels,
|
151 |
+
out_channels,
|
152 |
+
kernel_size=3,
|
153 |
+
stride=2,
|
154 |
+
padding=1,
|
155 |
+
bias=False,
|
156 |
+
)
|
157 |
+
)
|
158 |
+
in_channels = out_channels
|
159 |
+
self.downsamples = nn.Sequential(*downsamples)
|
160 |
+
|
161 |
+
self.sam_hd = True
|
162 |
+
if self.sam_hd:
|
163 |
+
self.hd_alpha_downsamples = nn.Parameter(torch.zeros(1))
|
164 |
+
# self.neck_hd = nn.Linear(embed_dim, embed_dim)
|
165 |
+
self.neck_hd = copy.deepcopy(self.neck)
|
166 |
+
# self.downsamples_hd = copy.deepcopy(self.downsamples)
|
167 |
+
|
168 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
169 |
+
x = self.patch_embed(x)
|
170 |
+
if self.pos_embed is not None:
|
171 |
+
x = x + self.pos_embed
|
172 |
+
|
173 |
+
global_features = []
|
174 |
+
for i, blk in enumerate(self.blocks):
|
175 |
+
x = blk(x)
|
176 |
+
if self.sam_hd and blk.window_size == 0:
|
177 |
+
global_features.append(x)
|
178 |
+
|
179 |
+
x = self.neck(x.permute(0, 3, 1, 2))
|
180 |
+
x_dtype = x.dtype
|
181 |
+
x = F.interpolate(
|
182 |
+
x.float(), size=(96, 96), mode="bilinear", align_corners=False
|
183 |
+
).to(x_dtype)
|
184 |
+
x = self.downsamples(x)
|
185 |
+
|
186 |
+
if self.sam_hd:
|
187 |
+
first_global_feature = self.neck_hd(global_features[0].permute(0, 3, 1, 2))
|
188 |
+
x_dtype = first_global_feature.dtype
|
189 |
+
first_global_feature = F.interpolate(
|
190 |
+
first_global_feature.float(),
|
191 |
+
size=(96, 96),
|
192 |
+
mode="bilinear",
|
193 |
+
align_corners=False,
|
194 |
+
)
|
195 |
+
first_global_feature = self.downsamples(first_global_feature.to(x_dtype))
|
196 |
+
x = x + first_global_feature * self.hd_alpha_downsamples
|
197 |
+
|
198 |
+
return x
|
199 |
+
|
200 |
+
|
201 |
+
class Block(nn.Module):
|
202 |
+
"""Transformer blocks with support of window attention and residual propagation blocks"""
|
203 |
+
|
204 |
+
def __init__(
|
205 |
+
self,
|
206 |
+
dim: int,
|
207 |
+
num_heads: int,
|
208 |
+
mlp_ratio: float = 4.0,
|
209 |
+
qkv_bias: bool = True,
|
210 |
+
norm_layer: Type[nn.Module] = nn.LayerNorm,
|
211 |
+
act_layer: Type[nn.Module] = nn.GELU,
|
212 |
+
use_rel_pos: bool = False,
|
213 |
+
rel_pos_zero_init: bool = True,
|
214 |
+
window_size: int = 0,
|
215 |
+
input_size: Optional[Tuple[int, int]] = None,
|
216 |
+
) -> None:
|
217 |
+
"""
|
218 |
+
Args:
|
219 |
+
dim (int): Number of input channels.
|
220 |
+
num_heads (int): Number of attention heads in each ViT block.
|
221 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
222 |
+
qkv_bias (bool): If True, add a learnable bias to query, key, value.
|
223 |
+
norm_layer (nn.Module): Normalization layer.
|
224 |
+
act_layer (nn.Module): Activation layer.
|
225 |
+
use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
|
226 |
+
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
|
227 |
+
window_size (int): Window size for window attention blocks. If it equals 0, then
|
228 |
+
use global attention.
|
229 |
+
input_size (tuple(int, int) or None): Input resolution for calculating the relative
|
230 |
+
positional parameter size.
|
231 |
+
"""
|
232 |
+
super().__init__()
|
233 |
+
self.norm1 = norm_layer(dim)
|
234 |
+
self.attn = Attention(
|
235 |
+
dim,
|
236 |
+
num_heads=num_heads,
|
237 |
+
qkv_bias=qkv_bias,
|
238 |
+
use_rel_pos=use_rel_pos,
|
239 |
+
rel_pos_zero_init=rel_pos_zero_init,
|
240 |
+
input_size=input_size if window_size == 0 else (window_size, window_size),
|
241 |
+
)
|
242 |
+
|
243 |
+
self.norm2 = norm_layer(dim)
|
244 |
+
self.mlp = MLPBlock(
|
245 |
+
embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer
|
246 |
+
)
|
247 |
+
|
248 |
+
self.window_size = window_size
|
249 |
+
|
250 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
251 |
+
shortcut = x
|
252 |
+
x = self.norm1(x)
|
253 |
+
# Window partition
|
254 |
+
if self.window_size > 0:
|
255 |
+
H, W = x.shape[1], x.shape[2]
|
256 |
+
x, pad_hw = window_partition(x, self.window_size)
|
257 |
+
|
258 |
+
x = self.attn(x)
|
259 |
+
# Reverse window partition
|
260 |
+
if self.window_size > 0:
|
261 |
+
x = window_unpartition(x, self.window_size, pad_hw, (H, W))
|
262 |
+
|
263 |
+
x = shortcut + x
|
264 |
+
x = x + self.mlp(self.norm2(x))
|
265 |
+
|
266 |
+
return x
|
267 |
+
|
268 |
+
|
269 |
+
class Attention(nn.Module):
|
270 |
+
"""Multi-head Attention block with relative position embeddings."""
|
271 |
+
|
272 |
+
def __init__(
|
273 |
+
self,
|
274 |
+
dim: int,
|
275 |
+
num_heads: int = 8,
|
276 |
+
qkv_bias: bool = True,
|
277 |
+
use_rel_pos: bool = False,
|
278 |
+
rel_pos_zero_init: bool = True,
|
279 |
+
input_size: Optional[Tuple[int, int]] = None,
|
280 |
+
) -> None:
|
281 |
+
"""
|
282 |
+
Args:
|
283 |
+
dim (int): Number of input channels.
|
284 |
+
num_heads (int): Number of attention heads.
|
285 |
+
qkv_bias (bool): If True, add a learnable bias to query, key, value.
|
286 |
+
rel_pos (bool): If True, add relative positional embeddings to the attention map.
|
287 |
+
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
|
288 |
+
input_size (tuple(int, int) or None): Input resolution for calculating the relative
|
289 |
+
positional parameter size.
|
290 |
+
"""
|
291 |
+
super().__init__()
|
292 |
+
self.num_heads = num_heads
|
293 |
+
head_dim = dim // num_heads
|
294 |
+
self.scale = head_dim**-0.5
|
295 |
+
|
296 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
297 |
+
self.proj = nn.Linear(dim, dim)
|
298 |
+
|
299 |
+
self.use_rel_pos = use_rel_pos
|
300 |
+
if self.use_rel_pos:
|
301 |
+
assert (
|
302 |
+
input_size is not None
|
303 |
+
), "Input size must be provided if using relative positional encoding."
|
304 |
+
# initialize relative positional embeddings
|
305 |
+
self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim))
|
306 |
+
self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))
|
307 |
+
|
308 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
309 |
+
B, H, W, _ = x.shape
|
310 |
+
# qkv with shape (3, B, nHead, H * W, C)
|
311 |
+
qkv = (
|
312 |
+
self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
313 |
+
)
|
314 |
+
# q, k, v with shape (B * nHead, H * W, C)
|
315 |
+
q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0)
|
316 |
+
|
317 |
+
def do_attention(q, k, v):
|
318 |
+
attn = (q * self.scale) @ k.transpose(-2, -1)
|
319 |
+
if self.use_rel_pos:
|
320 |
+
attn = add_decomposed_rel_pos(
|
321 |
+
attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W)
|
322 |
+
)
|
323 |
+
|
324 |
+
attn = attn.softmax(dim=-1)
|
325 |
+
x = (
|
326 |
+
(attn @ v)
|
327 |
+
.view(B, self.num_heads, H, W, -1)
|
328 |
+
.permute(0, 2, 3, 1, 4)
|
329 |
+
.reshape(B, H, W, -1)
|
330 |
+
)
|
331 |
+
|
332 |
+
return x
|
333 |
+
|
334 |
+
# from haiscale.utils import on_demand_checkpoint
|
335 |
+
# x = on_demand_checkpoint(do_attention, q, k, v)
|
336 |
+
x = do_attention(q, k, v)
|
337 |
+
x = self.proj(x)
|
338 |
+
|
339 |
+
return x
|
340 |
+
|
341 |
+
|
342 |
+
def window_partition(
|
343 |
+
x: torch.Tensor, window_size: int
|
344 |
+
) -> Tuple[torch.Tensor, Tuple[int, int]]:
|
345 |
+
"""
|
346 |
+
Partition into non-overlapping windows with padding if needed.
|
347 |
+
Args:
|
348 |
+
x (tensor): input tokens with [B, H, W, C].
|
349 |
+
window_size (int): window size.
|
350 |
+
|
351 |
+
Returns:
|
352 |
+
windows: windows after partition with [B * num_windows, window_size, window_size, C].
|
353 |
+
(Hp, Wp): padded height and width before partition
|
354 |
+
"""
|
355 |
+
B, H, W, C = x.shape
|
356 |
+
|
357 |
+
pad_h = (window_size - H % window_size) % window_size
|
358 |
+
pad_w = (window_size - W % window_size) % window_size
|
359 |
+
if pad_h > 0 or pad_w > 0:
|
360 |
+
x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
|
361 |
+
Hp, Wp = H + pad_h, W + pad_w
|
362 |
+
|
363 |
+
x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
|
364 |
+
windows = (
|
365 |
+
x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
|
366 |
+
)
|
367 |
+
return windows, (Hp, Wp)
|
368 |
+
|
369 |
+
|
370 |
+
def window_unpartition(
|
371 |
+
windows: torch.Tensor,
|
372 |
+
window_size: int,
|
373 |
+
pad_hw: Tuple[int, int],
|
374 |
+
hw: Tuple[int, int],
|
375 |
+
) -> torch.Tensor:
|
376 |
+
"""
|
377 |
+
Window unpartition into original sequences and removing padding.
|
378 |
+
Args:
|
379 |
+
windows (tensor): input tokens with [B * num_windows, window_size, window_size, C].
|
380 |
+
window_size (int): window size.
|
381 |
+
pad_hw (Tuple): padded height and width (Hp, Wp).
|
382 |
+
hw (Tuple): original height and width (H, W) before padding.
|
383 |
+
|
384 |
+
Returns:
|
385 |
+
x: unpartitioned sequences with [B, H, W, C].
|
386 |
+
"""
|
387 |
+
Hp, Wp = pad_hw
|
388 |
+
H, W = hw
|
389 |
+
B = windows.shape[0] // (Hp * Wp // window_size // window_size)
|
390 |
+
x = windows.view(
|
391 |
+
B, Hp // window_size, Wp // window_size, window_size, window_size, -1
|
392 |
+
)
|
393 |
+
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
|
394 |
+
|
395 |
+
if Hp > H or Wp > W:
|
396 |
+
x = x[:, :H, :W, :].contiguous()
|
397 |
+
return x
|
398 |
+
|
399 |
+
|
400 |
+
def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor:
|
401 |
+
"""
|
402 |
+
Get relative positional embeddings according to the relative positions of
|
403 |
+
query and key sizes.
|
404 |
+
Args:
|
405 |
+
q_size (int): size of query q.
|
406 |
+
k_size (int): size of key k.
|
407 |
+
rel_pos (Tensor): relative position embeddings (L, C).
|
408 |
+
|
409 |
+
Returns:
|
410 |
+
Extracted positional embeddings according to relative positions.
|
411 |
+
"""
|
412 |
+
max_rel_dist = int(2 * max(q_size, k_size) - 1)
|
413 |
+
# Interpolate rel pos if needed.
|
414 |
+
if rel_pos.shape[0] != max_rel_dist:
|
415 |
+
# Interpolate rel pos.
|
416 |
+
rel_pos_resized = F.interpolate(
|
417 |
+
rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
|
418 |
+
size=max_rel_dist,
|
419 |
+
mode="linear",
|
420 |
+
)
|
421 |
+
rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
|
422 |
+
else:
|
423 |
+
rel_pos_resized = rel_pos
|
424 |
+
|
425 |
+
# Scale the coords with short length if shapes for q and k are different.
|
426 |
+
q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
|
427 |
+
k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
|
428 |
+
relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
|
429 |
+
|
430 |
+
return rel_pos_resized[relative_coords.long()]
|
431 |
+
|
432 |
+
|
433 |
+
def add_decomposed_rel_pos(
|
434 |
+
attn: torch.Tensor,
|
435 |
+
q: torch.Tensor,
|
436 |
+
rel_pos_h: torch.Tensor,
|
437 |
+
rel_pos_w: torch.Tensor,
|
438 |
+
q_size: Tuple[int, int],
|
439 |
+
k_size: Tuple[int, int],
|
440 |
+
) -> torch.Tensor:
|
441 |
+
"""
|
442 |
+
Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
|
443 |
+
https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950
|
444 |
+
Args:
|
445 |
+
attn (Tensor): attention map.
|
446 |
+
q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C).
|
447 |
+
rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis.
|
448 |
+
rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis.
|
449 |
+
q_size (Tuple): spatial sequence size of query q with (q_h, q_w).
|
450 |
+
k_size (Tuple): spatial sequence size of key k with (k_h, k_w).
|
451 |
+
|
452 |
+
Returns:
|
453 |
+
attn (Tensor): attention map with added relative positional embeddings.
|
454 |
+
"""
|
455 |
+
q_h, q_w = q_size
|
456 |
+
k_h, k_w = k_size
|
457 |
+
Rh = get_rel_pos(q_h, k_h, rel_pos_h)
|
458 |
+
Rw = get_rel_pos(q_w, k_w, rel_pos_w)
|
459 |
+
|
460 |
+
B, _, dim = q.shape
|
461 |
+
r_q = q.reshape(B, q_h, q_w, dim)
|
462 |
+
rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)
|
463 |
+
rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)
|
464 |
+
|
465 |
+
attn = (
|
466 |
+
attn.view(B, q_h, q_w, k_h, k_w)
|
467 |
+
+ rel_h[:, :, :, :, None]
|
468 |
+
+ rel_w[:, :, :, None, :]
|
469 |
+
).view(B, q_h * q_w, k_h * k_w)
|
470 |
+
|
471 |
+
return attn
|
472 |
+
|
473 |
+
|
474 |
+
class PatchEmbed(nn.Module):
|
475 |
+
"""
|
476 |
+
Image to Patch Embedding.
|
477 |
+
"""
|
478 |
+
|
479 |
+
def __init__(
|
480 |
+
self,
|
481 |
+
kernel_size: Tuple[int, int] = (16, 16),
|
482 |
+
stride: Tuple[int, int] = (16, 16),
|
483 |
+
padding: Tuple[int, int] = (0, 0),
|
484 |
+
in_chans: int = 3,
|
485 |
+
embed_dim: int = 768,
|
486 |
+
) -> None:
|
487 |
+
"""
|
488 |
+
Args:
|
489 |
+
kernel_size (Tuple): kernel size of the projection layer.
|
490 |
+
stride (Tuple): stride of the projection layer.
|
491 |
+
padding (Tuple): padding size of the projection layer.
|
492 |
+
in_chans (int): Number of input image channels.
|
493 |
+
embed_dim (int): Patch embedding dimension.
|
494 |
+
"""
|
495 |
+
super().__init__()
|
496 |
+
|
497 |
+
self.proj = nn.Conv2d(
|
498 |
+
in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding
|
499 |
+
)
|
500 |
+
|
501 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
502 |
+
x = self.proj(x)
|
503 |
+
# B C H W -> B H W C
|
504 |
+
x = x.permute(0, 2, 3, 1)
|
505 |
+
return x
|
506 |
+
|
507 |
+
|
508 |
+
@dataclass
|
509 |
+
class SAMViTCfg:
|
510 |
+
image_size: Union[Tuple[int, int], int] = 1024
|
511 |
+
width: int = 1024
|
512 |
+
layers: int = 23
|
513 |
+
heads: int = 16
|
514 |
+
patch_size: int = 16
|
515 |
+
window_size: int = 14
|
516 |
+
prompt_embed_dim: int = 256
|
517 |
+
global_attn_indexes: Union[List[int], Tuple[int]] = (5, 11, 17, 23)
|
518 |
+
downsample_channels: Union[List[int], Tuple[int]] = (512, 1024)
|
519 |
+
|
520 |
+
|
521 |
+
SAM_MODEL_CONFIG = {
|
522 |
+
"sam_vit_b": {
|
523 |
+
"width": 768,
|
524 |
+
"layers": 12,
|
525 |
+
"heads": 12,
|
526 |
+
"global_attn_indexes": [2, 5, 8, 11],
|
527 |
+
"downsample_channels": (),
|
528 |
+
},
|
529 |
+
"sam_b_downsample": {
|
530 |
+
"width": 768,
|
531 |
+
"layers": 12,
|
532 |
+
"heads": 12,
|
533 |
+
"global_attn_indexes": [2, 5, 8, 11],
|
534 |
+
"downsample_channels": (512, 1024),
|
535 |
+
},
|
536 |
+
"sam_vit_l": {
|
537 |
+
"width": 1024,
|
538 |
+
"layers": 24,
|
539 |
+
"heads": 16,
|
540 |
+
"global_attn_indexes": [5, 11, 17, 23],
|
541 |
+
"downsample_channels": (),
|
542 |
+
},
|
543 |
+
"sam_vit_h": {
|
544 |
+
"width": 1280,
|
545 |
+
"layers": 32,
|
546 |
+
"heads": 16,
|
547 |
+
"global_attn_indexes": [7, 15, 23, 31],
|
548 |
+
"downsample_channels": (),
|
549 |
+
},
|
550 |
+
}
|
551 |
+
|
552 |
+
|
553 |
+
def create_sam_vit(
|
554 |
+
model_name: str = "sam_b_downsample",
|
555 |
+
image_size: int = 1024,
|
556 |
+
ckpt_path: str = "",
|
557 |
+
**kwargs,
|
558 |
+
):
|
559 |
+
assert (
|
560 |
+
model_name in SAM_MODEL_CONFIG.keys()
|
561 |
+
), f"model name: {model_name} should be in {SAM_MODEL_CONFIG.keys()}"
|
562 |
+
|
563 |
+
sam_cfg = SAMViTCfg(**SAM_MODEL_CONFIG[model_name])
|
564 |
+
image_encoder = ImageEncoderViT(
|
565 |
+
depth=sam_cfg.layers,
|
566 |
+
embed_dim=sam_cfg.width,
|
567 |
+
img_size=image_size,
|
568 |
+
mlp_ratio=4,
|
569 |
+
norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
|
570 |
+
num_heads=sam_cfg.heads,
|
571 |
+
patch_size=sam_cfg.patch_size,
|
572 |
+
qkv_bias=True,
|
573 |
+
use_rel_pos=True,
|
574 |
+
global_attn_indexes=sam_cfg.global_attn_indexes,
|
575 |
+
window_size=14,
|
576 |
+
out_chans=sam_cfg.prompt_embed_dim,
|
577 |
+
downsample_channels=sam_cfg.downsample_channels,
|
578 |
+
)
|
579 |
+
|
580 |
+
if ckpt_path:
|
581 |
+
state_dict = torch.load(ckpt_path)
|
582 |
+
image_encoder.load_state_dict(state_dict, strict=False)
|
583 |
+
print(f"SAM-ViT restores from {ckpt_path}")
|
584 |
+
|
585 |
+
return image_encoder
|
586 |
+
|
587 |
+
|
588 |
+
if __name__ == "__main__":
|
589 |
+
x = torch.zeros(2, 3, 1024, 1024).bfloat16()
|
590 |
+
# x.permute(0, 3, 1, 2)
|
591 |
+
net = create_sam_vit().bfloat16()
|
592 |
+
out = net(x)
|
593 |
+
print(x.shape, out.shape)
|
deepseek_vl/models/siglip_vit.py
ADDED
@@ -0,0 +1,681 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023-2024 DeepSeek.
|
2 |
+
#
|
3 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy of
|
4 |
+
# this software and associated documentation files (the "Software"), to deal in
|
5 |
+
# the Software without restriction, including without limitation the rights to
|
6 |
+
# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
|
7 |
+
# the Software, and to permit persons to whom the Software is furnished to do so,
|
8 |
+
# subject to the following conditions:
|
9 |
+
#
|
10 |
+
# The above copyright notice and this permission notice shall be included in all
|
11 |
+
# copies or substantial portions of the Software.
|
12 |
+
#
|
13 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
14 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
|
15 |
+
# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
|
16 |
+
# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
|
17 |
+
# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
|
18 |
+
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
19 |
+
|
20 |
+
# https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py
|
21 |
+
import math
|
22 |
+
import warnings
|
23 |
+
from dataclasses import dataclass
|
24 |
+
from functools import partial
|
25 |
+
from typing import (
|
26 |
+
Callable,
|
27 |
+
Dict,
|
28 |
+
Final,
|
29 |
+
List,
|
30 |
+
Literal,
|
31 |
+
Optional,
|
32 |
+
Sequence,
|
33 |
+
Set,
|
34 |
+
Tuple,
|
35 |
+
Type,
|
36 |
+
Union,
|
37 |
+
)
|
38 |
+
|
39 |
+
import torch
|
40 |
+
import torch.nn as nn
|
41 |
+
import torch.nn.functional as F
|
42 |
+
from timm.layers import (
|
43 |
+
AttentionPoolLatent,
|
44 |
+
DropPath,
|
45 |
+
LayerType,
|
46 |
+
Mlp,
|
47 |
+
PatchDropout,
|
48 |
+
PatchEmbed,
|
49 |
+
resample_abs_pos_embed,
|
50 |
+
)
|
51 |
+
from timm.models._manipulate import checkpoint_seq, named_apply
|
52 |
+
|
53 |
+
|
54 |
+
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
|
55 |
+
# Cut & paste from PyTorch official master until it's in a few official releases - RW
|
56 |
+
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
|
57 |
+
def norm_cdf(x):
|
58 |
+
# Computes standard normal cumulative distribution function
|
59 |
+
return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
|
60 |
+
|
61 |
+
if (mean < a - 2 * std) or (mean > b + 2 * std):
|
62 |
+
warnings.warn(
|
63 |
+
"mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
|
64 |
+
"The distribution of values may be incorrect.",
|
65 |
+
stacklevel=2,
|
66 |
+
)
|
67 |
+
|
68 |
+
with torch.no_grad():
|
69 |
+
# Values are generated by using a truncated uniform distribution and
|
70 |
+
# then using the inverse CDF for the normal distribution.
|
71 |
+
# Get upper and lower cdf values
|
72 |
+
l = norm_cdf((a - mean) / std) # noqa: E741
|
73 |
+
u = norm_cdf((b - mean) / std)
|
74 |
+
|
75 |
+
# Uniformly fill tensor with values from [l, u], then translate to
|
76 |
+
# [2l-1, 2u-1].
|
77 |
+
tensor.uniform_(2 * l - 1, 2 * u - 1)
|
78 |
+
|
79 |
+
# Use inverse cdf transform for normal distribution to get truncated
|
80 |
+
# standard normal
|
81 |
+
tensor.erfinv_()
|
82 |
+
|
83 |
+
# Transform to proper mean, std
|
84 |
+
tensor.mul_(std * math.sqrt(2.0))
|
85 |
+
tensor.add_(mean)
|
86 |
+
|
87 |
+
# Clamp to ensure it's in the proper range
|
88 |
+
tensor.clamp_(min=a, max=b)
|
89 |
+
return tensor
|
90 |
+
|
91 |
+
|
92 |
+
def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):
|
93 |
+
# type: (torch.Tensor, float, float, float, float) -> torch.Tensor
|
94 |
+
r"""The original timm.models.layers.weight_init.trunc_normal_ can not handle bfloat16 yet, here we first
|
95 |
+
convert the tensor to float32, apply the trunc_normal_() in float32, and then convert it back to its orignal dtype.
|
96 |
+
Fills the input Tensor with values drawn from a truncated normal distribution. The values are effectively drawn
|
97 |
+
from the normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
|
98 |
+
with values outside :math:`[a, b]` redrawn until they are within
|
99 |
+
the bounds. The method used for generating the random values works
|
100 |
+
best when :math:`a \leq \text{mean} \leq b`.
|
101 |
+
Args:
|
102 |
+
tensor: an n-dimensional `torch.Tensor`
|
103 |
+
mean: the mean of the normal distribution
|
104 |
+
std: the standard deviation of the normal distribution
|
105 |
+
a: the minimum cutoff value
|
106 |
+
b: the maximum cutoff value
|
107 |
+
Examples:
|
108 |
+
>>> w = torch.empty(3, 5)
|
109 |
+
>>> nn.init.trunc_normal_(w)
|
110 |
+
"""
|
111 |
+
|
112 |
+
with torch.no_grad():
|
113 |
+
dtype = tensor.dtype
|
114 |
+
tensor_fp32 = tensor.float()
|
115 |
+
tensor_fp32 = _no_grad_trunc_normal_(tensor_fp32, mean, std, a, b)
|
116 |
+
tensor_dtype = tensor_fp32.to(dtype=dtype)
|
117 |
+
tensor.copy_(tensor_dtype)
|
118 |
+
|
119 |
+
|
120 |
+
def init_weights(self):
|
121 |
+
if self.pos_embed is not None:
|
122 |
+
trunc_normal_(self.pos_embed, std=self.pos_embed.shape[1] ** -0.5)
|
123 |
+
trunc_normal_(self.latent, std=self.latent_dim**-0.5)
|
124 |
+
|
125 |
+
|
126 |
+
def init_weights_vit_timm(module: nn.Module, name: str = "") -> None:
|
127 |
+
"""ViT weight initialization, original timm impl (for reproducibility)"""
|
128 |
+
if isinstance(module, nn.Linear):
|
129 |
+
trunc_normal_(module.weight, std=0.02)
|
130 |
+
if module.bias is not None:
|
131 |
+
nn.init.zeros_(module.bias)
|
132 |
+
elif hasattr(module, "init_weights"):
|
133 |
+
module.init_weights()
|
134 |
+
|
135 |
+
|
136 |
+
class Attention(nn.Module):
|
137 |
+
fused_attn: Final[bool]
|
138 |
+
|
139 |
+
def __init__(
|
140 |
+
self,
|
141 |
+
dim: int,
|
142 |
+
num_heads: int = 8,
|
143 |
+
qkv_bias: bool = False,
|
144 |
+
qk_norm: bool = False,
|
145 |
+
attn_drop: float = 0.0,
|
146 |
+
proj_drop: float = 0.0,
|
147 |
+
norm_layer: nn.Module = nn.LayerNorm,
|
148 |
+
) -> None:
|
149 |
+
super().__init__()
|
150 |
+
assert dim % num_heads == 0, "dim should be divisible by num_heads"
|
151 |
+
self.num_heads = num_heads
|
152 |
+
self.head_dim = dim // num_heads
|
153 |
+
self.scale = self.head_dim**-0.5
|
154 |
+
# self.fused_attn = use_fused_attn()
|
155 |
+
self.fused_attn = True
|
156 |
+
|
157 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
158 |
+
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
159 |
+
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
160 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
161 |
+
self.proj = nn.Linear(dim, dim)
|
162 |
+
self.proj_drop = nn.Dropout(proj_drop) if proj_drop > 0.0 else nn.Identity()
|
163 |
+
|
164 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
165 |
+
B, N, C = x.shape
|
166 |
+
qkv = (
|
167 |
+
self.qkv(x)
|
168 |
+
.reshape(B, N, 3, self.num_heads, self.head_dim)
|
169 |
+
.permute(2, 0, 3, 1, 4)
|
170 |
+
)
|
171 |
+
q, k, v = qkv.unbind(0)
|
172 |
+
q, k = self.q_norm(q), self.k_norm(k)
|
173 |
+
|
174 |
+
if self.fused_attn:
|
175 |
+
x = F.scaled_dot_product_attention(
|
176 |
+
q,
|
177 |
+
k,
|
178 |
+
v,
|
179 |
+
dropout_p=self.attn_drop.p if self.training else 0.0,
|
180 |
+
)
|
181 |
+
else:
|
182 |
+
q = q * self.scale
|
183 |
+
attn = q @ k.transpose(-2, -1)
|
184 |
+
attn = attn.softmax(dim=-1)
|
185 |
+
attn = self.attn_drop(attn)
|
186 |
+
x = attn @ v
|
187 |
+
|
188 |
+
x = x.transpose(1, 2).reshape(B, N, C)
|
189 |
+
x = self.proj(x)
|
190 |
+
x = self.proj_drop(x)
|
191 |
+
return x
|
192 |
+
|
193 |
+
|
194 |
+
class LayerScale(nn.Module):
|
195 |
+
def __init__(
|
196 |
+
self,
|
197 |
+
dim: int,
|
198 |
+
init_values: float = 1e-5,
|
199 |
+
inplace: bool = False,
|
200 |
+
) -> None:
|
201 |
+
super().__init__()
|
202 |
+
self.inplace = inplace
|
203 |
+
self.gamma = nn.Parameter(init_values * torch.ones(dim))
|
204 |
+
|
205 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
206 |
+
return x.mul_(self.gamma) if self.inplace else x * self.gamma
|
207 |
+
|
208 |
+
|
209 |
+
class Block(nn.Module):
|
210 |
+
def __init__(
|
211 |
+
self,
|
212 |
+
dim: int,
|
213 |
+
num_heads: int,
|
214 |
+
mlp_ratio: float = 4.0,
|
215 |
+
qkv_bias: bool = False,
|
216 |
+
qk_norm: bool = False,
|
217 |
+
proj_drop: float = 0.0,
|
218 |
+
attn_drop: float = 0.0,
|
219 |
+
init_values: Optional[float] = None,
|
220 |
+
drop_path: float = 0.0,
|
221 |
+
act_layer: nn.Module = nn.GELU,
|
222 |
+
norm_layer: nn.Module = nn.LayerNorm,
|
223 |
+
mlp_layer: nn.Module = Mlp,
|
224 |
+
) -> None:
|
225 |
+
super().__init__()
|
226 |
+
self.norm1 = norm_layer(dim)
|
227 |
+
self.attn = Attention(
|
228 |
+
dim,
|
229 |
+
num_heads=num_heads,
|
230 |
+
qkv_bias=qkv_bias,
|
231 |
+
qk_norm=qk_norm,
|
232 |
+
attn_drop=attn_drop,
|
233 |
+
proj_drop=proj_drop,
|
234 |
+
norm_layer=norm_layer,
|
235 |
+
)
|
236 |
+
self.ls1 = (
|
237 |
+
LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
238 |
+
)
|
239 |
+
self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
240 |
+
|
241 |
+
self.norm2 = norm_layer(dim)
|
242 |
+
self.mlp = mlp_layer(
|
243 |
+
in_features=dim,
|
244 |
+
hidden_features=int(dim * mlp_ratio),
|
245 |
+
act_layer=act_layer,
|
246 |
+
drop=proj_drop,
|
247 |
+
)
|
248 |
+
self.ls2 = (
|
249 |
+
LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
250 |
+
)
|
251 |
+
self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
252 |
+
|
253 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
254 |
+
x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x))))
|
255 |
+
x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
|
256 |
+
return x
|
257 |
+
|
258 |
+
|
259 |
+
class VisionTransformer(nn.Module):
|
260 |
+
"""Vision Transformer
|
261 |
+
|
262 |
+
A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale`
|
263 |
+
- https://arxiv.org/abs/2010.11929
|
264 |
+
"""
|
265 |
+
|
266 |
+
dynamic_img_size: Final[bool]
|
267 |
+
|
268 |
+
def __init__(
|
269 |
+
self,
|
270 |
+
img_size: Union[int, Tuple[int, int]] = 224,
|
271 |
+
patch_size: Union[int, Tuple[int, int]] = 16,
|
272 |
+
in_chans: int = 3,
|
273 |
+
num_classes: int = 1000,
|
274 |
+
global_pool: Literal["", "avg", "token", "map"] = "token",
|
275 |
+
embed_dim: int = 768,
|
276 |
+
depth: int = 12,
|
277 |
+
num_heads: int = 12,
|
278 |
+
mlp_ratio: float = 4.0,
|
279 |
+
qkv_bias: bool = True,
|
280 |
+
qk_norm: bool = False,
|
281 |
+
init_values: Optional[float] = None,
|
282 |
+
class_token: bool = True,
|
283 |
+
no_embed_class: bool = False,
|
284 |
+
reg_tokens: int = 0,
|
285 |
+
pre_norm: bool = False,
|
286 |
+
fc_norm: Optional[bool] = None,
|
287 |
+
dynamic_img_size: bool = False,
|
288 |
+
dynamic_img_pad: bool = False,
|
289 |
+
drop_rate: float = 0.0,
|
290 |
+
pos_drop_rate: float = 0.0,
|
291 |
+
patch_drop_rate: float = 0.0,
|
292 |
+
proj_drop_rate: float = 0.0,
|
293 |
+
attn_drop_rate: float = 0.0,
|
294 |
+
drop_path_rate: float = 0.0,
|
295 |
+
weight_init: Literal["skip", "jax", "jax_nlhb", "moco", ""] = "",
|
296 |
+
embed_layer: Callable = PatchEmbed,
|
297 |
+
norm_layer: Optional[LayerType] = None,
|
298 |
+
act_layer: Optional[LayerType] = None,
|
299 |
+
block_fn: Type[nn.Module] = Block,
|
300 |
+
mlp_layer: Type[nn.Module] = Mlp,
|
301 |
+
ignore_head: bool = False,
|
302 |
+
) -> None:
|
303 |
+
"""
|
304 |
+
Args:
|
305 |
+
img_size: Input image size.
|
306 |
+
patch_size: Patch size.
|
307 |
+
in_chans: Number of image input channels.
|
308 |
+
num_classes: Mumber of classes for classification head.
|
309 |
+
global_pool: Type of global pooling for final sequence (default: 'token').
|
310 |
+
embed_dim: Transformer embedding dimension.
|
311 |
+
depth: Depth of transformer.
|
312 |
+
num_heads: Number of attention heads.
|
313 |
+
mlp_ratio: Ratio of mlp hidden dim to embedding dim.
|
314 |
+
qkv_bias: Enable bias for qkv projections if True.
|
315 |
+
init_values: Layer-scale init values (layer-scale enabled if not None).
|
316 |
+
class_token: Use class token.
|
317 |
+
no_embed_class: Don't include position embeddings for class (or reg) tokens.
|
318 |
+
reg_tokens: Number of register tokens.
|
319 |
+
fc_norm: Pre head norm after pool (instead of before), if None, enabled when global_pool == 'avg'.
|
320 |
+
drop_rate: Head dropout rate.
|
321 |
+
pos_drop_rate: Position embedding dropout rate.
|
322 |
+
attn_drop_rate: Attention dropout rate.
|
323 |
+
drop_path_rate: Stochastic depth rate.
|
324 |
+
weight_init: Weight initialization scheme.
|
325 |
+
embed_layer: Patch embedding layer.
|
326 |
+
norm_layer: Normalization layer.
|
327 |
+
act_layer: MLP activation layer.
|
328 |
+
block_fn: Transformer block layer.
|
329 |
+
"""
|
330 |
+
super().__init__()
|
331 |
+
assert global_pool in ("", "avg", "token", "map")
|
332 |
+
assert class_token or global_pool != "token"
|
333 |
+
use_fc_norm = global_pool == "avg" if fc_norm is None else fc_norm
|
334 |
+
# norm_layer = get_norm_layer(norm_layer) or partial(nn.LayerNorm, eps=1e-6)
|
335 |
+
# act_layer = get_act_layer(act_layer) or nn.GELU
|
336 |
+
norm_layer = partial(nn.LayerNorm, eps=1e-6)
|
337 |
+
act_layer = nn.GELU
|
338 |
+
|
339 |
+
self.num_classes = num_classes
|
340 |
+
self.global_pool = global_pool
|
341 |
+
self.num_features = self.embed_dim = (
|
342 |
+
embed_dim # num_features for consistency with other models
|
343 |
+
)
|
344 |
+
self.num_prefix_tokens = 1 if class_token else 0
|
345 |
+
self.num_prefix_tokens += reg_tokens
|
346 |
+
self.num_reg_tokens = reg_tokens
|
347 |
+
self.has_class_token = class_token
|
348 |
+
self.no_embed_class = (
|
349 |
+
no_embed_class # don't embed prefix positions (includes reg)
|
350 |
+
)
|
351 |
+
self.dynamic_img_size = dynamic_img_size
|
352 |
+
self.grad_checkpointing = False
|
353 |
+
self.ignore_head = ignore_head
|
354 |
+
|
355 |
+
embed_args = {}
|
356 |
+
if dynamic_img_size:
|
357 |
+
# flatten deferred until after pos embed
|
358 |
+
embed_args.update(dict(strict_img_size=False, output_fmt="NHWC"))
|
359 |
+
self.patch_embed = embed_layer(
|
360 |
+
img_size=img_size,
|
361 |
+
patch_size=patch_size,
|
362 |
+
in_chans=in_chans,
|
363 |
+
embed_dim=embed_dim,
|
364 |
+
bias=not pre_norm, # disable bias if pre-norm is used (e.g. CLIP)
|
365 |
+
dynamic_img_pad=dynamic_img_pad,
|
366 |
+
**embed_args,
|
367 |
+
)
|
368 |
+
num_patches = self.patch_embed.num_patches
|
369 |
+
|
370 |
+
self.cls_token = (
|
371 |
+
nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None
|
372 |
+
)
|
373 |
+
self.reg_token = (
|
374 |
+
nn.Parameter(torch.zeros(1, reg_tokens, embed_dim)) if reg_tokens else None
|
375 |
+
)
|
376 |
+
embed_len = (
|
377 |
+
num_patches if no_embed_class else num_patches + self.num_prefix_tokens
|
378 |
+
)
|
379 |
+
self.pos_embed = nn.Parameter(torch.randn(1, embed_len, embed_dim) * 0.02)
|
380 |
+
self.pos_drop = nn.Dropout(p=pos_drop_rate)
|
381 |
+
if patch_drop_rate > 0:
|
382 |
+
self.patch_drop = PatchDropout(
|
383 |
+
patch_drop_rate,
|
384 |
+
num_prefix_tokens=self.num_prefix_tokens,
|
385 |
+
)
|
386 |
+
else:
|
387 |
+
self.patch_drop = nn.Identity()
|
388 |
+
self.norm_pre = norm_layer(embed_dim) if pre_norm else nn.Identity()
|
389 |
+
|
390 |
+
dpr = [
|
391 |
+
x.item() for x in torch.linspace(0, drop_path_rate, depth)
|
392 |
+
] # stochastic depth decay rule
|
393 |
+
self.blocks = nn.Sequential(
|
394 |
+
*[
|
395 |
+
block_fn(
|
396 |
+
dim=embed_dim,
|
397 |
+
num_heads=num_heads,
|
398 |
+
mlp_ratio=mlp_ratio,
|
399 |
+
qkv_bias=qkv_bias,
|
400 |
+
qk_norm=qk_norm,
|
401 |
+
init_values=init_values,
|
402 |
+
proj_drop=proj_drop_rate,
|
403 |
+
attn_drop=attn_drop_rate,
|
404 |
+
drop_path=dpr[i],
|
405 |
+
norm_layer=norm_layer,
|
406 |
+
act_layer=act_layer,
|
407 |
+
mlp_layer=mlp_layer,
|
408 |
+
)
|
409 |
+
for i in range(depth)
|
410 |
+
]
|
411 |
+
)
|
412 |
+
self.norm = norm_layer(embed_dim) if not use_fc_norm else nn.Identity()
|
413 |
+
|
414 |
+
# Classifier Head
|
415 |
+
if global_pool == "map":
|
416 |
+
AttentionPoolLatent.init_weights = init_weights
|
417 |
+
self.attn_pool = AttentionPoolLatent(
|
418 |
+
self.embed_dim,
|
419 |
+
num_heads=num_heads,
|
420 |
+
mlp_ratio=mlp_ratio,
|
421 |
+
norm_layer=norm_layer,
|
422 |
+
)
|
423 |
+
else:
|
424 |
+
self.attn_pool = None
|
425 |
+
self.fc_norm = norm_layer(embed_dim) if use_fc_norm else nn.Identity()
|
426 |
+
self.head_drop = nn.Dropout(drop_rate)
|
427 |
+
self.head = (
|
428 |
+
nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
429 |
+
)
|
430 |
+
|
431 |
+
if weight_init != "skip":
|
432 |
+
self.init_weights(weight_init)
|
433 |
+
|
434 |
+
def init_weights(self, mode: Literal["jax", "jax_nlhb", "moco", ""] = "") -> None:
|
435 |
+
assert mode in ("jax", "jax_nlhb", "moco", "")
|
436 |
+
# head_bias = -math.log(self.num_classes) if "nlhb" in mode else 0.0
|
437 |
+
trunc_normal_(self.pos_embed, std=0.02)
|
438 |
+
if self.cls_token is not None:
|
439 |
+
nn.init.normal_(self.cls_token, std=1e-6)
|
440 |
+
named_apply(init_weights_vit_timm, self)
|
441 |
+
|
442 |
+
@torch.jit.ignore
|
443 |
+
def no_weight_decay(self) -> Set:
|
444 |
+
return {"pos_embed", "cls_token", "dist_token"}
|
445 |
+
|
446 |
+
@torch.jit.ignore
|
447 |
+
def group_matcher(self, coarse: bool = False) -> Dict:
|
448 |
+
return dict(
|
449 |
+
stem=r"^cls_token|pos_embed|patch_embed", # stem and embed
|
450 |
+
blocks=[(r"^blocks\.(\d+)", None), (r"^norm", (99999,))],
|
451 |
+
)
|
452 |
+
|
453 |
+
@torch.jit.ignore
|
454 |
+
def set_grad_checkpointing(self, enable: bool = True) -> None:
|
455 |
+
self.grad_checkpointing = enable
|
456 |
+
|
457 |
+
@torch.jit.ignore
|
458 |
+
def get_classifier(self) -> nn.Module:
|
459 |
+
return self.head
|
460 |
+
|
461 |
+
def reset_classifier(self, num_classes: int, global_pool=None) -> None:
|
462 |
+
self.num_classes = num_classes
|
463 |
+
if global_pool is not None:
|
464 |
+
assert global_pool in ("", "avg", "token", "map")
|
465 |
+
if global_pool == "map" and self.attn_pool is None:
|
466 |
+
assert (
|
467 |
+
False
|
468 |
+
), "Cannot currently add attention pooling in reset_classifier()."
|
469 |
+
elif global_pool != "map " and self.attn_pool is not None:
|
470 |
+
self.attn_pool = None # remove attention pooling
|
471 |
+
self.global_pool = global_pool
|
472 |
+
self.head = (
|
473 |
+
nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
474 |
+
)
|
475 |
+
|
476 |
+
def _pos_embed(self, x: torch.Tensor) -> torch.Tensor:
|
477 |
+
if self.dynamic_img_size:
|
478 |
+
B, H, W, C = x.shape
|
479 |
+
pos_embed = resample_abs_pos_embed(
|
480 |
+
self.pos_embed,
|
481 |
+
(H, W),
|
482 |
+
num_prefix_tokens=0 if self.no_embed_class else self.num_prefix_tokens,
|
483 |
+
)
|
484 |
+
x = x.view(B, -1, C)
|
485 |
+
else:
|
486 |
+
pos_embed = self.pos_embed
|
487 |
+
|
488 |
+
to_cat = []
|
489 |
+
if self.cls_token is not None:
|
490 |
+
to_cat.append(self.cls_token.expand(x.shape[0], -1, -1))
|
491 |
+
if self.reg_token is not None:
|
492 |
+
to_cat.append(self.reg_token.expand(x.shape[0], -1, -1))
|
493 |
+
|
494 |
+
if self.no_embed_class:
|
495 |
+
# deit-3, updated JAX (big vision)
|
496 |
+
# position embedding does not overlap with class token, add then concat
|
497 |
+
x = x + pos_embed
|
498 |
+
if to_cat:
|
499 |
+
x = torch.cat(to_cat + [x], dim=1)
|
500 |
+
else:
|
501 |
+
# original timm, JAX, and deit vit impl
|
502 |
+
# pos_embed has entry for class token, concat then add
|
503 |
+
if to_cat:
|
504 |
+
x = torch.cat(to_cat + [x], dim=1)
|
505 |
+
x = x + pos_embed
|
506 |
+
|
507 |
+
return self.pos_drop(x)
|
508 |
+
|
509 |
+
def _intermediate_layers(
|
510 |
+
self,
|
511 |
+
x: torch.Tensor,
|
512 |
+
n: Union[int, Sequence] = 1,
|
513 |
+
) -> List[torch.Tensor]:
|
514 |
+
outputs, num_blocks = [], len(self.blocks)
|
515 |
+
take_indices = set(
|
516 |
+
range(num_blocks - n, num_blocks) if isinstance(n, int) else n
|
517 |
+
)
|
518 |
+
|
519 |
+
# forward pass
|
520 |
+
x = self.patch_embed(x)
|
521 |
+
x = self._pos_embed(x)
|
522 |
+
x = self.patch_drop(x)
|
523 |
+
x = self.norm_pre(x)
|
524 |
+
for i, blk in enumerate(self.blocks):
|
525 |
+
x = blk(x)
|
526 |
+
if i in take_indices:
|
527 |
+
outputs.append(x)
|
528 |
+
|
529 |
+
return outputs
|
530 |
+
|
531 |
+
def get_intermediate_layers(
|
532 |
+
self,
|
533 |
+
x: torch.Tensor,
|
534 |
+
n: Union[int, Sequence] = 1,
|
535 |
+
reshape: bool = False,
|
536 |
+
return_prefix_tokens: bool = False,
|
537 |
+
norm: bool = False,
|
538 |
+
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
|
539 |
+
"""Intermediate layer accessor (NOTE: This is a WIP experiment).
|
540 |
+
Inspired by DINO / DINOv2 interface
|
541 |
+
"""
|
542 |
+
# take last n blocks if n is an int, if in is a sequence, select by matching indices
|
543 |
+
outputs = self._intermediate_layers(x, n)
|
544 |
+
if norm:
|
545 |
+
outputs = [self.norm(out) for out in outputs]
|
546 |
+
prefix_tokens = [out[:, 0 : self.num_prefix_tokens] for out in outputs]
|
547 |
+
outputs = [out[:, self.num_prefix_tokens :] for out in outputs]
|
548 |
+
|
549 |
+
if reshape:
|
550 |
+
grid_size = self.patch_embed.grid_size
|
551 |
+
outputs = [
|
552 |
+
out.reshape(x.shape[0], grid_size[0], grid_size[1], -1)
|
553 |
+
.permute(0, 3, 1, 2)
|
554 |
+
.contiguous()
|
555 |
+
for out in outputs
|
556 |
+
]
|
557 |
+
|
558 |
+
if return_prefix_tokens:
|
559 |
+
return tuple(zip(outputs, prefix_tokens))
|
560 |
+
return tuple(outputs)
|
561 |
+
|
562 |
+
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
|
563 |
+
x = self.patch_embed(x)
|
564 |
+
x = self._pos_embed(x)
|
565 |
+
x = self.patch_drop(x)
|
566 |
+
x = self.norm_pre(x)
|
567 |
+
if self.grad_checkpointing and not torch.jit.is_scripting():
|
568 |
+
x = checkpoint_seq(self.blocks, x)
|
569 |
+
else:
|
570 |
+
x = self.blocks(x)
|
571 |
+
x = self.norm(x)
|
572 |
+
return x
|
573 |
+
|
574 |
+
def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor:
|
575 |
+
if self.attn_pool is not None:
|
576 |
+
x = self.attn_pool(x)
|
577 |
+
elif self.global_pool == "avg":
|
578 |
+
x = x[:, self.num_prefix_tokens :].mean(dim=1)
|
579 |
+
elif self.global_pool:
|
580 |
+
x = x[:, 0] # class token
|
581 |
+
x = self.fc_norm(x)
|
582 |
+
x = self.head_drop(x)
|
583 |
+
return x if pre_logits else self.head(x)
|
584 |
+
|
585 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
586 |
+
x = self.forward_features(x)
|
587 |
+
if not self.ignore_head:
|
588 |
+
x = self.forward_head(x)
|
589 |
+
return x
|
590 |
+
|
591 |
+
|
592 |
+
@dataclass
|
593 |
+
class SigLIPVisionCfg:
|
594 |
+
width: int = 1152
|
595 |
+
layers: Union[Tuple[int, int, int, int], int] = 27
|
596 |
+
heads: int = 16
|
597 |
+
patch_size: int = 14
|
598 |
+
image_size: Union[Tuple[int, int], int] = 336
|
599 |
+
global_pool: str = "map"
|
600 |
+
mlp_ratio: float = 3.7362
|
601 |
+
class_token: bool = False
|
602 |
+
num_classes: int = 0
|
603 |
+
use_checkpoint: bool = False
|
604 |
+
|
605 |
+
|
606 |
+
SigLIP_MODEL_CONFIG = {
|
607 |
+
"siglip_so400m_patch14_384": {
|
608 |
+
"image_size": 336,
|
609 |
+
"patch_size": 14,
|
610 |
+
"width": 1152,
|
611 |
+
"layers": 27,
|
612 |
+
"heads": 16,
|
613 |
+
"mlp_ratio": 3.7362,
|
614 |
+
"global_pool": "map",
|
615 |
+
"use_checkpoint": False,
|
616 |
+
},
|
617 |
+
"siglip_so400m_patch14_224": {
|
618 |
+
"image_size": 224,
|
619 |
+
"patch_size": 14,
|
620 |
+
"width": 1152,
|
621 |
+
"layers": 27,
|
622 |
+
"heads": 16,
|
623 |
+
"mlp_ratio": 3.7362,
|
624 |
+
"global_pool": "map",
|
625 |
+
"use_checkpoint": False,
|
626 |
+
},
|
627 |
+
"siglip_large_patch16_384": {
|
628 |
+
"image_size": 384,
|
629 |
+
"patch_size": 16,
|
630 |
+
"width": 1024,
|
631 |
+
"layers": 24,
|
632 |
+
"heads": 16,
|
633 |
+
"mlp_ratio": 4,
|
634 |
+
"global_pool": "map",
|
635 |
+
"use_checkpoint": False,
|
636 |
+
},
|
637 |
+
}
|
638 |
+
|
639 |
+
|
640 |
+
def create_siglip_vit(
|
641 |
+
model_name: str = "siglip_so400m_patch14_384",
|
642 |
+
image_size: int = 384,
|
643 |
+
select_layer: int = -1,
|
644 |
+
ckpt_path: str = "",
|
645 |
+
**kwargs,
|
646 |
+
):
|
647 |
+
assert (
|
648 |
+
model_name in SigLIP_MODEL_CONFIG.keys()
|
649 |
+
), f"model name should be in {SigLIP_MODEL_CONFIG.keys()}"
|
650 |
+
|
651 |
+
vision_cfg = SigLIPVisionCfg(**SigLIP_MODEL_CONFIG[model_name])
|
652 |
+
|
653 |
+
if select_layer <= 0:
|
654 |
+
layers = min(vision_cfg.layers, vision_cfg.layers + select_layer + 1)
|
655 |
+
else:
|
656 |
+
layers = min(vision_cfg.layers, select_layer)
|
657 |
+
|
658 |
+
model = VisionTransformer(
|
659 |
+
img_size=image_size,
|
660 |
+
patch_size=vision_cfg.patch_size,
|
661 |
+
embed_dim=vision_cfg.width,
|
662 |
+
depth=layers,
|
663 |
+
num_heads=vision_cfg.heads,
|
664 |
+
mlp_ratio=vision_cfg.mlp_ratio,
|
665 |
+
class_token=vision_cfg.class_token,
|
666 |
+
global_pool=vision_cfg.global_pool,
|
667 |
+
ignore_head=kwargs.get("ignore_head", True),
|
668 |
+
weight_init=kwargs.get("weight_init", "skip"),
|
669 |
+
num_classes=0,
|
670 |
+
)
|
671 |
+
|
672 |
+
if ckpt_path:
|
673 |
+
state_dict = torch.load(ckpt_path, map_location="cpu")
|
674 |
+
|
675 |
+
incompatible_keys = model.load_state_dict(state_dict, strict=False)
|
676 |
+
print(
|
677 |
+
f"SigLIP-ViT restores from {ckpt_path},\n"
|
678 |
+
f"\tincompatible_keys:', {incompatible_keys}."
|
679 |
+
)
|
680 |
+
|
681 |
+
return model
|
deepseek_vl/utils/__init__.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023-2024 DeepSeek.
|
2 |
+
#
|
3 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy of
|
4 |
+
# this software and associated documentation files (the "Software"), to deal in
|
5 |
+
# the Software without restriction, including without limitation the rights to
|
6 |
+
# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
|
7 |
+
# the Software, and to permit persons to whom the Software is furnished to do so,
|
8 |
+
# subject to the following conditions:
|
9 |
+
#
|
10 |
+
# The above copyright notice and this permission notice shall be included in all
|
11 |
+
# copies or substantial portions of the Software.
|
12 |
+
#
|
13 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
14 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
|
15 |
+
# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
|
16 |
+
# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
|
17 |
+
# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
|
18 |
+
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
deepseek_vl/utils/conversation.py
ADDED
@@ -0,0 +1,348 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023-2024 DeepSeek.
|
2 |
+
#
|
3 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy of
|
4 |
+
# this software and associated documentation files (the "Software"), to deal in
|
5 |
+
# the Software without restriction, including without limitation the rights to
|
6 |
+
# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
|
7 |
+
# the Software, and to permit persons to whom the Software is furnished to do so,
|
8 |
+
# subject to the following conditions:
|
9 |
+
#
|
10 |
+
# The above copyright notice and this permission notice shall be included in all
|
11 |
+
# copies or substantial portions of the Software.
|
12 |
+
#
|
13 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
14 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
|
15 |
+
# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
|
16 |
+
# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
|
17 |
+
# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
|
18 |
+
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
19 |
+
|
20 |
+
"""
|
21 |
+
From https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
|
22 |
+
"""
|
23 |
+
|
24 |
+
import dataclasses
|
25 |
+
from enum import IntEnum, auto
|
26 |
+
from typing import Dict, List
|
27 |
+
|
28 |
+
|
29 |
+
class SeparatorStyle(IntEnum):
|
30 |
+
"""Separator styles."""
|
31 |
+
|
32 |
+
ADD_COLON_SINGLE = auto()
|
33 |
+
ADD_COLON_TWO = auto()
|
34 |
+
ADD_COLON_SPACE_SINGLE = auto()
|
35 |
+
NO_COLON_SINGLE = auto()
|
36 |
+
NO_COLON_TWO = auto()
|
37 |
+
ADD_NEW_LINE_SINGLE = auto()
|
38 |
+
LLAMA2 = auto()
|
39 |
+
CHATGLM = auto()
|
40 |
+
CHATML = auto()
|
41 |
+
CHATINTERN = auto()
|
42 |
+
DOLLY = auto()
|
43 |
+
RWKV = auto()
|
44 |
+
PHOENIX = auto()
|
45 |
+
ROBIN = auto()
|
46 |
+
DeepSeek = auto()
|
47 |
+
PLAIN = auto()
|
48 |
+
ALIGNMENT = auto()
|
49 |
+
|
50 |
+
|
51 |
+
@dataclasses.dataclass
|
52 |
+
class Conversation:
|
53 |
+
"""A class that manages prompt templates and keeps all conversation history."""
|
54 |
+
|
55 |
+
# The name of this template
|
56 |
+
name: str
|
57 |
+
# The template of the system prompt
|
58 |
+
system_template: str = "{system_message}"
|
59 |
+
# The system message
|
60 |
+
system_message: str = ""
|
61 |
+
# The names of two roles
|
62 |
+
roles: List[str] = (("USER", "ASSISTANT"),)
|
63 |
+
# All messages. Each item is (role, message).
|
64 |
+
messages: List[List[str]] = ()
|
65 |
+
# The number of few shot examples
|
66 |
+
offset: int = 0
|
67 |
+
# The separator style and configurations
|
68 |
+
sep_style: SeparatorStyle = SeparatorStyle.ADD_COLON_SINGLE
|
69 |
+
sep: str = "\n"
|
70 |
+
sep2: str = None
|
71 |
+
# Stop criteria (the default one is EOS token)
|
72 |
+
stop_str: str = None
|
73 |
+
# Stops generation if meeting any token in this list
|
74 |
+
stop_token_ids: List[int] = None
|
75 |
+
|
76 |
+
def get_prompt(self) -> str:
|
77 |
+
"""Get the prompt for generation."""
|
78 |
+
system_prompt = self.system_template.format(system_message=self.system_message)
|
79 |
+
|
80 |
+
if self.sep_style == SeparatorStyle.DeepSeek:
|
81 |
+
seps = [self.sep, self.sep2]
|
82 |
+
if system_prompt == "" or system_prompt is None:
|
83 |
+
ret = ""
|
84 |
+
else:
|
85 |
+
ret = system_prompt + seps[0]
|
86 |
+
for i, (role, message) in enumerate(self.messages):
|
87 |
+
if message:
|
88 |
+
ret += role + ": " + message + seps[i % 2]
|
89 |
+
else:
|
90 |
+
ret += role + ":"
|
91 |
+
return ret
|
92 |
+
elif self.sep_style == SeparatorStyle.LLAMA2:
|
93 |
+
seps = [self.sep, self.sep2]
|
94 |
+
if self.system_message:
|
95 |
+
ret = system_prompt
|
96 |
+
else:
|
97 |
+
ret = "[INST] "
|
98 |
+
for i, (role, message) in enumerate(self.messages):
|
99 |
+
tag = self.roles[i % 2]
|
100 |
+
if message:
|
101 |
+
if type(message) is tuple: # multimodal message
|
102 |
+
message, _ = message
|
103 |
+
if i == 0:
|
104 |
+
ret += message + " "
|
105 |
+
else:
|
106 |
+
ret += tag + " " + message + seps[i % 2]
|
107 |
+
else:
|
108 |
+
ret += tag
|
109 |
+
return ret
|
110 |
+
elif self.sep_style == SeparatorStyle.PLAIN:
|
111 |
+
seps = [self.sep, self.sep2]
|
112 |
+
ret = ""
|
113 |
+
for i, (role, message) in enumerate(self.messages):
|
114 |
+
if message:
|
115 |
+
if type(message) is tuple:
|
116 |
+
message, _, _ = message
|
117 |
+
if i % 2 == 0:
|
118 |
+
ret += message + seps[i % 2]
|
119 |
+
else:
|
120 |
+
ret += message + seps[i % 2]
|
121 |
+
else:
|
122 |
+
ret += ""
|
123 |
+
return ret
|
124 |
+
elif self.sep_style == SeparatorStyle.ALIGNMENT:
|
125 |
+
seps = [self.sep, self.sep2]
|
126 |
+
ret = ""
|
127 |
+
for i, (role, message) in enumerate(self.messages):
|
128 |
+
if message:
|
129 |
+
if type(message) is tuple:
|
130 |
+
message, _, _ = message
|
131 |
+
if i % 2 == 0:
|
132 |
+
ret += "<image>\n" + seps[i % 2]
|
133 |
+
else:
|
134 |
+
ret += message + seps[i % 2]
|
135 |
+
else:
|
136 |
+
ret += ""
|
137 |
+
return ret
|
138 |
+
else:
|
139 |
+
raise ValueError(f"Invalid style: {self.sep_style}")
|
140 |
+
|
141 |
+
def get_prompt_for_current_round(self, content=None):
|
142 |
+
"""Get current round formatted question prompt during sft training"""
|
143 |
+
if self.sep_style == SeparatorStyle.PLAIN:
|
144 |
+
formatted_question = "<image>\n"
|
145 |
+
elif self.sep_style == SeparatorStyle.DeepSeek:
|
146 |
+
formatted_question = (
|
147 |
+
f"{self.roles[0]}: " + content.strip() + self.sep + f"{self.roles[1]}:"
|
148 |
+
)
|
149 |
+
else:
|
150 |
+
raise ValueError(f"Unsupported sep_style: {self.sep_style}")
|
151 |
+
return formatted_question
|
152 |
+
|
153 |
+
def set_system_message(self, system_message: str):
|
154 |
+
"""Set the system message."""
|
155 |
+
self.system_message = system_message
|
156 |
+
|
157 |
+
def append_message(self, role: str, message: str):
|
158 |
+
"""Append a new message."""
|
159 |
+
self.messages.append([role, message])
|
160 |
+
|
161 |
+
def reset_message(self):
|
162 |
+
"""Reset a new message."""
|
163 |
+
self.messages = []
|
164 |
+
|
165 |
+
def update_last_message(self, message: str):
|
166 |
+
"""Update the last output.
|
167 |
+
|
168 |
+
The last message is typically set to be None when constructing the prompt,
|
169 |
+
so we need to update it in-place after getting the response from a model.
|
170 |
+
"""
|
171 |
+
self.messages[-1][1] = message
|
172 |
+
|
173 |
+
def to_gradio_chatbot(self):
|
174 |
+
"""Convert the conversation to gradio chatbot format."""
|
175 |
+
ret = []
|
176 |
+
for i, (role, msg) in enumerate(self.messages[self.offset :]):
|
177 |
+
if i % 2 == 0:
|
178 |
+
ret.append([msg, None])
|
179 |
+
else:
|
180 |
+
ret[-1][-1] = msg
|
181 |
+
return ret
|
182 |
+
|
183 |
+
def to_openai_api_messages(self):
|
184 |
+
"""Convert the conversation to OpenAI chat completion format."""
|
185 |
+
system_prompt = self.system_template.format(system_message=self.system_message)
|
186 |
+
ret = [{"role": "system", "content": system_prompt}]
|
187 |
+
|
188 |
+
for i, (_, msg) in enumerate(self.messages[self.offset :]):
|
189 |
+
if i % 2 == 0:
|
190 |
+
ret.append({"role": "user", "content": msg})
|
191 |
+
else:
|
192 |
+
if msg is not None:
|
193 |
+
ret.append({"role": "assistant", "content": msg})
|
194 |
+
return ret
|
195 |
+
|
196 |
+
def copy(self):
|
197 |
+
return Conversation(
|
198 |
+
name=self.name,
|
199 |
+
system_template=self.system_template,
|
200 |
+
system_message=self.system_message,
|
201 |
+
roles=self.roles,
|
202 |
+
messages=[[x, y] for x, y in self.messages],
|
203 |
+
offset=self.offset,
|
204 |
+
sep_style=self.sep_style,
|
205 |
+
sep=self.sep,
|
206 |
+
sep2=self.sep2,
|
207 |
+
stop_str=self.stop_str,
|
208 |
+
stop_token_ids=self.stop_token_ids,
|
209 |
+
)
|
210 |
+
|
211 |
+
def dict(self):
|
212 |
+
return {
|
213 |
+
"template_name": self.name,
|
214 |
+
"system_message": self.system_message,
|
215 |
+
"roles": self.roles,
|
216 |
+
"messages": self.messages,
|
217 |
+
"offset": self.offset,
|
218 |
+
}
|
219 |
+
|
220 |
+
|
221 |
+
# A global registry for all conversation templates
|
222 |
+
conv_templates: Dict[str, Conversation] = {}
|
223 |
+
|
224 |
+
|
225 |
+
def register_conv_template(template: Conversation, override: bool = False):
|
226 |
+
"""Register a new conversation template."""
|
227 |
+
if not override:
|
228 |
+
assert (
|
229 |
+
template.name not in conv_templates
|
230 |
+
), f"{template.name} has been registered."
|
231 |
+
|
232 |
+
conv_templates[template.name] = template
|
233 |
+
|
234 |
+
|
235 |
+
def get_conv_template(name: str) -> Conversation:
|
236 |
+
"""Get a conversation template."""
|
237 |
+
return conv_templates[name].copy()
|
238 |
+
|
239 |
+
|
240 |
+
# llava_llama2 template
|
241 |
+
register_conv_template(
|
242 |
+
Conversation(
|
243 |
+
name="llava_llama2",
|
244 |
+
system_message="You are a helpful language and vision assistant. "
|
245 |
+
"You are able to understand the visual content that the user provides, "
|
246 |
+
"and assist the user with a variety of tasks using natural language.",
|
247 |
+
system_template="[INST] <<SYS>>\n{system_message}\n<</SYS>>\n\n",
|
248 |
+
roles=("[INST]", "[/INST]"),
|
249 |
+
messages=(),
|
250 |
+
offset=0,
|
251 |
+
sep_style=SeparatorStyle.LLAMA2,
|
252 |
+
sep=" ",
|
253 |
+
sep2=" </s><s>",
|
254 |
+
stop_token_ids=[2],
|
255 |
+
)
|
256 |
+
)
|
257 |
+
|
258 |
+
# llama2 template
|
259 |
+
# reference: https://github.com/facebookresearch/llama/blob/cfc3fc8c1968d390eb830e65c63865e980873a06/llama/generation.py#L212
|
260 |
+
register_conv_template(
|
261 |
+
Conversation(
|
262 |
+
name="llama-2",
|
263 |
+
system_template="[INST] <<SYS>>\n{system_message}\n<</SYS>>\n\n",
|
264 |
+
roles=("[INST]", "[/INST]"),
|
265 |
+
messages=(),
|
266 |
+
offset=0,
|
267 |
+
sep_style=SeparatorStyle.LLAMA2,
|
268 |
+
sep=" ",
|
269 |
+
sep2=" </s><s>",
|
270 |
+
stop_token_ids=[2],
|
271 |
+
)
|
272 |
+
)
|
273 |
+
|
274 |
+
|
275 |
+
# deepseek template
|
276 |
+
register_conv_template(
|
277 |
+
Conversation(
|
278 |
+
name="deepseek",
|
279 |
+
system_template="{system_message}",
|
280 |
+
# system_message="You are a helpful assistant. Please answer truthfully and write out your "
|
281 |
+
# "thinking step by step to be sure you get the right answer.",
|
282 |
+
system_message="",
|
283 |
+
roles=("User", "Assistant"),
|
284 |
+
messages=(),
|
285 |
+
offset=0,
|
286 |
+
sep_style=SeparatorStyle.DeepSeek,
|
287 |
+
sep="\n\n",
|
288 |
+
sep2="<|end▁of▁sentence|>",
|
289 |
+
stop_token_ids=[100001],
|
290 |
+
stop_str=["User:", "<|end▁of▁sentence|>"],
|
291 |
+
)
|
292 |
+
)
|
293 |
+
|
294 |
+
register_conv_template(
|
295 |
+
Conversation(
|
296 |
+
name="plain",
|
297 |
+
system_template="",
|
298 |
+
system_message="",
|
299 |
+
roles=("", ""),
|
300 |
+
messages=(),
|
301 |
+
offset=0,
|
302 |
+
sep_style=SeparatorStyle.PLAIN,
|
303 |
+
sep="",
|
304 |
+
sep2="",
|
305 |
+
stop_token_ids=[2],
|
306 |
+
stop_str=["</s>"],
|
307 |
+
)
|
308 |
+
)
|
309 |
+
|
310 |
+
|
311 |
+
register_conv_template(
|
312 |
+
Conversation(
|
313 |
+
name="alignment",
|
314 |
+
system_template="",
|
315 |
+
system_message="",
|
316 |
+
roles=("", ""),
|
317 |
+
messages=(),
|
318 |
+
offset=0,
|
319 |
+
sep_style=SeparatorStyle.ALIGNMENT,
|
320 |
+
sep="",
|
321 |
+
sep2="",
|
322 |
+
stop_token_ids=[2],
|
323 |
+
stop_str=["</s>"],
|
324 |
+
)
|
325 |
+
)
|
326 |
+
|
327 |
+
|
328 |
+
if __name__ == "__main__":
|
329 |
+
# print("Llama-2 template:")
|
330 |
+
# conv = get_conv_template("llama-2")
|
331 |
+
# conv.set_system_message("You are a helpful, respectful and honest assistant.")
|
332 |
+
# conv.append_message(conv.roles[0], "Hello!")
|
333 |
+
# conv.append_message(conv.roles[1], "Hi!")
|
334 |
+
# conv.append_message(conv.roles[0], "How are you?")
|
335 |
+
# conv.append_message(conv.roles[1], None)
|
336 |
+
# print(conv.get_prompt())
|
337 |
+
|
338 |
+
# print("\n")
|
339 |
+
|
340 |
+
print("deepseek template:")
|
341 |
+
conv = get_conv_template("deepseek")
|
342 |
+
conv.append_message(conv.roles[0], "Hello!")
|
343 |
+
conv.append_message(conv.roles[1], "Hi! This is Tony.")
|
344 |
+
conv.append_message(conv.roles[0], "Who are you?")
|
345 |
+
conv.append_message(conv.roles[1], "I am a helpful assistant.")
|
346 |
+
conv.append_message(conv.roles[0], "How are you?")
|
347 |
+
conv.append_message(conv.roles[1], None)
|
348 |
+
print(conv.get_prompt())
|
deepseek_vl/utils/io.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023-2024 DeepSeek.
|
2 |
+
#
|
3 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy of
|
4 |
+
# this software and associated documentation files (the "Software"), to deal in
|
5 |
+
# the Software without restriction, including without limitation the rights to
|
6 |
+
# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
|
7 |
+
# the Software, and to permit persons to whom the Software is furnished to do so,
|
8 |
+
# subject to the following conditions:
|
9 |
+
#
|
10 |
+
# The above copyright notice and this permission notice shall be included in all
|
11 |
+
# copies or substantial portions of the Software.
|
12 |
+
#
|
13 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
14 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
|
15 |
+
# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
|
16 |
+
# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
|
17 |
+
# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
|
18 |
+
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
19 |
+
|
20 |
+
import json
|
21 |
+
from typing import Dict, List
|
22 |
+
|
23 |
+
import PIL.Image
|
24 |
+
import torch
|
25 |
+
from transformers import AutoModelForCausalLM
|
26 |
+
|
27 |
+
from deepseek_vl.models import MultiModalityCausalLM, VLChatProcessor
|
28 |
+
|
29 |
+
|
30 |
+
def load_pretrained_model(model_path: str):
|
31 |
+
vl_chat_processor: VLChatProcessor = VLChatProcessor.from_pretrained(model_path)
|
32 |
+
tokenizer = vl_chat_processor.tokenizer
|
33 |
+
|
34 |
+
vl_gpt: MultiModalityCausalLM = AutoModelForCausalLM.from_pretrained(
|
35 |
+
model_path, trust_remote_code=True
|
36 |
+
)
|
37 |
+
vl_gpt = vl_gpt.to(torch.bfloat16).cuda().eval()
|
38 |
+
|
39 |
+
return tokenizer, vl_chat_processor, vl_gpt
|
40 |
+
|
41 |
+
|
42 |
+
def load_pil_images(conversations: List[Dict[str, str]]) -> List[PIL.Image.Image]:
|
43 |
+
"""
|
44 |
+
|
45 |
+
Args:
|
46 |
+
conversations (List[Dict[str, str]]): the conversations with a list of messages. An example is :
|
47 |
+
[
|
48 |
+
{
|
49 |
+
"role": "User",
|
50 |
+
"content": "<image_placeholder>\nExtract all information from this image and convert them into markdown format.",
|
51 |
+
"images": ["./examples/table_datasets.png"]
|
52 |
+
},
|
53 |
+
{"role": "Assistant", "content": ""},
|
54 |
+
]
|
55 |
+
|
56 |
+
Returns:
|
57 |
+
pil_images (List[PIL.Image.Image]): the list of PIL images.
|
58 |
+
|
59 |
+
"""
|
60 |
+
|
61 |
+
pil_images = []
|
62 |
+
|
63 |
+
for message in conversations:
|
64 |
+
if "images" not in message:
|
65 |
+
continue
|
66 |
+
|
67 |
+
for image_path in message["images"]:
|
68 |
+
pil_img = PIL.Image.open(image_path)
|
69 |
+
pil_img = pil_img.convert("RGB")
|
70 |
+
pil_images.append(pil_img)
|
71 |
+
|
72 |
+
return pil_images
|
73 |
+
|
74 |
+
|
75 |
+
def load_json(filepath):
|
76 |
+
with open(filepath, "r") as f:
|
77 |
+
data = json.load(f)
|
78 |
+
return data
|
examples/app.png
ADDED
![]() |
examples/chart.png
ADDED
![]() |
examples/mirror.png
ADDED
![]() |
examples/pipeline.png
ADDED
![]() |
examples/puzzle.png
ADDED
![]() |
examples/rap.jpeg
ADDED
![]() |
inference.py
ADDED
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023-2024 DeepSeek.
|
2 |
+
#
|
3 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy of
|
4 |
+
# this software and associated documentation files (the "Software"), to deal in
|
5 |
+
# the Software without restriction, including without limitation the rights to
|
6 |
+
# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
|
7 |
+
# the Software, and to permit persons to whom the Software is furnished to do so,
|
8 |
+
# subject to the following conditions:
|
9 |
+
#
|
10 |
+
# The above copyright notice and this permission notice shall be included in all
|
11 |
+
# copies or substantial portions of the Software.
|
12 |
+
#
|
13 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
14 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
|
15 |
+
# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
|
16 |
+
# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
|
17 |
+
# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
|
18 |
+
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
19 |
+
|
20 |
+
from threading import Thread
|
21 |
+
from typing import List
|
22 |
+
|
23 |
+
import torch
|
24 |
+
import transformers
|
25 |
+
from transformers import (
|
26 |
+
AutoModelForCausalLM,
|
27 |
+
StoppingCriteria,
|
28 |
+
StoppingCriteriaList,
|
29 |
+
TextIteratorStreamer,
|
30 |
+
)
|
31 |
+
|
32 |
+
from deepseek_vl.models import MultiModalityCausalLM, VLChatProcessor
|
33 |
+
from deepseek_vl.utils.conversation import Conversation
|
34 |
+
|
35 |
+
|
36 |
+
def load_model(model_path):
|
37 |
+
vl_chat_processor: VLChatProcessor = VLChatProcessor.from_pretrained(model_path)
|
38 |
+
tokenizer = vl_chat_processor.tokenizer
|
39 |
+
vl_gpt: MultiModalityCausalLM = AutoModelForCausalLM.from_pretrained(
|
40 |
+
model_path, trust_remote_code=True
|
41 |
+
)
|
42 |
+
vl_gpt = vl_gpt.to(torch.bfloat16).cuda().eval()
|
43 |
+
return tokenizer, vl_gpt, vl_chat_processor
|
44 |
+
|
45 |
+
|
46 |
+
def convert_conversation_to_prompts(conversation: Conversation):
|
47 |
+
prompts = []
|
48 |
+
messages = conversation.messages
|
49 |
+
|
50 |
+
for i in range(0, len(messages), 2):
|
51 |
+
prompt = {
|
52 |
+
"role": messages[i][0],
|
53 |
+
"content": (
|
54 |
+
messages[i][1][0]
|
55 |
+
if isinstance(messages[i][1], tuple)
|
56 |
+
else messages[i][1]
|
57 |
+
),
|
58 |
+
"images": [messages[i][1][1]] if isinstance(messages[i][1], tuple) else [],
|
59 |
+
}
|
60 |
+
response = {"role": messages[i + 1][0], "content": messages[i + 1][1]}
|
61 |
+
prompts.extend([prompt, response])
|
62 |
+
|
63 |
+
return prompts
|
64 |
+
|
65 |
+
|
66 |
+
class StoppingCriteriaSub(StoppingCriteria):
|
67 |
+
def __init__(self, stops=[], encounters=1):
|
68 |
+
super().__init__()
|
69 |
+
self.stops = [stop.to("cuda") for stop in stops]
|
70 |
+
|
71 |
+
def __call__(
|
72 |
+
self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
|
73 |
+
):
|
74 |
+
for stop in self.stops:
|
75 |
+
if input_ids.shape[-1] < len(stop):
|
76 |
+
continue
|
77 |
+
if torch.all((stop == input_ids[0][-len(stop) :])).item():
|
78 |
+
return True
|
79 |
+
|
80 |
+
return False
|
81 |
+
|
82 |
+
|
83 |
+
@torch.inference_mode()
|
84 |
+
def deepseek_generate(
|
85 |
+
prompts: list,
|
86 |
+
vl_gpt: torch.nn.Module,
|
87 |
+
vl_chat_processor,
|
88 |
+
tokenizer: transformers.PreTrainedTokenizer,
|
89 |
+
stop_words: list,
|
90 |
+
max_length: int = 256,
|
91 |
+
temperature: float = 1.0,
|
92 |
+
top_p: float = 1.0,
|
93 |
+
repetition_penalty=1.1,
|
94 |
+
):
|
95 |
+
prompts = prompts
|
96 |
+
pil_images = list()
|
97 |
+
for message in prompts:
|
98 |
+
if "images" not in message:
|
99 |
+
continue
|
100 |
+
for pil_img in message["images"]:
|
101 |
+
pil_images.append(pil_img)
|
102 |
+
|
103 |
+
prepare_inputs = vl_chat_processor(
|
104 |
+
conversations=prompts, images=pil_images, force_batchify=True
|
105 |
+
).to(vl_gpt.device)
|
106 |
+
|
107 |
+
return generate(
|
108 |
+
vl_gpt,
|
109 |
+
tokenizer,
|
110 |
+
prepare_inputs,
|
111 |
+
max_length,
|
112 |
+
temperature,
|
113 |
+
repetition_penalty,
|
114 |
+
top_p,
|
115 |
+
stop_words,
|
116 |
+
)
|
117 |
+
|
118 |
+
|
119 |
+
@torch.inference_mode()
|
120 |
+
def generate(
|
121 |
+
vl_gpt,
|
122 |
+
tokenizer,
|
123 |
+
prepare_inputs,
|
124 |
+
max_gen_len: int = 256,
|
125 |
+
temperature: float = 0,
|
126 |
+
repetition_penalty=1.1,
|
127 |
+
top_p: float = 0.95,
|
128 |
+
stop_words: List[str] = [],
|
129 |
+
):
|
130 |
+
"""Stream the text output from the multimodality model with prompt and image inputs."""
|
131 |
+
inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs)
|
132 |
+
|
133 |
+
streamer = TextIteratorStreamer(tokenizer)
|
134 |
+
|
135 |
+
stop_words_ids = [
|
136 |
+
torch.tensor(tokenizer.encode(stop_word)) for stop_word in stop_words
|
137 |
+
]
|
138 |
+
stopping_criteria = StoppingCriteriaList(
|
139 |
+
[StoppingCriteriaSub(stops=stop_words_ids)]
|
140 |
+
)
|
141 |
+
|
142 |
+
generation_config = dict(
|
143 |
+
inputs_embeds=inputs_embeds,
|
144 |
+
attention_mask=prepare_inputs.attention_mask,
|
145 |
+
pad_token_id=tokenizer.eos_token_id,
|
146 |
+
bos_token_id=tokenizer.bos_token_id,
|
147 |
+
eos_token_id=tokenizer.eos_token_id,
|
148 |
+
max_new_tokens=max_gen_len,
|
149 |
+
do_sample=True,
|
150 |
+
use_cache=True,
|
151 |
+
streamer=streamer,
|
152 |
+
stopping_criteria=stopping_criteria,
|
153 |
+
)
|
154 |
+
|
155 |
+
if temperature > 0:
|
156 |
+
generation_config.update(
|
157 |
+
{
|
158 |
+
"do_sample": True,
|
159 |
+
"top_p": top_p,
|
160 |
+
"temperature": temperature,
|
161 |
+
"repetition_penalty": repetition_penalty,
|
162 |
+
}
|
163 |
+
)
|
164 |
+
else:
|
165 |
+
generation_config["do_sample"] = False
|
166 |
+
|
167 |
+
thread = Thread(target=vl_gpt.language_model.generate, kwargs=generation_config)
|
168 |
+
thread.start()
|
169 |
+
|
170 |
+
yield from streamer
|
requirements.txt
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch==2.0.1
|
2 |
+
transformers>=4.38.2
|
3 |
+
timm>=0.9.16
|
4 |
+
accelerate
|
5 |
+
sentencepiece
|
6 |
+
attrdict
|
7 |
+
einops
|
8 |
+
|
9 |
+
# for gradio demo
|
10 |
+
gradio==3.48.0
|
11 |
+
gradio-client==0.6.1
|
12 |
+
mdtex2html==1.3.0
|
13 |
+
pypinyin==0.50.0
|
14 |
+
tiktoken==0.5.2
|
15 |
+
tqdm==4.64.0
|
16 |
+
colorama==0.4.5
|
17 |
+
Pygments==2.12.0
|
18 |
+
markdown==3.4.1
|
19 |
+
SentencePiece==0.1.96
|