File size: 3,511 Bytes
87e8c3e
 
8b2a8d4
 
87e8c3e
8642e97
 
87e8c3e
8642e97
8b2a8d4
 
 
 
 
 
 
 
 
 
87e8c3e
c5d6f8a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3f72dd0
c5d6f8a
 
 
 
 
 
 
 
 
 
 
3f72dd0
d0dc3ee
8642e97
c5d6f8a
8642e97
 
 
 
87e8c3e
8642e97
 
c5d6f8a
 
8642e97
 
 
8b2a8d4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87e8c3e
 
d91c6b8
 
5780d63
 
87e8c3e
8642e97
 
 
 
87e8c3e
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import gradio as gr
from huggingface_hub import InferenceClient
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
import torch

# ←–– set this to the exact name of your HF repo
HF_MODEL_ID = "rieon/DeepCoder-14B-Preview-Suger"

# explicitly tell the client you want text-generation
# client = InferenceClient(model=HF_MODEL_ID)
device = "cuda" if torch.cuda.is_available() else "cpu"

tokenizer = AutoTokenizer.from_pretrained(HF_MODEL_ID, use_fast=True)
model = AutoModelForCausalLM.from_pretrained(
    HF_MODEL_ID,
    device_map="auto",       # spreads across all available GPUs
    torch_dtype=torch.float16
)
model.eval()

# def respond(
#     message: str,
#     history: list[dict],       # [{"role":"user"/"assistant","content":…}, …]
#     system_message: str,
#     max_tokens: int,
#     temperature: float,
#     top_p: float,
# ):
#     # 1️⃣ Build one raw-text prompt from system + chat history + new user turn
#     prompt = system_message.strip() + "\n"
#     for msg in history:
#         role = msg["role"]
#         content = msg["content"]
#         if role == "user":
#             prompt += f"User: {content}\n"
#         elif role == "assistant":
#             prompt += f"Assistant: {content}\n"
#     prompt += f"User: {message}\nAssistant:"

#     # 2️⃣ Stream tokens from the text-generation endpoint
#     generated = ""
#     for chunk in client.text_generation(
#         prompt,                     # first positional arg
#         max_new_tokens=max_tokens,
#         temperature=temperature,
#         top_p=top_p,
#         stream=True,
#     ):
#         generated += chunk.generated_text
#         yield generated

def respond(
    message: str,
    history: list[dict],
    system_message: str,
    max_tokens: int,
    temperature: float,
    top_p: float,
):
    # assemble a single prompt from system message + history
    prompt = system_message.strip() + "\n"
    # for user, bot in history:
    #     prompt += f"User: {user}\nAssistant: {bot}\n"
    prompt += f"User: {message}\nAssistant:"

    # stream back tokens
    # generated = ""
    # for chunk in client.text_generation(
    #     prompt,
    #     max_new_tokens=max_tokens,
    #     temperature=temperature,
    #     top_p=top_p,
    #     stream=True,
    # ):
    #     # the API returns a small JSON with .generated_text
    #     generated += chunk.generated_text
    #     yield generated
    streamer = TextIteratorStreamer(tokenizer,
                                    skip_prompt=True,
                                    skip_special_tokens=True)
    inputs = tokenizer(prompt, return_tensors="pt").to(device)
    model.generate(**inputs,
                   streamer=streamer,
                   max_new_tokens=max_new_tokens,
                   temperature=temperature,
                   top_p=top_p)

    output = ""
    for tok in streamer:
        output += tok
        yield output


demo = gr.ChatInterface(
    fn=respond,
    type="messages",
    title="DeepCoder with Suger",
    description="Upload any text or pdf files and ask questions about them!",
    additional_inputs=[
        gr.Textbox(value="You are a helpful coding assistant.", label="System message"),
        gr.Slider(1, 2048, value=512, step=1, label="Max new tokens"),
        gr.Slider(0.1, 4.0, value=0.7, step=0.1, label="Temperature"),
        gr.Slider(0.1, 1.0, value=0.95, step=0.05, label="Top-p"),
    ],
)

if __name__ == "__main__":
    demo.launch()