File size: 3,220 Bytes
54c0660
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10088ba
54c0660
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10088ba
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import gradio as gr
from dotenv import load_dotenv
import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))  # 添加项目根目录
from qa_chain.Chat_QA_chain_self import Chat_QA_chain_self

# 加载环境变量
load_dotenv()

chat_history:list=[]  # 初始化聊天历史记录

class ChatSystem:
    def __init__(self):
        self.chat_chain = Chat_QA_chain_self(
        model="qwen-max",  # 改为阿里通义千问模型
        temperature=0.7,
        top_k=4,
        chat_history=chat_history,
        persist_path="./vector_db/chroma_sanguo",
        file_path="./knowledge_db/sanguo_characters",
        api_key=os.getenv("ali_api_key"),  # 使用阿里API Key
        embedding="m3e",
        embedding_key=None
    )

chat_system = ChatSystem()

def handle_message(message, chat_history):
    if not message:
        return "", chat_history
    
    # 处理None值:将用户消息中的None转换为空字符串
    formatted_chat_history = []
    for item in chat_history:
        user_msg, ai_msg = item
        formatted_user_msg = user_msg if user_msg is not None else ""  # 关键修改
        formatted_chat_history.append((formatted_user_msg, ai_msg))
    
    # 调用问答链获取回复
    updated_history = chat_system.chat_chain.answer(
        question=message, 
        chat_history=formatted_chat_history
    )
    
    # 转换回列表的列表供Gradio显示
    return "", [list(item) for item in updated_history]

with gr.Blocks() as simple_demo:
    # 使用Row布局水平排列图标和标题(可改为Column垂直排列)
    with gr.Row(equal_height=True):  # equal_height=True保证行内组件高度一致
        # 加载图标(调整参数保持1:1比例)
        gr.Image(
            value="./figures/sanguo.png",  # 图标路径
            scale=0.2,  # 控制组件在布局中的占比(1表示与相邻组件等宽)
            # min_width=60,  # 设置最小宽度(根据图标原始尺寸调整)
            # min_height=60,  # 设置最小高度(与min_width相同,保持1:1)
            width=80,  # 固定宽度
            height=80,  # 固定高度
            show_label=False,  # 隐藏标签
            show_download_button=False,  # 隐藏下载按钮
            container=False  # 去除图标外边框
        )
        gr.Markdown("## 三国大乱斗系统")  # 原标题
    
    # 初始化带系统介绍的聊天框
    chatbot = gr.Chatbot(
        height=600,
        value=[[None, "欢迎使用三国大乱斗系统!我可以为您提供以下功能:\n1. 角色抽取:随机抽取三国人物卡并展示完整信息\n2. 对战规程介绍:包括回合制规则、技能使用说明等\n3. 实时对战:支持玩家与AI的回合制对战模拟\n请输入您的问题开始交互~"]]
    )
    msg = gr.Textbox(label="请输入您的问题")
    
    with gr.Row():
        submit_btn = gr.Button("发送")
        clear_btn = gr.Button("清空对话")

    submit_btn.click(
        handle_message,
        inputs=[msg, chatbot],
        outputs=[msg, chatbot]
    )
    clear_btn.click(lambda: [], None, chatbot)

if __name__ == "__main__":
    simple_demo.launch()