Jenny991 commited on
Commit
8b0b1ab
·
verified ·
1 Parent(s): 28e30fe

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -5
app.py CHANGED
@@ -1,10 +1,24 @@
1
  from transformers import pipeline
2
  import gradio as gr
3
 
4
- generator = pipeline("text-generation", model="distilgpt2")
 
5
 
6
- def chat_fn(message, history):
7
- output = generator(message, max_new_tokens=50)[0]["generated_text"]
8
- return output, history
9
 
10
- gr.Interface(fn=chat_fn, inputs="text", outputs="text").launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from transformers import pipeline
2
  import gradio as gr
3
 
4
+ # 中文 GPT2 模型
5
+ generator = pipeline("text-generation", model="IDEA-CCNL/Wenzhong2.0-GPT2-110M")
6
 
7
+ def chat_fn(user_input, history):
8
+ history = history or []
 
9
 
10
+ # 構造中文 prompt
11
+ prompt = ""
12
+ for user_msg, bot_msg in history:
13
+ prompt += f"用戶:{user_msg}\n機器人:{bot_msg}\n"
14
+ prompt += f"用戶:{user_input}\n機器人:"
15
+
16
+ # 生成中文回答
17
+ output = generator(prompt, max_new_tokens=100, pad_token_id=50256)[0]["generated_text"]
18
+ reply = output.split("機器人:")[-1].strip()
19
+
20
+ # 更新歷史
21
+ history.append((user_input, reply))
22
+ return history, history
23
+
24
+ gr.ChatInterface(chat_fn).launch()