|
import gradio as gr |
|
from dotenv import load_dotenv |
|
import sys |
|
import os |
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) |
|
from qa_chain.Chat_QA_chain_self import Chat_QA_chain_self |
|
|
|
|
|
load_dotenv() |
|
|
|
chat_history:list=[] |
|
|
|
class ChatSystem: |
|
def __init__(self): |
|
self.chat_chain = Chat_QA_chain_self( |
|
model="qwen-max", |
|
temperature=0.7, |
|
top_k=4, |
|
chat_history=chat_history, |
|
persist_path="./vector_db/chroma_sanguo", |
|
file_path="./knowledge_db/sanguo_characters", |
|
api_key=os.getenv("ali_api_key"), |
|
embedding="m3e", |
|
embedding_key=None |
|
) |
|
|
|
chat_system = ChatSystem() |
|
|
|
def handle_message(message, chat_history): |
|
if not message: |
|
return "", chat_history |
|
|
|
|
|
formatted_chat_history = [] |
|
for item in chat_history: |
|
user_msg, ai_msg = item |
|
formatted_user_msg = user_msg if user_msg is not None else "" |
|
formatted_chat_history.append((formatted_user_msg, ai_msg)) |
|
|
|
|
|
updated_history = chat_system.chat_chain.answer( |
|
question=message, |
|
chat_history=formatted_chat_history |
|
) |
|
|
|
|
|
return "", [list(item) for item in updated_history] |
|
|
|
with gr.Blocks() as simple_demo: |
|
|
|
with gr.Row(equal_height=True,justify="center"): |
|
|
|
gr.Image( |
|
value="./figures/sanguo.png", |
|
scale=0.2, |
|
|
|
|
|
width=80, |
|
height=80, |
|
show_label=False, |
|
show_download_button=False, |
|
container=False |
|
) |
|
gr.Markdown("## 三国大乱斗系统") |
|
|
|
|
|
chatbot = gr.Chatbot( |
|
height=600, |
|
value=[[None, "欢迎使用三国大乱斗系统!我可以为您提供以下功能:\n1. 角色抽取:随机抽取三国人物卡并展示完整信息\n2. 对战规程介绍:包括回合制规则、技能使用说明等\n3. 实时对战:支持玩家与AI的回合制对战模拟\n请输入您的问题开始交互~"]] |
|
) |
|
msg = gr.Textbox(label="请输入您的问题") |
|
|
|
with gr.Row(): |
|
submit_btn = gr.Button("发送") |
|
clear_btn = gr.Button("清空对话") |
|
|
|
submit_btn.click( |
|
handle_message, |
|
inputs=[msg, chatbot], |
|
outputs=[msg, chatbot] |
|
) |
|
clear_btn.click(lambda: [], None, chatbot) |
|
|
|
if __name__ == "__main__": |
|
|
|
simple_demo.launch( |
|
server_name="0.0.0.0", |
|
server_port=7860, |
|
share=False |
|
) |