Spaces:
Running
on
Zero
Running
on
Zero
File size: 3,279 Bytes
89f33a5 6f2ede7 851f757 060b8a4 3d1de8f a56c04c 43e2f2f 4fce686 3d1de8f 060b8a4 a940b7a a56c04c 399775a 4fce686 060b8a4 399775a b2e8189 399775a a56c04c 2c131cf a56c04c 3d1de8f a56c04c 3d1de8f 060b8a4 3d1de8f b34ac00 a56c04c 3d1de8f a56c04c 060b8a4 3d1de8f 399775a 3d1de8f 060b8a4 3d1de8f a56c04c 060b8a4 399775a 060b8a4 a56c04c 060b8a4 a56c04c 3d1de8f a56c04c 060b8a4 399775a 2c131cf a56c04c a940b7a a56c04c 399775a 3d1de8f 89f33a5 060b8a4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 |
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
import torch
from threading import Thread
import spaces
import time
# Load the model and tokenizer
model_name = "sarvamai/sarvam-m"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, device_map="auto")
indicators = ["Thinking ⠋", "Thinking ⠙", "Thinking ⠹", "Thinking ⠸", "Thinking ⠼", "Thinking ⠴", "Thinking ⠦", "Thinking ⠧", "Thinking ⠇", "Thinking ⠏"]
@spaces.GPU(duration=120)
def generate_response(prompt, chat_history):
chat_history.append({"role": "user", "content": prompt})
yield chat_history, ""
print(chat_history)
# Preprocess chat history to include thinking tags
processed_chat_history = []
for message in chat_history:
# Skipping Thought Process in history
if message["role"] == "assistant":
metadata = message.get("metadata", {})
if isinstance(metadata, dict) and metadata.get("title", "").startswith("Thought"):
pass
else:
processed_chat_history.append(message)
else:
processed_chat_history.append(message)
text = tokenizer.apply_chat_template(processed_chat_history, tokenize=False, add_generation_prompt=True)
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
# Use TextIteratorStreamer for streaming
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
# Conduct text generation with streaming
generation_kwargs = dict(
input_ids=model_inputs.input_ids,
max_new_tokens=8192,
streamer=streamer,
)
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
# Initialize variables to capture reasoning content and main content
reasoning_content = ""
content = ""
reasoning_done = False
start_time = time.time()
chat_history.append({"role": "assistant", "content": reasoning_content, "metadata": {"title": "Thinking..."}})
indicator_index = 0
for new_text in streamer:
if "</think>" in new_text:
reasoning_done = True
thought_duration = time.time() - start_time
chat_history[-1]["metadata"] = {"title": f"Thought for {thought_duration:.2f} seconds"}
chat_history.append({"role": "assistant", "content": content})
if not reasoning_done:
# Update the thinking indicator
indicator_index = (indicator_index + 1) % len(indicators)
chat_history[-1]["metadata"] = {"title": indicators[indicator_index]}
reasoning_content += new_text
chat_history[-1]["content"] = reasoning_content
else:
content += new_text
chat_history[-1]["content"] = content
yield chat_history, ""
# Create the Gradio interface
with gr.Blocks() as demo:
gr.Markdown("# Sarvam M Demo")
chatbot = gr.Chatbot(height=500, type="messages")
msg = gr.Textbox(label="Your Message")
msg.submit(generate_response, [msg, chatbot], [chatbot, msg])
if __name__ == "__main__":
demo.launch(mcp_server=True) |