Jenny991 commited on
Commit
ea0fe2c
·
verified ·
1 Parent(s): 4938516

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -14
app.py CHANGED
@@ -1,25 +1,45 @@
1
  from transformers import pipeline
2
  import gradio as gr
3
 
4
- generator = pipeline("text-generation", model="ckiplab/gpt2-base-chinese",
5
- tokenizer="ckiplab/gpt2-base-chinese")
 
 
 
 
 
6
 
 
7
  def chat_fn(message, history):
8
  history = history or []
9
- input_text = "\n".join(history + [f"你: {message}", "AI:"])
10
- output = generator(input_text, max_new_tokens=80, pad_token_id=0)[0]["generated_text"]
11
- response = output.split("AI:")[-1].strip().split("你:")[0].strip()
12
- history.append(f"你: {message}")
13
- history.append(f"AI: {response}")
14
- messages = [(history[i], history[i+1]) for i in range(0, len(history)-1, 2)]
15
- return messages, history
16
 
 
 
 
 
 
 
 
 
 
 
 
17
  with gr.Blocks() as demo:
18
- chatbot = gr.Chatbot(label="中文聊天機器人", type="tuples")
19
- state = gr.State([])
20
- textbox = gr.Textbox(placeholder="請輸入訊息")
 
 
 
 
21
 
22
- textbox.submit(chat_fn, [textbox, state], [chatbot, state])
23
- textbox.submit(lambda: "", None, textbox)
24
 
25
  demo.launch()
 
1
  from transformers import pipeline
2
  import gradio as gr
3
 
4
+ # 使用中文 GPT2 對話模型(支援 CPU)
5
+ generator = pipeline(
6
+ "text-generation",
7
+ model="thu-coai/CDial-GPT2_LCCC-base",
8
+ tokenizer="thu-coai/CDial-GPT2_LCCC-base",
9
+ device=-1 # 使用 CPU
10
+ )
11
 
12
+ # 對話處理函式
13
  def chat_fn(message, history):
14
  history = history or []
15
+
16
+ # 將所有歷史訊息合併為 prompt
17
+ prompt = ""
18
+ for user_msg, bot_msg in history:
19
+ prompt += f"你說:{user_msg}\nAI說:{bot_msg}\n"
20
+ prompt += f"你說:{message}\nAI說:"
 
21
 
22
+ # 生成新回應
23
+ output = generator(prompt, max_new_tokens=80, pad_token_id=0)[0]["generated_text"]
24
+
25
+ # 從模型輸出中擷取 AI 回覆
26
+ response = output.split("AI說:")[-1].split("你說:")[-1].strip()
27
+
28
+ # 更新歷史
29
+ history.append((message, response))
30
+ return history, history
31
+
32
+ # 建立 Gradio 介面
33
  with gr.Blocks() as demo:
34
+ gr.Markdown("## 🧠 中文聊天機器人(記住上下文)")
35
+
36
+ chatbot = gr.Chatbot(label="GPT2 中文對話")
37
+ msg = gr.Textbox(show_label=False, placeholder="請輸入訊息,Enter 送出")
38
+ clear = gr.Button("🧹 清除對話")
39
+
40
+ state = gr.State([]) # 儲存對話歷史
41
 
42
+ msg.submit(chat_fn, inputs=[msg, state], outputs=[chatbot, state])
43
+ clear.click(lambda: ([], []), outputs=[chatbot, state])
44
 
45
  demo.launch()