app.py
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from dotenv import load_dotenv
|
3 |
+
import sys
|
4 |
+
import os
|
5 |
+
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) # 添加项目根目录
|
6 |
+
from qa_chain.Chat_QA_chain_self import Chat_QA_chain_self
|
7 |
+
|
8 |
+
# 加载环境变量
|
9 |
+
load_dotenv()
|
10 |
+
|
11 |
+
chat_history:list=[] # 初始化聊天历史记录
|
12 |
+
|
13 |
+
class ChatSystem:
|
14 |
+
def __init__(self):
|
15 |
+
self.chat_chain = Chat_QA_chain_self(
|
16 |
+
model="qwen-max", # 改为阿里通义千问模型
|
17 |
+
temperature=0.7,
|
18 |
+
top_k=4,
|
19 |
+
chat_history=chat_history,
|
20 |
+
persist_path="./vector_db/chroma_sanguo",
|
21 |
+
file_path="./knowledge_db/sanguo_characters",
|
22 |
+
api_key=os.getenv("ali_api_key"), # 使用阿里API Key
|
23 |
+
embedding="m3e",
|
24 |
+
embedding_key=None
|
25 |
+
)
|
26 |
+
|
27 |
+
chat_system = ChatSystem()
|
28 |
+
|
29 |
+
def handle_message(message, chat_history):
|
30 |
+
if not message:
|
31 |
+
return "", chat_history
|
32 |
+
|
33 |
+
# 处理None值:将用户消息中的None转换为空字符串
|
34 |
+
formatted_chat_history = []
|
35 |
+
for item in chat_history:
|
36 |
+
user_msg, ai_msg = item
|
37 |
+
formatted_user_msg = user_msg if user_msg is not None else "" # 关键修改
|
38 |
+
formatted_chat_history.append((formatted_user_msg, ai_msg))
|
39 |
+
|
40 |
+
# 调用问答链获取回复
|
41 |
+
updated_history = chat_system.chat_chain.answer(
|
42 |
+
question=message,
|
43 |
+
chat_history=formatted_chat_history
|
44 |
+
)
|
45 |
+
|
46 |
+
# 转换回列表的列表供Gradio显示
|
47 |
+
return "", [list(item) for item in updated_history]
|
48 |
+
|
49 |
+
with gr.Blocks() as simple_demo:
|
50 |
+
# 使用Row布局水平排列图标和标题(可改为Column垂直排列)
|
51 |
+
with gr.Row(equal_height=True,justify="center"): # equal_height=True保证行内组件高度一致
|
52 |
+
# 加载图标(调整参数保持1:1比例)
|
53 |
+
gr.Image(
|
54 |
+
value="./figures/sanguo.png", # 图标路径
|
55 |
+
scale=0.2, # 控制组件在布局中的占比(1表示与相邻组件等宽)
|
56 |
+
# min_width=60, # 设置最小宽度(根据图标原始尺寸调整)
|
57 |
+
# min_height=60, # 设置最小高度(与min_width相同,保持1:1)
|
58 |
+
width=80, # 固定宽度
|
59 |
+
height=80, # 固定高度
|
60 |
+
show_label=False, # 隐藏标签
|
61 |
+
show_download_button=False, # 隐藏下载按钮
|
62 |
+
container=False # 去除图标外边框
|
63 |
+
)
|
64 |
+
gr.Markdown("## 三国大乱斗系统") # 原标题
|
65 |
+
|
66 |
+
# 初始化带系统介绍的聊天框
|
67 |
+
chatbot = gr.Chatbot(
|
68 |
+
height=600,
|
69 |
+
value=[[None, "欢迎使用三国大乱斗系统!我可以为您提供以下功能:\n1. 角色抽取:随机抽取三国人物卡并展示完整信息\n2. 对战规程介绍:包括回合制规则、技能使用说明等\n3. 实时对战:支持玩家与AI的回合制对战模拟\n请输入您的问题开始交互~"]]
|
70 |
+
)
|
71 |
+
msg = gr.Textbox(label="请输入您的问题")
|
72 |
+
|
73 |
+
with gr.Row():
|
74 |
+
submit_btn = gr.Button("发送")
|
75 |
+
clear_btn = gr.Button("清空对话")
|
76 |
+
|
77 |
+
submit_btn.click(
|
78 |
+
handle_message,
|
79 |
+
inputs=[msg, chatbot],
|
80 |
+
outputs=[msg, chatbot]
|
81 |
+
)
|
82 |
+
clear_btn.click(lambda: [], None, chatbot)
|
83 |
+
|
84 |
+
if __name__ == "__main__":
|
85 |
+
# 关键修改:显式设置启动参数(适配Hugging Face环境)
|
86 |
+
simple_demo.launch(
|
87 |
+
server_name="0.0.0.0", # 允许外部访问
|
88 |
+
server_port=7860, # 使用Gradio默认端口
|
89 |
+
share=False # 关闭公共分享链接(Hugging Face会自动生成访问地址)
|
90 |
+
)
|