|
from langchain.prompts import PromptTemplate |
|
from langchain.chains import RetrievalQA |
|
|
|
from langchain_community.vectorstores import Chroma |
|
from langchain.chains import ConversationalRetrievalChain |
|
from langchain.memory import ConversationBufferMemory |
|
from langchain.chat_models import ChatOpenAI |
|
import sys |
|
import os |
|
|
|
|
|
current_dir = os.path.dirname(os.path.abspath(__file__)) |
|
if current_dir not in sys.path: |
|
sys.path.append(current_dir) |
|
|
|
from model_to_llm import model_to_llm |
|
from get_vectordb import get_vectordb |
|
import re |
|
|
|
class Chat_QA_chain_self: |
|
"""" |
|
带历史记录的问答链 |
|
- model:调用的模型名称 |
|
- temperature:温度系数,控制生成的随机性 |
|
- top_k:返回检索的前k个相似文档 |
|
- chat_history:历史记录,输入一个列表,默认是一个空列表 |
|
- history_len:控制保留的最近 history_len 次对话 |
|
- file_path:建库文件所在路径 |
|
- persist_path:向量数据库持久化路径 |
|
- appid:星火 |
|
- api_key:星火、百度文心、OpenAI、智谱都需要传递的参数 |
|
- Spark_api_secret:星火秘钥 |
|
- Wenxin_secret_key:文心秘钥 |
|
- embeddings:使用的embedding模型 |
|
- embedding_key:使用的embedding模型的秘钥(智谱或者OpenAI) |
|
""" |
|
def __init__(self,model:str, temperature:float=0.0, top_k:int=4, chat_history:list=[], file_path:str=None, persist_path:str=None, appid:str=None, api_key:str=None, Spark_api_secret:str=None,Wenxin_secret_key:str=None, embedding = "openai",embedding_key:str=None): |
|
print("开始初始化 Chat_QA_chain_self...") |
|
self.model = model |
|
self.temperature = temperature |
|
self.top_k = top_k |
|
self.chat_history = chat_history |
|
self.file_path = file_path |
|
self.persist_path = persist_path |
|
self.appid = appid |
|
self.api_key = api_key |
|
self.Spark_api_secret = Spark_api_secret |
|
self.Wenxin_secret_key = Wenxin_secret_key |
|
self.embedding = embedding |
|
self.embedding_key = embedding_key |
|
|
|
print("正在初始化向量数据库...") |
|
try: |
|
self.vectordb = get_vectordb(self.file_path, self.persist_path, self.embedding,self.embedding_key) |
|
print("向量数据库初始化完成") |
|
except Exception as e: |
|
print(f"向量数据库初始化失败: {str(e)}") |
|
raise |
|
print("Chat_QA_chain_self 初始化完成") |
|
|
|
def clear_history(self): |
|
"清空历史记录" |
|
return self.chat_history.clear() |
|
|
|
|
|
def change_history_length(self, history_len:int=None): |
|
""" |
|
返回指定数量的历史记录,不传参数则返回全部 |
|
""" |
|
if history_len is None: |
|
return self.chat_history |
|
n = len(self.chat_history) |
|
return self.chat_history[n-history_len:] |
|
|
|
|
|
def answer(self, question:str=None, temperature=None, top_k=4, chat_history=None): |
|
"""" |
|
核心方法,调用问答链 |
|
arguments: |
|
- question:用户提问 |
|
- chat_history:对话历史记录 |
|
""" |
|
print("Chat_QA_chain_self.answer 开始执行...") |
|
|
|
if len(question) == 0: |
|
return "", chat_history or self.chat_history |
|
|
|
if temperature is None: |
|
temperature = self.temperature |
|
|
|
if chat_history is None: |
|
chat_history = self.chat_history |
|
|
|
print("正在初始化 LLM...") |
|
llm = model_to_llm(self.model, temperature, self.appid, self.api_key, self.Spark_api_secret,self.Wenxin_secret_key) |
|
print("LLM 初始化完成") |
|
|
|
print("正在初始化检索器...") |
|
retriever = self.vectordb.as_retriever( |
|
search_type="mmr", |
|
search_kwargs={ |
|
'k': top_k, |
|
'lambda_mult': 0.5 |
|
} |
|
) |
|
print("检索器初始化完成") |
|
|
|
print("正在创建问答链...") |
|
|
|
from langchain.prompts import PromptTemplate |
|
|
|
template = """基于以下历史对话和相关文档回答用户问题。如果历史对话中已包含相关信息,优先参考历史对话内容: |
|
|
|
历史对话: |
|
{chat_history} |
|
|
|
相关文档: |
|
{context} |
|
|
|
问题:{question} |
|
""" |
|
QA_PROMPT = PromptTemplate( |
|
template=template, |
|
input_variables=["chat_history", "context", "question"] |
|
) |
|
|
|
qa = ConversationalRetrievalChain.from_llm( |
|
llm=llm, |
|
retriever=retriever, |
|
combine_docs_chain_kwargs={"prompt": QA_PROMPT} |
|
) |
|
print("问答链创建完成") |
|
|
|
print("正在发送问题到模型...") |
|
print("当前传递的chat_history内容:", chat_history) |
|
|
|
docs = retriever.get_relevant_documents(question) |
|
print(f"检索到的文档内容:{[doc.page_content[:50]+'...' for doc in docs]}") |
|
result = qa({"question": question, "chat_history": chat_history}) |
|
print("收到模型回复") |
|
answer = result['answer'] |
|
answer = re.sub(r"\\n", '<br/>', answer) |
|
chat_history.append((question, answer)) |
|
|
|
print("Chat_QA_chain_self.answer 执行完成") |
|
return chat_history |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|