File size: 5,675 Bytes
28d41ca
 
45c5d09
 
28d41ca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
from langchain.prompts import PromptTemplate
from langchain.chains import RetrievalQA
# from langchain.vectorstores import Chroma
from langchain_community.vectorstores import Chroma
from langchain.chains import ConversationalRetrievalChain
from langchain.memory import ConversationBufferMemory
from langchain.chat_models import ChatOpenAI
import sys
import os

# 添加当前目录到 Python 路径
current_dir = os.path.dirname(os.path.abspath(__file__))
if current_dir not in sys.path:
    sys.path.append(current_dir)

from model_to_llm import model_to_llm
from get_vectordb import get_vectordb
import re

class Chat_QA_chain_self:
    """"
    带历史记录的问答链  
    - model:调用的模型名称
    - temperature:温度系数,控制生成的随机性
    - top_k:返回检索的前k个相似文档
    - chat_history:历史记录,输入一个列表,默认是一个空列表
    - history_len:控制保留的最近 history_len 次对话
    - file_path:建库文件所在路径
    - persist_path:向量数据库持久化路径
    - appid:星火
    - api_key:星火、百度文心、OpenAI、智谱都需要传递的参数
    - Spark_api_secret:星火秘钥
    - Wenxin_secret_key:文心秘钥
    - embeddings:使用的embedding模型
    - embedding_key:使用的embedding模型的秘钥(智谱或者OpenAI)  
    """
    def __init__(self,model:str, temperature:float=0.0, top_k:int=4, chat_history:list=[], file_path:str=None, persist_path:str=None, appid:str=None, api_key:str=None, Spark_api_secret:str=None,Wenxin_secret_key:str=None, embedding = "openai",embedding_key:str=None):
        print("开始初始化 Chat_QA_chain_self...")
        self.model = model
        self.temperature = temperature
        self.top_k = top_k
        self.chat_history = chat_history
        self.file_path = file_path
        self.persist_path = persist_path
        self.appid = appid
        self.api_key = api_key
        self.Spark_api_secret = Spark_api_secret
        self.Wenxin_secret_key = Wenxin_secret_key
        self.embedding = embedding
        self.embedding_key = embedding_key

        print("正在初始化向量数据库...")
        try:
            self.vectordb = get_vectordb(self.file_path, self.persist_path, self.embedding,self.embedding_key)
            print("向量数据库初始化完成")
        except Exception as e:
            print(f"向量数据库初始化失败: {str(e)}")
            raise
        print("Chat_QA_chain_self 初始化完成")
    
    def clear_history(self):
        "清空历史记录"
        return self.chat_history.clear()

    
    def change_history_length(self, history_len:int=None):
        """
        返回指定数量的历史记录,不传参数则返回全部
        """
        if history_len is None:
            return self.chat_history
        n = len(self.chat_history)
        return self.chat_history[n-history_len:]

 
    def answer(self, question:str=None, temperature=None, top_k=4, chat_history=None):
        """"
        核心方法,调用问答链
        arguments: 
        - question:用户提问
        - chat_history:对话历史记录
        """
        print("Chat_QA_chain_self.answer 开始执行...")
        
        if len(question) == 0:
            return "", chat_history or self.chat_history
        
        if temperature is None:
            temperature = self.temperature
        
        if chat_history is None:
            chat_history = self.chat_history
            
        print("正在初始化 LLM...")
        llm = model_to_llm(self.model, temperature, self.appid, self.api_key, self.Spark_api_secret,self.Wenxin_secret_key)
        print("LLM 初始化完成")

        print("正在初始化检索器...")
        retriever = self.vectordb.as_retriever(
            search_type="mmr",  # 改为MMR算法
            search_kwargs={
                'k': top_k,
                'lambda_mult': 0.5  # 控制多样性
            }
        )
        print("检索器初始化完成")

        print("正在创建问答链...")
        # 新增:显式配置提示词模板
        from langchain.prompts import PromptTemplate
        
        template = """基于以下历史对话和相关文档回答用户问题。如果历史对话中已包含相关信息,优先参考历史对话内容:
        
        历史对话:
        {chat_history}
        
        相关文档:
        {context}
        
        问题:{question}
        """
        QA_PROMPT = PromptTemplate(
            template=template,
            input_variables=["chat_history", "context", "question"]
        )
        
        qa = ConversationalRetrievalChain.from_llm(
            llm=llm,
            retriever=retriever,
            combine_docs_chain_kwargs={"prompt": QA_PROMPT}  # 应用自定义模板
        )
        print("问答链创建完成")

        print("正在发送问题到模型...")
        print("当前传递的chat_history内容:", chat_history)
        # 新增:打印检索结果调试信息
        docs = retriever.get_relevant_documents(question)
        print(f"检索到的文档内容:{[doc.page_content[:50]+'...' for doc in docs]}")  # 显示文档前50字符
        result = qa({"question": question, "chat_history": chat_history})
        print("收到模型回复")
        answer = result['answer']
        answer = re.sub(r"\\n", '<br/>', answer)
        chat_history.append((question, answer))

        print("Chat_QA_chain_self.answer 执行完成")
        return chat_history  #返回本次回答和更新后的历史记录