KingNish commited on
Commit
a56c04c
·
verified ·
1 Parent(s): 43e2f2f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -124
app.py CHANGED
@@ -1,40 +1,28 @@
1
  import gradio as gr
2
- from gradio import ChatMessage
3
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
4
  import torch
5
  from threading import Thread
6
  import spaces
7
- import time
8
 
9
- # --- Model and Tokenizer Setup ---
10
- print("Loading model and tokenizer...")
11
  model_name = "sarvamai/sarvam-m"
12
  tokenizer = AutoTokenizer.from_pretrained(model_name)
13
- model = AutoModelForCausalLM.from_pretrained(
14
- model_name,
15
- torch_dtype=torch.bfloat16, # bfloat16 is often better for inference
16
- device_map="auto"
17
- )
18
- print("Model and tokenizer loaded.")
19
 
20
- # --- Core Generation Logic ---
21
  @spaces.GPU
22
- def generate_response(history: list[ChatMessage]):
23
- # 1. Format the conversation history for the model
24
- # The model expects a list of dictionaries, e.g., [{"role": "user", "content": "Hello"}]
25
- # We convert our ChatMessage history to this format.
26
- query = [msg.model_dump() for msg in history]
27
- # Remove metadata as the model doesn't use it
28
- for msg in query:
29
- msg.pop('metadata', None)
30
-
31
- prompt_text = tokenizer.apply_chat_template(query, tokenize=False, add_generation_prompt=True, enable_thinking=True)
32
- model_inputs = tokenizer([prompt_text], return_tensors="pt").to(model.device)
33
 
34
- # 2. Set up the streamer
35
  streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
36
 
37
- # 3. Start generation in a separate thread
38
  generation_kwargs = dict(
39
  input_ids=model_inputs.input_ids,
40
  max_new_tokens=8192,
@@ -42,116 +30,39 @@ def generate_response(history: list[ChatMessage]):
42
  temperature=0.7,
43
  streamer=streamer,
44
  )
 
45
  thread = Thread(target=model.generate, kwargs=generation_kwargs)
46
  thread.start()
47
 
48
- # 4. Stream and process the output to create structured ChatMessages
49
- in_thought_block = False
50
- thought_content = ""
51
- response_content = ""
52
-
53
- # Add placeholder messages to the history that we will update
54
- # One for thoughts, one for the final answer.
55
- history.append(ChatMessage(role="assistant", content="", metadata={"title": "🤔 Thinking..."}))
56
- history.append(ChatMessage(role="assistant", content="" ))
57
- yield history
58
 
59
- start_time = time.time()
60
 
61
  for new_text in streamer:
62
- # Check if the model is starting to think
63
- if "<think>" in new_text and not in_thought_block:
64
- in_thought_block = True
65
- # Any text after the tag in this chunk is part of the thought
66
- thought_content += new_text.split("<think>", 1)[-1]
67
- continue # Move to next token
68
-
69
- # Check if the model has finished thinking
70
- if "</think>" in new_text and in_thought_block:
71
- in_thought_block = False
72
- duration = time.time() - start_time
73
- # Update the thought message with the full thought and completion status
74
- parts = new_text.split("</think>", 1)
75
- thought_content += parts[0]
76
- history[-2].content = thought_content.strip() # The first placeholder message
77
- history[-2].metadata = {"title": f"✅ Thinking Completed in {duration:.2f}s"}
78
 
79
- # Any text after the tag is part of the final response
80
- response_content += parts[1]
81
- history[-1].content = response_content.lstrip() # The second placeholder
82
- yield history
83
- continue
84
-
85
- # Accumulate content based on whether we are in a thought block or not
86
- if in_thought_block:
87
- thought_content += new_text
88
- # Update the thinking message in real-time
89
- history[-2].content = thought_content.strip()
90
  else:
91
- response_content += new_text
92
- # Update the final answer message in real-time
93
- history[-1].content = response_content.lstrip()
94
 
95
- yield history
96
-
97
- # Final cleanup: if the thought bubble is empty, remove it.
98
- if not history[-2].content.strip():
99
- history.pop(-2)
100
- yield history
101
 
 
 
 
 
 
102
 
103
- # --- Gradio Interface ---
104
- with gr.Blocks(theme=gr.themes.Default(primary_hue="blue", secondary_hue="blue")) as demo:
105
- gr.Markdown(
106
- """
107
- # 🧠 Sarvam AI Chatbot with Thinking Process
108
- This chatbot uses the `sarvamai/sarvam-m` model.
109
- It will show its "thoughts" in a separate, collapsible box before giving the final answer.
110
- """
111
- )
112
-
113
- chatbot = gr.Chatbot(
114
- [],
115
- elem_id="chatbot",
116
- bubble_full_width=False,
117
- height=600,
118
- avatar_images=(None, "https://huggingface.co/sarvamai/sarvam-m/resolve/main/Sarvam.AI.logo.jpeg"),
119
- show_copy_button=True,
120
- type="messages" # Crucial for using ChatMessage objects
121
- )
122
-
123
- with gr.Row():
124
- txt = gr.Textbox(
125
- scale=4,
126
- show_label=False,
127
- placeholder="Enter your message and press enter...",
128
- container=False,
129
- )
130
- btn = gr.Button("Submit", scale=1)
131
-
132
- # Function to handle user submission
133
- def user(user_message, history):
134
- # Create a user message and add it to the history
135
- history.append(ChatMessage(role="user", content=user_message))
136
- return "", history
137
-
138
- # Chain the events: user submission -> update history -> generate response
139
- txt.submit(user, [txt, chatbot], [txt, chatbot], queue=False).then(
140
- generate_response, chatbot, chatbot
141
- )
142
- btn.click(user, [txt, chatbot], [txt, chatbot], queue=False).then(
143
- generate_response, chatbot, chatbot
144
- )
145
-
146
- gr.Examples(
147
- [
148
- "Write a short story about a robot who discovers music.",
149
- "Explain the concept of black holes to a 5-year-old.",
150
- "Plan a 3-day itinerary for a trip to Paris.",
151
- ],
152
- inputs=txt,
153
- label="Example Prompts"
154
- )
155
 
156
  if __name__ == "__main__":
157
- demo.launch(debug=True)
 
1
  import gradio as gr
 
2
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
3
  import torch
4
  from threading import Thread
5
  import spaces
 
6
 
7
+ # Load the model and tokenizer
 
8
  model_name = "sarvamai/sarvam-m"
9
  tokenizer = AutoTokenizer.from_pretrained(model_name)
10
+ model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map="auto")
 
 
 
 
 
11
 
 
12
  @spaces.GPU
13
+ def generate_response(prompt, chat_history):
14
+
15
+ chat_history.append(dict(role="user", content=prompt ))
16
+
17
+ messages = [{"role": "user", "content": prompt}]
18
+ text = tokenizer.apply_chat_template(messages, tokenize=False, enable_thinking=True)
19
+
20
+ model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
 
 
 
21
 
22
+ # Use TextIteratorStreamer for streaming
23
  streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
24
 
25
+ # Conduct text generation with streaming
26
  generation_kwargs = dict(
27
  input_ids=model_inputs.input_ids,
28
  max_new_tokens=8192,
 
30
  temperature=0.7,
31
  streamer=streamer,
32
  )
33
+
34
  thread = Thread(target=model.generate, kwargs=generation_kwargs)
35
  thread.start()
36
 
37
+ # Initialize variables to capture reasoning content and main content
38
+ reasoning_content = ""
39
+ content = ""
40
+ reasoning_done = False
 
 
 
 
 
 
41
 
42
+ chat_history.append(dict(role="assistant", content=reasoning_content, metadata={"title": "Thinking..."}) )
43
 
44
  for new_text in streamer:
45
+ if "</think>" in new_text:
46
+ chat_history[-1]["metadata"] = {"title": "Thinking Completed"}
47
+ reasoning_done = True
48
+ chat_history.append(dict(role="assistant", content=content))
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
+ if not reasoning_done:
51
+ reasoning_content += new_text
52
+ chat_history[-1]["content"] = reasoning_content
 
 
 
 
 
 
 
 
53
  else:
54
+ content += new_text
55
+ chat_history[-1]["content"] = content
 
56
 
57
+ yield chat_history
 
 
 
 
 
58
 
59
+ # Create the Gradio interface
60
+ with gr.Blocks() as demo:
61
+ gr.Markdown("# Sarvam M Demo")
62
+ chatbot = gr.Chatbot(height=600)
63
+ msg = gr.Textbox(label="Your Message")
64
 
65
+ msg.submit(respond, [msg, chatbot], [chatbot])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
 
67
  if __name__ == "__main__":
68
+ demo.launch()