Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -4,234 +4,91 @@ import torch
|
|
4 |
from threading import Thread
|
5 |
import spaces
|
6 |
import time
|
7 |
-
|
8 |
-
|
9 |
-
# For the advanced UI components
|
10 |
-
import modelscope_studio.components.antd as antd
|
11 |
-
import modelscope_studio.components.antdx as antdx
|
12 |
-
import modelscope_studio.components.base as ms
|
13 |
-
import modelscope_studio.components.pro as pro
|
14 |
-
from modelscope_studio.components.pro.chatbot import (ChatbotBotConfig,
|
15 |
-
ChatbotPromptsConfig,
|
16 |
-
ChatbotUserConfig,
|
17 |
-
ChatbotWelcomeConfig)
|
18 |
-
|
19 |
-
# --- 1. Load the Hugging Face Model and Tokenizer ---
|
20 |
model_name = "sarvamai/sarvam-m"
|
21 |
-
print(f"Loading model: {model_name}...")
|
22 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
23 |
-
model = AutoModelForCausalLM.from_pretrained(
|
24 |
-
model_name,
|
25 |
-
torch_dtype=torch.float16,
|
26 |
-
device_map="auto"
|
27 |
-
)
|
28 |
-
print("Model loaded successfully.")
|
29 |
-
|
30 |
-
|
31 |
-
# --- 2. Helper and Event Handler Functions ---
|
32 |
-
|
33 |
-
def format_history_for_sarvam(history: list) -> list:
|
34 |
-
messages = []
|
35 |
-
if not history:
|
36 |
-
return messages
|
37 |
-
for item in history:
|
38 |
-
role = item.get("role")
|
39 |
-
content = item.get("content")
|
40 |
-
if role == "user":
|
41 |
-
messages.append({"role": "user", "content": content})
|
42 |
-
elif role == "assistant":
|
43 |
-
final_content = ""
|
44 |
-
if isinstance(content, list):
|
45 |
-
for part in content:
|
46 |
-
if part.get("type") == "text":
|
47 |
-
final_content = part.get("content", "")
|
48 |
-
break
|
49 |
-
elif isinstance(content, str):
|
50 |
-
final_content = content
|
51 |
-
if final_content:
|
52 |
-
messages.append({"role": "assistant", "content": final_content})
|
53 |
-
return messages
|
54 |
|
55 |
@spaces.GPU
|
56 |
-
def
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
}
|
73 |
-
|
74 |
-
try:
|
75 |
-
history_messages = format_history_for_sarvam(chatbot_value)
|
76 |
-
prompt_text = tokenizer.apply_chat_template(
|
77 |
-
history_messages, tokenize=False, add_generation_prompt=True, enable_thinking=True
|
78 |
-
)
|
79 |
-
model_inputs = tokenizer([prompt_text], return_tensors="pt").to(model.device)
|
80 |
-
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
81 |
-
generation_kwargs = dict(
|
82 |
-
input_ids=model_inputs.input_ids,
|
83 |
-
max_new_tokens=8192,
|
84 |
-
do_sample=True,
|
85 |
-
temperature=0.7,
|
86 |
-
streamer=streamer,
|
87 |
-
)
|
88 |
-
thread = Thread(target=model.generate, kwargs=generation_kwargs)
|
89 |
-
thread.start()
|
90 |
-
|
91 |
-
start_time = time.time()
|
92 |
-
message_content = chatbot_value[-1]["content"]
|
93 |
-
message_content.append({
|
94 |
-
"copyable": False, "editable": False, "type": "tool", "content": "",
|
95 |
-
"options": {"title": "Thinking...", "status": "pending"}
|
96 |
-
})
|
97 |
-
message_content.append({"type": "text", "content": ""})
|
98 |
-
chatbot_value[-1]["loading"] = False
|
99 |
-
full_response = ""
|
100 |
-
thinking_content = ""
|
101 |
-
main_content = ""
|
102 |
-
thinking_done = False
|
103 |
-
|
104 |
-
for new_text in streamer:
|
105 |
-
full_response += new_text
|
106 |
-
if not thinking_done and "</think>" in full_response:
|
107 |
-
thinking_done = True
|
108 |
-
try:
|
109 |
-
parts = full_response.split("</think>", 1)
|
110 |
-
thinking_content = parts[0].split("<think>", 1)[1]
|
111 |
-
main_content = parts[1]
|
112 |
-
thought_cost_time = "{:.2f}".format(time.time() - start_time)
|
113 |
-
message_content[0]["content"] = thinking_content.strip()
|
114 |
-
message_content[0]["options"]["title"] = f"End of Thought ({thought_cost_time}s)"
|
115 |
-
message_content[0]["options"]["status"] = "done"
|
116 |
-
except IndexError:
|
117 |
-
main_content = full_response
|
118 |
-
elif not thinking_done:
|
119 |
-
if full_response.lstrip().startswith("<think>"):
|
120 |
-
thinking_content = full_response.lstrip()[len("<think>"):]
|
121 |
-
message_content[0]["content"] = thinking_content.strip()
|
122 |
-
else:
|
123 |
-
main_content = full_response.split("</think>", 1)[1]
|
124 |
-
|
125 |
-
message_content[1]["content"] = main_content.lstrip("\n")
|
126 |
-
|
127 |
-
# <-- 3. APPLY FIX HERE
|
128 |
-
yield {chatbot: gr.update(value=copy.deepcopy(chatbot_value))}
|
129 |
-
|
130 |
-
chatbot_value[-1]["footer"] = "{:.2f}s".format(time.time() - start_time)
|
131 |
-
chatbot_value[-1]["status"] = "done"
|
132 |
-
|
133 |
-
# <-- 4. APPLY FIX HERE
|
134 |
-
yield {
|
135 |
-
clear_btn: gr.update(disabled=False),
|
136 |
-
sender: gr.update(loading=False),
|
137 |
-
chatbot: gr.update(value=copy.deepcopy(chatbot_value)),
|
138 |
-
}
|
139 |
-
|
140 |
-
except Exception as e:
|
141 |
-
print(f"An error occurred: {e}")
|
142 |
-
chatbot_value[-1]["loading"] = False
|
143 |
-
chatbot_value[-1]["status"] = "done"
|
144 |
-
chatbot_value[-1]["content"] = f"Failed to respond due to an error: {e}"
|
145 |
-
|
146 |
-
# <-- 5. APPLY FIX HERE
|
147 |
-
yield {
|
148 |
-
clear_btn: gr.update(disabled=False),
|
149 |
-
sender: gr.update(loading=False),
|
150 |
-
chatbot: gr.update(value=copy.deepcopy(chatbot_value)),
|
151 |
-
}
|
152 |
-
|
153 |
-
def prompt_select(e: gr.EventData):
|
154 |
-
return gr.update(value=e._data["payload"][0]["value"]["description"])
|
155 |
-
|
156 |
-
def clear():
|
157 |
-
return gr.update(value=None)
|
158 |
-
|
159 |
-
def retry(chatbot_value: list, e: gr.EventData):
|
160 |
-
index = e._data["payload"][0]["index"]
|
161 |
-
chatbot_value = chatbot_value[:index-1]
|
162 |
-
yield {
|
163 |
-
sender: gr.update(loading=True),
|
164 |
-
chatbot: gr.update(value=chatbot_value),
|
165 |
-
clear_btn: gr.update(disabled=True)
|
166 |
-
}
|
167 |
-
for chunk in submit(None, chatbot_value):
|
168 |
-
yield chunk
|
169 |
-
|
170 |
-
def cancel(chatbot_value: list):
|
171 |
-
if chatbot_value and chatbot_value[-1].get("status") == "pending":
|
172 |
-
chatbot_value[-1]["loading"] = False
|
173 |
-
chatbot_value[-1]["status"] = "done"
|
174 |
-
chatbot_value[-1]["footer"] = "Chat completion paused"
|
175 |
-
return {
|
176 |
-
chatbot: gr.update(value=chatbot_value),
|
177 |
-
sender: gr.update(loading=False),
|
178 |
-
clear_btn: gr.update(disabled=False)
|
179 |
-
}
|
180 |
-
|
181 |
-
# --- 3. Build the Gradio UI ---
|
182 |
-
with gr.Blocks(theme=gr.themes.Default(primary_hue="blue", secondary_hue="blue")) as demo, ms.Application(), antdx.XProvider():
|
183 |
-
with antd.Flex(vertical=True, gap="middle"):
|
184 |
-
chatbot = pro.Chatbot(
|
185 |
-
height=650,
|
186 |
-
welcome_config=ChatbotWelcomeConfig(
|
187 |
-
variant="borderless",
|
188 |
-
icon="https://cdn-avatars.huggingface.co/v1/production/uploads/60270a7c32856987162c641a/umd13GCWVijwTDGZzw3q-.png",
|
189 |
-
title=f"Hello, I'm {model_name.split('/')[-1]}",
|
190 |
-
description="I can show you my thinking process. How can I help you today?",
|
191 |
-
prompts=ChatbotPromptsConfig(
|
192 |
-
items=[
|
193 |
-
{"label": "Explain a concept", "children": [{"description": "Explain what a Large Language Model is in simple terms."}]},
|
194 |
-
{"label": "Help me write", "children": [{"description": "Write a short, futuristic story about AI companions."}]},
|
195 |
-
{"label": "Creative Ideas", "children": [{"description": "Give me three creative names for a new coffee shop."}]},
|
196 |
-
{"label": "Code generation", "children": [{"description": "Write a python function to find the factorial of a number."}]}
|
197 |
-
]
|
198 |
-
)
|
199 |
-
),
|
200 |
-
user_config=ChatbotUserConfig(avatar="https://api.dicebear.com/7.x/miniavs/svg?seed=gradio"),
|
201 |
-
bot_config=ChatbotBotConfig(
|
202 |
-
header=model_name,
|
203 |
-
avatar="https://cdn-avatars.huggingface.co/v1/production/uploads/60270a7c32856987162c641a/umd13GCWVijwTDGZzw3q-.png",
|
204 |
-
actions=["copy", "retry"]
|
205 |
-
),
|
206 |
-
)
|
207 |
-
with antdx.Sender() as sender:
|
208 |
-
with ms.Slot("prefix"):
|
209 |
-
with antd.Button(value=None, color="default", variant="text") as clear_btn:
|
210 |
-
with ms.Slot("icon"):
|
211 |
-
antd.Icon("ClearOutlined")
|
212 |
-
|
213 |
-
clear_btn.click(fn=clear, outputs=[chatbot])
|
214 |
-
submit_event = sender.submit(
|
215 |
-
fn=submit,
|
216 |
-
inputs=[sender, chatbot],
|
217 |
-
outputs=[sender, chatbot, clear_btn]
|
218 |
-
)
|
219 |
-
sender.cancel(
|
220 |
-
fn=cancel,
|
221 |
-
inputs=[chatbot],
|
222 |
-
outputs=[chatbot, sender, clear_btn],
|
223 |
-
cancels=[submit_event],
|
224 |
-
queue=False
|
225 |
-
)
|
226 |
-
chatbot.retry(
|
227 |
-
fn=retry,
|
228 |
-
inputs=[chatbot],
|
229 |
-
outputs=[sender, chatbot, clear_btn]
|
230 |
-
)
|
231 |
-
chatbot.welcome_prompt_select(
|
232 |
-
fn=prompt_select,
|
233 |
-
outputs=[sender]
|
234 |
)
|
235 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
236 |
if __name__ == "__main__":
|
237 |
-
demo.
|
|
|
4 |
from threading import Thread
|
5 |
import spaces
|
6 |
import time
|
7 |
+
|
8 |
+
# Load the model and tokenizer
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
model_name = "sarvamai/sarvam-m"
|
|
|
10 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
11 |
+
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map="auto")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
|
13 |
@spaces.GPU
|
14 |
+
def generate_response(prompt, chat_history):
|
15 |
+
messages = [{"role": "user", "content": prompt}]
|
16 |
+
text = tokenizer.apply_chat_template(messages, tokenize=False, enable_thinking=True)
|
17 |
+
|
18 |
+
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
|
19 |
+
|
20 |
+
# Use TextIteratorStreamer for streaming
|
21 |
+
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
22 |
+
|
23 |
+
# Conduct text generation with streaming
|
24 |
+
generation_kwargs = dict(
|
25 |
+
input_ids=model_inputs.input_ids,
|
26 |
+
max_new_tokens=8192,
|
27 |
+
do_sample=True,
|
28 |
+
temperature=0.7,
|
29 |
+
streamer=streamer,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
)
|
31 |
|
32 |
+
thread = Thread(target=model.generate, kwargs=generation_kwargs)
|
33 |
+
thread.start()
|
34 |
+
|
35 |
+
# Initialize variables to capture reasoning content and main content
|
36 |
+
reasoning_content = ""
|
37 |
+
content = ""
|
38 |
+
start_time = time.time()
|
39 |
+
|
40 |
+
# First yield: Show thinking has started
|
41 |
+
yield chat_history + [(None, "Thinking...")], ""
|
42 |
+
|
43 |
+
for new_text in streamer:
|
44 |
+
if "</think>" in new_text:
|
45 |
+
parts = new_text.split("</think>")
|
46 |
+
reasoning_content = parts[0].rstrip("\n")
|
47 |
+
content = parts[-1].lstrip("\n").rstrip("</s>")
|
48 |
+
|
49 |
+
# Calculate thinking time
|
50 |
+
thinking_time = time.time() - start_time
|
51 |
+
|
52 |
+
# Yield the thinking process
|
53 |
+
yield chat_history + [
|
54 |
+
(None, f"Thinking..."),
|
55 |
+
(None, f"Thinking completed. Thought for {thinking_time:.1f} seconds."),
|
56 |
+
(None, f"Thought process:\n{reasoning_content}")
|
57 |
+
], ""
|
58 |
+
else:
|
59 |
+
content += new_text
|
60 |
+
# Yield the content as it's being generated
|
61 |
+
yield chat_history + [
|
62 |
+
(None, f"Thinking..."),
|
63 |
+
(None, f"Thinking completed. Thought for {time.time() - start_time:.1f} seconds."),
|
64 |
+
(None, f"Thought process:\n{reasoning_content}"),
|
65 |
+
(None, content)
|
66 |
+
], ""
|
67 |
+
|
68 |
+
# Final yield with complete response
|
69 |
+
yield chat_history + [
|
70 |
+
(None, f"Thinking..."),
|
71 |
+
(None, f"Thinking completed. Thought for {time.time() - start_time:.1f} seconds."),
|
72 |
+
(None, f"Thought process:\n{reasoning_content}"),
|
73 |
+
(prompt, f"{reasoning_content}\n{content}" if reasoning_content else content)
|
74 |
+
], ""
|
75 |
+
|
76 |
+
# Create the Gradio interface
|
77 |
+
with gr.Blocks() as demo:
|
78 |
+
gr.Markdown("# Sarvam M Demo")
|
79 |
+
chatbot = gr.Chatbot(show_copy_button=True)
|
80 |
+
msg = gr.Textbox(label="Your Message")
|
81 |
+
|
82 |
+
def respond(message, chat_history):
|
83 |
+
# Start with the user message
|
84 |
+
chat_history.append((message, None))
|
85 |
+
yield chat_history, ""
|
86 |
+
|
87 |
+
# Then stream the assistant's response
|
88 |
+
for updated_history, _ in generate_response(message, chat_history):
|
89 |
+
yield updated_history, ""
|
90 |
+
|
91 |
+
msg.submit(respond, [msg, chatbot], [chatbot, msg])
|
92 |
+
|
93 |
if __name__ == "__main__":
|
94 |
+
demo.launch()
|