KingNish commited on
Commit
3d1de8f
·
verified ·
1 Parent(s): fcc1fac

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +81 -224
app.py CHANGED
@@ -4,234 +4,91 @@ import torch
4
  from threading import Thread
5
  import spaces
6
  import time
7
- import copy
8
-
9
- # For the advanced UI components
10
- import modelscope_studio.components.antd as antd
11
- import modelscope_studio.components.antdx as antdx
12
- import modelscope_studio.components.base as ms
13
- import modelscope_studio.components.pro as pro
14
- from modelscope_studio.components.pro.chatbot import (ChatbotBotConfig,
15
- ChatbotPromptsConfig,
16
- ChatbotUserConfig,
17
- ChatbotWelcomeConfig)
18
-
19
- # --- 1. Load the Hugging Face Model and Tokenizer ---
20
  model_name = "sarvamai/sarvam-m"
21
- print(f"Loading model: {model_name}...")
22
  tokenizer = AutoTokenizer.from_pretrained(model_name)
23
- model = AutoModelForCausalLM.from_pretrained(
24
- model_name,
25
- torch_dtype=torch.float16,
26
- device_map="auto"
27
- )
28
- print("Model loaded successfully.")
29
-
30
-
31
- # --- 2. Helper and Event Handler Functions ---
32
-
33
- def format_history_for_sarvam(history: list) -> list:
34
- messages = []
35
- if not history:
36
- return messages
37
- for item in history:
38
- role = item.get("role")
39
- content = item.get("content")
40
- if role == "user":
41
- messages.append({"role": "user", "content": content})
42
- elif role == "assistant":
43
- final_content = ""
44
- if isinstance(content, list):
45
- for part in content:
46
- if part.get("type") == "text":
47
- final_content = part.get("content", "")
48
- break
49
- elif isinstance(content, str):
50
- final_content = content
51
- if final_content:
52
- messages.append({"role": "assistant", "content": final_content})
53
- return messages
54
 
55
  @spaces.GPU
56
- def submit(sender_value: str, chatbot_value: list):
57
- if sender_value:
58
- chatbot_value.append({"role": "user", "content": sender_value})
59
-
60
- chatbot_value.append({
61
- "role": "assistant",
62
- "content": [],
63
- "loading": True,
64
- "status": "pending"
65
- })
66
-
67
- # <-- 2. APPLY FIX HERE
68
- yield {
69
- sender: gr.update(value=None, loading=True),
70
- clear_btn: gr.update(disabled=True),
71
- chatbot: gr.update(value=copy.deepcopy(chatbot_value))
72
- }
73
-
74
- try:
75
- history_messages = format_history_for_sarvam(chatbot_value)
76
- prompt_text = tokenizer.apply_chat_template(
77
- history_messages, tokenize=False, add_generation_prompt=True, enable_thinking=True
78
- )
79
- model_inputs = tokenizer([prompt_text], return_tensors="pt").to(model.device)
80
- streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
81
- generation_kwargs = dict(
82
- input_ids=model_inputs.input_ids,
83
- max_new_tokens=8192,
84
- do_sample=True,
85
- temperature=0.7,
86
- streamer=streamer,
87
- )
88
- thread = Thread(target=model.generate, kwargs=generation_kwargs)
89
- thread.start()
90
-
91
- start_time = time.time()
92
- message_content = chatbot_value[-1]["content"]
93
- message_content.append({
94
- "copyable": False, "editable": False, "type": "tool", "content": "",
95
- "options": {"title": "Thinking...", "status": "pending"}
96
- })
97
- message_content.append({"type": "text", "content": ""})
98
- chatbot_value[-1]["loading"] = False
99
- full_response = ""
100
- thinking_content = ""
101
- main_content = ""
102
- thinking_done = False
103
-
104
- for new_text in streamer:
105
- full_response += new_text
106
- if not thinking_done and "</think>" in full_response:
107
- thinking_done = True
108
- try:
109
- parts = full_response.split("</think>", 1)
110
- thinking_content = parts[0].split("<think>", 1)[1]
111
- main_content = parts[1]
112
- thought_cost_time = "{:.2f}".format(time.time() - start_time)
113
- message_content[0]["content"] = thinking_content.strip()
114
- message_content[0]["options"]["title"] = f"End of Thought ({thought_cost_time}s)"
115
- message_content[0]["options"]["status"] = "done"
116
- except IndexError:
117
- main_content = full_response
118
- elif not thinking_done:
119
- if full_response.lstrip().startswith("<think>"):
120
- thinking_content = full_response.lstrip()[len("<think>"):]
121
- message_content[0]["content"] = thinking_content.strip()
122
- else:
123
- main_content = full_response.split("</think>", 1)[1]
124
-
125
- message_content[1]["content"] = main_content.lstrip("\n")
126
-
127
- # <-- 3. APPLY FIX HERE
128
- yield {chatbot: gr.update(value=copy.deepcopy(chatbot_value))}
129
-
130
- chatbot_value[-1]["footer"] = "{:.2f}s".format(time.time() - start_time)
131
- chatbot_value[-1]["status"] = "done"
132
-
133
- # <-- 4. APPLY FIX HERE
134
- yield {
135
- clear_btn: gr.update(disabled=False),
136
- sender: gr.update(loading=False),
137
- chatbot: gr.update(value=copy.deepcopy(chatbot_value)),
138
- }
139
-
140
- except Exception as e:
141
- print(f"An error occurred: {e}")
142
- chatbot_value[-1]["loading"] = False
143
- chatbot_value[-1]["status"] = "done"
144
- chatbot_value[-1]["content"] = f"Failed to respond due to an error: {e}"
145
-
146
- # <-- 5. APPLY FIX HERE
147
- yield {
148
- clear_btn: gr.update(disabled=False),
149
- sender: gr.update(loading=False),
150
- chatbot: gr.update(value=copy.deepcopy(chatbot_value)),
151
- }
152
-
153
- def prompt_select(e: gr.EventData):
154
- return gr.update(value=e._data["payload"][0]["value"]["description"])
155
-
156
- def clear():
157
- return gr.update(value=None)
158
-
159
- def retry(chatbot_value: list, e: gr.EventData):
160
- index = e._data["payload"][0]["index"]
161
- chatbot_value = chatbot_value[:index-1]
162
- yield {
163
- sender: gr.update(loading=True),
164
- chatbot: gr.update(value=chatbot_value),
165
- clear_btn: gr.update(disabled=True)
166
- }
167
- for chunk in submit(None, chatbot_value):
168
- yield chunk
169
-
170
- def cancel(chatbot_value: list):
171
- if chatbot_value and chatbot_value[-1].get("status") == "pending":
172
- chatbot_value[-1]["loading"] = False
173
- chatbot_value[-1]["status"] = "done"
174
- chatbot_value[-1]["footer"] = "Chat completion paused"
175
- return {
176
- chatbot: gr.update(value=chatbot_value),
177
- sender: gr.update(loading=False),
178
- clear_btn: gr.update(disabled=False)
179
- }
180
-
181
- # --- 3. Build the Gradio UI ---
182
- with gr.Blocks(theme=gr.themes.Default(primary_hue="blue", secondary_hue="blue")) as demo, ms.Application(), antdx.XProvider():
183
- with antd.Flex(vertical=True, gap="middle"):
184
- chatbot = pro.Chatbot(
185
- height=650,
186
- welcome_config=ChatbotWelcomeConfig(
187
- variant="borderless",
188
- icon="https://cdn-avatars.huggingface.co/v1/production/uploads/60270a7c32856987162c641a/umd13GCWVijwTDGZzw3q-.png",
189
- title=f"Hello, I'm {model_name.split('/')[-1]}",
190
- description="I can show you my thinking process. How can I help you today?",
191
- prompts=ChatbotPromptsConfig(
192
- items=[
193
- {"label": "Explain a concept", "children": [{"description": "Explain what a Large Language Model is in simple terms."}]},
194
- {"label": "Help me write", "children": [{"description": "Write a short, futuristic story about AI companions."}]},
195
- {"label": "Creative Ideas", "children": [{"description": "Give me three creative names for a new coffee shop."}]},
196
- {"label": "Code generation", "children": [{"description": "Write a python function to find the factorial of a number."}]}
197
- ]
198
- )
199
- ),
200
- user_config=ChatbotUserConfig(avatar="https://api.dicebear.com/7.x/miniavs/svg?seed=gradio"),
201
- bot_config=ChatbotBotConfig(
202
- header=model_name,
203
- avatar="https://cdn-avatars.huggingface.co/v1/production/uploads/60270a7c32856987162c641a/umd13GCWVijwTDGZzw3q-.png",
204
- actions=["copy", "retry"]
205
- ),
206
- )
207
- with antdx.Sender() as sender:
208
- with ms.Slot("prefix"):
209
- with antd.Button(value=None, color="default", variant="text") as clear_btn:
210
- with ms.Slot("icon"):
211
- antd.Icon("ClearOutlined")
212
-
213
- clear_btn.click(fn=clear, outputs=[chatbot])
214
- submit_event = sender.submit(
215
- fn=submit,
216
- inputs=[sender, chatbot],
217
- outputs=[sender, chatbot, clear_btn]
218
- )
219
- sender.cancel(
220
- fn=cancel,
221
- inputs=[chatbot],
222
- outputs=[chatbot, sender, clear_btn],
223
- cancels=[submit_event],
224
- queue=False
225
- )
226
- chatbot.retry(
227
- fn=retry,
228
- inputs=[chatbot],
229
- outputs=[sender, chatbot, clear_btn]
230
- )
231
- chatbot.welcome_prompt_select(
232
- fn=prompt_select,
233
- outputs=[sender]
234
  )
235
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
236
  if __name__ == "__main__":
237
- demo.queue().launch(debug=True)
 
4
  from threading import Thread
5
  import spaces
6
  import time
7
+
8
+ # Load the model and tokenizer
 
 
 
 
 
 
 
 
 
 
 
9
  model_name = "sarvamai/sarvam-m"
 
10
  tokenizer = AutoTokenizer.from_pretrained(model_name)
11
+ model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map="auto")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
  @spaces.GPU
14
+ def generate_response(prompt, chat_history):
15
+ messages = [{"role": "user", "content": prompt}]
16
+ text = tokenizer.apply_chat_template(messages, tokenize=False, enable_thinking=True)
17
+
18
+ model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
19
+
20
+ # Use TextIteratorStreamer for streaming
21
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
22
+
23
+ # Conduct text generation with streaming
24
+ generation_kwargs = dict(
25
+ input_ids=model_inputs.input_ids,
26
+ max_new_tokens=8192,
27
+ do_sample=True,
28
+ temperature=0.7,
29
+ streamer=streamer,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  )
31
 
32
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
33
+ thread.start()
34
+
35
+ # Initialize variables to capture reasoning content and main content
36
+ reasoning_content = ""
37
+ content = ""
38
+ start_time = time.time()
39
+
40
+ # First yield: Show thinking has started
41
+ yield chat_history + [(None, "Thinking...")], ""
42
+
43
+ for new_text in streamer:
44
+ if "</think>" in new_text:
45
+ parts = new_text.split("</think>")
46
+ reasoning_content = parts[0].rstrip("\n")
47
+ content = parts[-1].lstrip("\n").rstrip("</s>")
48
+
49
+ # Calculate thinking time
50
+ thinking_time = time.time() - start_time
51
+
52
+ # Yield the thinking process
53
+ yield chat_history + [
54
+ (None, f"Thinking..."),
55
+ (None, f"Thinking completed. Thought for {thinking_time:.1f} seconds."),
56
+ (None, f"Thought process:\n{reasoning_content}")
57
+ ], ""
58
+ else:
59
+ content += new_text
60
+ # Yield the content as it's being generated
61
+ yield chat_history + [
62
+ (None, f"Thinking..."),
63
+ (None, f"Thinking completed. Thought for {time.time() - start_time:.1f} seconds."),
64
+ (None, f"Thought process:\n{reasoning_content}"),
65
+ (None, content)
66
+ ], ""
67
+
68
+ # Final yield with complete response
69
+ yield chat_history + [
70
+ (None, f"Thinking..."),
71
+ (None, f"Thinking completed. Thought for {time.time() - start_time:.1f} seconds."),
72
+ (None, f"Thought process:\n{reasoning_content}"),
73
+ (prompt, f"{reasoning_content}\n{content}" if reasoning_content else content)
74
+ ], ""
75
+
76
+ # Create the Gradio interface
77
+ with gr.Blocks() as demo:
78
+ gr.Markdown("# Sarvam M Demo")
79
+ chatbot = gr.Chatbot(show_copy_button=True)
80
+ msg = gr.Textbox(label="Your Message")
81
+
82
+ def respond(message, chat_history):
83
+ # Start with the user message
84
+ chat_history.append((message, None))
85
+ yield chat_history, ""
86
+
87
+ # Then stream the assistant's response
88
+ for updated_history, _ in generate_response(message, chat_history):
89
+ yield updated_history, ""
90
+
91
+ msg.submit(respond, [msg, chatbot], [chatbot, msg])
92
+
93
  if __name__ == "__main__":
94
+ demo.launch()