File size: 5,675 Bytes
28d41ca 45c5d09 28d41ca |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 |
from langchain.prompts import PromptTemplate
from langchain.chains import RetrievalQA
# from langchain.vectorstores import Chroma
from langchain_community.vectorstores import Chroma
from langchain.chains import ConversationalRetrievalChain
from langchain.memory import ConversationBufferMemory
from langchain.chat_models import ChatOpenAI
import sys
import os
# 添加当前目录到 Python 路径
current_dir = os.path.dirname(os.path.abspath(__file__))
if current_dir not in sys.path:
sys.path.append(current_dir)
from model_to_llm import model_to_llm
from get_vectordb import get_vectordb
import re
class Chat_QA_chain_self:
""""
带历史记录的问答链
- model:调用的模型名称
- temperature:温度系数,控制生成的随机性
- top_k:返回检索的前k个相似文档
- chat_history:历史记录,输入一个列表,默认是一个空列表
- history_len:控制保留的最近 history_len 次对话
- file_path:建库文件所在路径
- persist_path:向量数据库持久化路径
- appid:星火
- api_key:星火、百度文心、OpenAI、智谱都需要传递的参数
- Spark_api_secret:星火秘钥
- Wenxin_secret_key:文心秘钥
- embeddings:使用的embedding模型
- embedding_key:使用的embedding模型的秘钥(智谱或者OpenAI)
"""
def __init__(self,model:str, temperature:float=0.0, top_k:int=4, chat_history:list=[], file_path:str=None, persist_path:str=None, appid:str=None, api_key:str=None, Spark_api_secret:str=None,Wenxin_secret_key:str=None, embedding = "openai",embedding_key:str=None):
print("开始初始化 Chat_QA_chain_self...")
self.model = model
self.temperature = temperature
self.top_k = top_k
self.chat_history = chat_history
self.file_path = file_path
self.persist_path = persist_path
self.appid = appid
self.api_key = api_key
self.Spark_api_secret = Spark_api_secret
self.Wenxin_secret_key = Wenxin_secret_key
self.embedding = embedding
self.embedding_key = embedding_key
print("正在初始化向量数据库...")
try:
self.vectordb = get_vectordb(self.file_path, self.persist_path, self.embedding,self.embedding_key)
print("向量数据库初始化完成")
except Exception as e:
print(f"向量数据库初始化失败: {str(e)}")
raise
print("Chat_QA_chain_self 初始化完成")
def clear_history(self):
"清空历史记录"
return self.chat_history.clear()
def change_history_length(self, history_len:int=None):
"""
返回指定数量的历史记录,不传参数则返回全部
"""
if history_len is None:
return self.chat_history
n = len(self.chat_history)
return self.chat_history[n-history_len:]
def answer(self, question:str=None, temperature=None, top_k=4, chat_history=None):
""""
核心方法,调用问答链
arguments:
- question:用户提问
- chat_history:对话历史记录
"""
print("Chat_QA_chain_self.answer 开始执行...")
if len(question) == 0:
return "", chat_history or self.chat_history
if temperature is None:
temperature = self.temperature
if chat_history is None:
chat_history = self.chat_history
print("正在初始化 LLM...")
llm = model_to_llm(self.model, temperature, self.appid, self.api_key, self.Spark_api_secret,self.Wenxin_secret_key)
print("LLM 初始化完成")
print("正在初始化检索器...")
retriever = self.vectordb.as_retriever(
search_type="mmr", # 改为MMR算法
search_kwargs={
'k': top_k,
'lambda_mult': 0.5 # 控制多样性
}
)
print("检索器初始化完成")
print("正在创建问答链...")
# 新增:显式配置提示词模板
from langchain.prompts import PromptTemplate
template = """基于以下历史对话和相关文档回答用户问题。如果历史对话中已包含相关信息,优先参考历史对话内容:
历史对话:
{chat_history}
相关文档:
{context}
问题:{question}
"""
QA_PROMPT = PromptTemplate(
template=template,
input_variables=["chat_history", "context", "question"]
)
qa = ConversationalRetrievalChain.from_llm(
llm=llm,
retriever=retriever,
combine_docs_chain_kwargs={"prompt": QA_PROMPT} # 应用自定义模板
)
print("问答链创建完成")
print("正在发送问题到模型...")
print("当前传递的chat_history内容:", chat_history)
# 新增:打印检索结果调试信息
docs = retriever.get_relevant_documents(question)
print(f"检索到的文档内容:{[doc.page_content[:50]+'...' for doc in docs]}") # 显示文档前50字符
result = qa({"question": question, "chat_history": chat_history})
print("收到模型回复")
answer = result['answer']
answer = re.sub(r"\\n", '<br/>', answer)
chat_history.append((question, answer))
print("Chat_QA_chain_self.answer 执行完成")
return chat_history #返回本次回答和更新后的历史记录
|