SANGUO / qa_chain /Chat_QA_chain_self.py
konghuan's picture
1
45c5d09
from langchain.prompts import PromptTemplate
from langchain.chains import RetrievalQA
# from langchain.vectorstores import Chroma
from langchain_community.vectorstores import Chroma
from langchain.chains import ConversationalRetrievalChain
from langchain.memory import ConversationBufferMemory
from langchain.chat_models import ChatOpenAI
import sys
import os
# 添加当前目录到 Python 路径
current_dir = os.path.dirname(os.path.abspath(__file__))
if current_dir not in sys.path:
sys.path.append(current_dir)
from model_to_llm import model_to_llm
from get_vectordb import get_vectordb
import re
class Chat_QA_chain_self:
""""
带历史记录的问答链
- model:调用的模型名称
- temperature:温度系数,控制生成的随机性
- top_k:返回检索的前k个相似文档
- chat_history:历史记录,输入一个列表,默认是一个空列表
- history_len:控制保留的最近 history_len 次对话
- file_path:建库文件所在路径
- persist_path:向量数据库持久化路径
- appid:星火
- api_key:星火、百度文心、OpenAI、智谱都需要传递的参数
- Spark_api_secret:星火秘钥
- Wenxin_secret_key:文心秘钥
- embeddings:使用的embedding模型
- embedding_key:使用的embedding模型的秘钥(智谱或者OpenAI)
"""
def __init__(self,model:str, temperature:float=0.0, top_k:int=4, chat_history:list=[], file_path:str=None, persist_path:str=None, appid:str=None, api_key:str=None, Spark_api_secret:str=None,Wenxin_secret_key:str=None, embedding = "openai",embedding_key:str=None):
print("开始初始化 Chat_QA_chain_self...")
self.model = model
self.temperature = temperature
self.top_k = top_k
self.chat_history = chat_history
self.file_path = file_path
self.persist_path = persist_path
self.appid = appid
self.api_key = api_key
self.Spark_api_secret = Spark_api_secret
self.Wenxin_secret_key = Wenxin_secret_key
self.embedding = embedding
self.embedding_key = embedding_key
print("正在初始化向量数据库...")
try:
self.vectordb = get_vectordb(self.file_path, self.persist_path, self.embedding,self.embedding_key)
print("向量数据库初始化完成")
except Exception as e:
print(f"向量数据库初始化失败: {str(e)}")
raise
print("Chat_QA_chain_self 初始化完成")
def clear_history(self):
"清空历史记录"
return self.chat_history.clear()
def change_history_length(self, history_len:int=None):
"""
返回指定数量的历史记录,不传参数则返回全部
"""
if history_len is None:
return self.chat_history
n = len(self.chat_history)
return self.chat_history[n-history_len:]
def answer(self, question:str=None, temperature=None, top_k=4, chat_history=None):
""""
核心方法,调用问答链
arguments:
- question:用户提问
- chat_history:对话历史记录
"""
print("Chat_QA_chain_self.answer 开始执行...")
if len(question) == 0:
return "", chat_history or self.chat_history
if temperature is None:
temperature = self.temperature
if chat_history is None:
chat_history = self.chat_history
print("正在初始化 LLM...")
llm = model_to_llm(self.model, temperature, self.appid, self.api_key, self.Spark_api_secret,self.Wenxin_secret_key)
print("LLM 初始化完成")
print("正在初始化检索器...")
retriever = self.vectordb.as_retriever(
search_type="mmr", # 改为MMR算法
search_kwargs={
'k': top_k,
'lambda_mult': 0.5 # 控制多样性
}
)
print("检索器初始化完成")
print("正在创建问答链...")
# 新增:显式配置提示词模板
from langchain.prompts import PromptTemplate
template = """基于以下历史对话和相关文档回答用户问题。如果历史对话中已包含相关信息,优先参考历史对话内容:
历史对话:
{chat_history}
相关文档:
{context}
问题:{question}
"""
QA_PROMPT = PromptTemplate(
template=template,
input_variables=["chat_history", "context", "question"]
)
qa = ConversationalRetrievalChain.from_llm(
llm=llm,
retriever=retriever,
combine_docs_chain_kwargs={"prompt": QA_PROMPT} # 应用自定义模板
)
print("问答链创建完成")
print("正在发送问题到模型...")
print("当前传递的chat_history内容:", chat_history)
# 新增:打印检索结果调试信息
docs = retriever.get_relevant_documents(question)
print(f"检索到的文档内容:{[doc.page_content[:50]+'...' for doc in docs]}") # 显示文档前50字符
result = qa({"question": question, "chat_history": chat_history})
print("收到模型回复")
answer = result['answer']
answer = re.sub(r"\\n", '<br/>', answer)
chat_history.append((question, answer))
print("Chat_QA_chain_self.answer 执行完成")
return chat_history #返回本次回答和更新后的历史记录