雷娃 commited on
Commit
73a2adf
·
1 Parent(s): 9d70a29

specify the GPU device and support streaming output

Browse files
Files changed (1) hide show
  1. app.py +76 -43
app.py CHANGED
@@ -1,55 +1,88 @@
1
- # app.py
2
- from transformers import AutoTokenizer, AutoModelForCausalLM
3
  import gradio as gr
4
  import torch
5
 
6
- # load model and tokenizer
7
  model_name = "inclusionAI/Ling-lite-1.5"
 
 
8
  tokenizer = AutoTokenizer.from_pretrained(model_name)
9
  model = AutoModelForCausalLM.from_pretrained(
10
  model_name,
11
  torch_dtype="auto",
12
- device_map="auto",
13
  trust_remote_code=True
14
  ).eval()
15
 
16
- # define chat function
17
- def chat(user_input, max_new_tokens=512):
18
- # chat history
19
- messages = [
20
- {"role": "system", "content": "You are Ling, an assistant created by inclusionAI"},
21
- {"role": "user", "content": user_input}
22
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
24
-
25
- # encode the input prompt
26
- inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
27
-
28
- # generate response
29
- with torch.no_grad():
30
- outputs = model.generate(
31
- **inputs,
32
- max_new_tokens=max_new_tokens,
33
- pad_token_id=tokenizer.eos_token_id
34
- )
35
- response = tokenizer.decode(outputs[0][inputs['input_ids'].shape[-1]:], skip_special_tokens=True)
36
- return response
37
-
38
- # Construct Gradio Interface
39
- interface = gr.Interface(
40
- fn=chat,
41
- inputs=[
42
- gr.Textbox(lines=5, label="输入你的问题"),
43
- gr.Slider(minimum=100, maximum=1024, step=50, label="生成长度")
44
- ],
45
- outputs=gr.Textbox(label="模型回复"),
46
- title="Ling-lite-1.5 MoE 模型 Demo",
47
- description="基于 [inclusionAI/Ling-lite-1.5](https://huggingface.co/inclusionAI/Ling-lite-1.5) 的对话式文本生成演示。",
48
- examples=[
49
- ["介绍大型语言模型的基本概念", 512],
50
- ["如何解决数学问题中的长上下文依赖?", 768]
51
- ]
52
- )
53
-
54
- # launch Gradion Service
55
- interface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer
 
2
  import gradio as gr
3
  import torch
4
 
5
+ # 加载模型和 Tokenizer
6
  model_name = "inclusionAI/Ling-lite-1.5"
7
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
8
+
9
  tokenizer = AutoTokenizer.from_pretrained(model_name)
10
  model = AutoModelForCausalLM.from_pretrained(
11
  model_name,
12
  torch_dtype="auto",
13
+ device_map=device,
14
  trust_remote_code=True
15
  ).eval()
16
 
17
+ # 自定义 Streamer 以兼容 Gradio 异步回调
18
+ class GradioStreamer(TextStreamer):
19
+ def __init__(self, tokenizer, chatbot, skip_prompt: bool = True, skip_special_tokens: bool = True):
20
+ super().__init__(tokenizer, skip_prompt=skip_prompt, skip_special_tokens=skip_special_tokens)
21
+ self.chatbot = chatbot
22
+ self.current_text = ""
23
+
24
+ def put(self, value):
25
+ # 解码 token 并追加到当前文本
26
+ self.current_text += super().decode(value)
27
+ # 更新 Chatbot 最后一条消息
28
+ self.chatbot[-1][1] = self.current_text
29
+ yield self.chatbot
30
+
31
+ def end(self):
32
+ # 结束时也触发一次更新
33
+ yield self.chatbot
34
+
35
+
36
+ # 定义异步聊天函数
37
+ async def chat_stream(message, chat_history, max_new_tokens=512):
38
+ # 构造系统提示 + 历史记录 + 当前问题
39
+ messages = [{"role": "system", "content": "You are Ling, an assistant created by inclusionAI"}]
40
+ for user, bot in chat_history:
41
+ messages.append({"role": "user", "content": user})
42
+ messages.append({"role": "assistant", "content": bot})
43
+ messages.append({"role": "user", "content": message})
44
+
45
  prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
46
+
47
+ inputs = tokenizer(prompt, return_tensors="pt").to(device)
48
+
49
+ # 初始化 Chatbot 历史并创建流式对象
50
+ chat_history.append([message, ""])
51
+ streamer = GradioStreamer(tokenizer, chat_history)
52
+
53
+ # 异步生成(注意:transformers 的 generate 目前还不是 async,但我们可以模拟)
54
+ generation_kwargs = {
55
+ "input_ids": inputs["input_ids"],
56
+ "attention_mask": inputs["attention_mask"],
57
+ "streamer": streamer,
58
+ "max_new_tokens": max_new_tokens,
59
+ "pad_token_id": tokenizer.eos_token_id,
60
+ }
61
+
62
+ # 在后台线程中运行模型生成
63
+ import threading
64
+ thread = threading.Thread(target=model.generate, kwargs=generation_kwargs)
65
+ thread.start()
66
+
67
+ # 逐步返回 chat_history
68
+ while thread.is_alive():
69
+ yield chat_history
70
+ await asyncio.sleep(0.01)
71
+
72
+ # 返回最终结果
73
+ yield chat_history
74
+
75
+
76
+ # 构建 Gradio 界面
77
+ with gr.Blocks(title="Ling-lite-1.5 MoE 模型 Demo") as demo:
78
+ chatbot = gr.Chatbot(bubble_full_width=False, label="Ling 聊天机器人")
79
+ interface = gr.ChatInterface(
80
+ fn=chat_stream,
81
+ additional_inputs=[
82
+ gr.Slider(minimum=100, maximum=1024, step=50, label="生成长度", value=512),
83
+ ],
84
+ chatbot=chatbot
85
+ )
86
+
87
+ # 启动服务
88
+ demo.launch()