File size: 3,279 Bytes
89f33a5
 
 
6f2ede7
851f757
060b8a4
3d1de8f
a56c04c
43e2f2f
 
4fce686
3d1de8f
060b8a4
 
a940b7a
a56c04c
399775a
 
4fce686
 
060b8a4
399775a
 
 
b2e8189
 
 
 
 
 
 
399775a
 
 
 
a56c04c
 
2c131cf
a56c04c
3d1de8f
 
a56c04c
3d1de8f
 
060b8a4
3d1de8f
b34ac00
a56c04c
3d1de8f
 
 
a56c04c
 
 
 
060b8a4
3d1de8f
399775a
3d1de8f
060b8a4
3d1de8f
a56c04c
 
060b8a4
 
399775a
060b8a4
a56c04c
060b8a4
 
 
a56c04c
 
3d1de8f
a56c04c
 
060b8a4
399775a
2c131cf
a56c04c
 
 
a940b7a
a56c04c
399775a
3d1de8f
89f33a5
060b8a4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
import torch
from threading import Thread
import spaces
import time

# Load the model and tokenizer
model_name = "sarvamai/sarvam-m"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, device_map="auto")

indicators = ["Thinking ⠋", "Thinking ⠙", "Thinking ⠹", "Thinking ⠸", "Thinking ⠼", "Thinking ⠴", "Thinking ⠦", "Thinking ⠧", "Thinking ⠇", "Thinking ⠏"]

@spaces.GPU(duration=120)
def generate_response(prompt, chat_history):
    chat_history.append({"role": "user", "content": prompt})
    yield chat_history, ""

    print(chat_history)

    # Preprocess chat history to include thinking tags
    processed_chat_history = []
    for message in chat_history:
        # Skipping Thought Process in history
        if message["role"] == "assistant":
            metadata = message.get("metadata", {})
            if isinstance(metadata, dict) and metadata.get("title", "").startswith("Thought"):
                pass
            else:
                processed_chat_history.append(message)
        else:
            processed_chat_history.append(message)

    text = tokenizer.apply_chat_template(processed_chat_history, tokenize=False, add_generation_prompt=True)

    model_inputs = tokenizer([text], return_tensors="pt").to(model.device)

    # Use TextIteratorStreamer for streaming
    streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)

    # Conduct text generation with streaming
    generation_kwargs = dict(
        input_ids=model_inputs.input_ids,
        max_new_tokens=8192,
        streamer=streamer,
    )

    thread = Thread(target=model.generate, kwargs=generation_kwargs)
    thread.start()

    # Initialize variables to capture reasoning content and main content
    reasoning_content = ""
    content = ""
    reasoning_done = False
    start_time = time.time()

    chat_history.append({"role": "assistant", "content": reasoning_content, "metadata": {"title": "Thinking..."}})

    indicator_index = 0
    for new_text in streamer:
        if "</think>" in new_text:
            reasoning_done = True
            thought_duration = time.time() - start_time
            chat_history[-1]["metadata"] = {"title": f"Thought for {thought_duration:.2f} seconds"}
            chat_history.append({"role": "assistant", "content": content})

        if not reasoning_done:
            # Update the thinking indicator
            indicator_index = (indicator_index + 1) % len(indicators)
            chat_history[-1]["metadata"] = {"title": indicators[indicator_index]}
            reasoning_content += new_text
            chat_history[-1]["content"] = reasoning_content
        else:
            content += new_text
            chat_history[-1]["content"] = content

        yield chat_history, ""

# Create the Gradio interface
with gr.Blocks() as demo:
    gr.Markdown("# Sarvam M Demo")
    chatbot = gr.Chatbot(height=500, type="messages")
    msg = gr.Textbox(label="Your Message")
    msg.submit(generate_response, [msg, chatbot], [chatbot, msg])

if __name__ == "__main__":
    demo.launch(mcp_server=True)