|
|
|
|
|
import sys |
|
import os |
|
|
|
sys.path.append(os.path.dirname(os.path.dirname(__file__))) |
|
|
|
import IPython.display |
|
import io |
|
import gradio as gr |
|
from dotenv import load_dotenv, find_dotenv |
|
from llm.call_llm import get_completion |
|
from database.create_db import create_db_info |
|
from qa_chain.Chat_QA_chain_self import Chat_QA_chain_self |
|
from qa_chain.QA_chain_self import QA_chain_self |
|
import re |
|
|
|
|
|
|
|
|
|
|
|
|
|
_ = load_dotenv(find_dotenv()) |
|
LLM_MODEL_DICT = { |
|
"openai": ["gpt-3.5-turbo", "gpt-3.5-turbo-16k-0613", "gpt-3.5-turbo-0613", "gpt-4", "gpt-4-32k"], |
|
"wenxin": ["ERNIE-Bot", "ERNIE-Bot-4", "ERNIE-Bot-turbo"], |
|
"xinhuo": ["Spark-X1"], |
|
"zhipuai": ["chatglm_pro", "chatglm_std", "chatglm_lite"] |
|
} |
|
|
|
|
|
LLM_MODEL_LIST = sum(list(LLM_MODEL_DICT.values()),[]) |
|
INIT_LLM = "chatglm_std" |
|
EMBEDDING_MODEL_LIST = ['zhipuai', 'openai', 'm3e'] |
|
INIT_EMBEDDING_MODEL = "m3e" |
|
DEFAULT_DB_PATH = "./knowledge_db" |
|
DEFAULT_PERSIST_PATH = "./vector_db/chroma" |
|
AIGC_AVATAR_PATH = "./figures/aigc_avatar.png" |
|
DATAWHALE_AVATAR_PATH = "./figures/datawhale_avatar.png" |
|
AIGC_LOGO_PATH = "./figures/aigc_logo.png" |
|
DATAWHALE_LOGO_PATH = "./figures/datawhale_logo.png" |
|
|
|
def get_model_by_platform(platform): |
|
return LLM_MODEL_DICT.get(platform, "") |
|
class Model_center(): |
|
""" |
|
存储问答 Chain 的对象 |
|
|
|
- chat_qa_chain_self: 以 (model, embedding) 为键存储的带历史记录的问答链。 |
|
- qa_chain_self: 以 (model, embedding) 为键存储的不带历史记录的问答链。 |
|
""" |
|
def __init__(self): |
|
self.chat_qa_chain_self = {} |
|
self.qa_chain_self = {} |
|
|
|
def chat_qa_chain_self_answer(self, question: str, chat_history: list = [], model: str = "openai", embedding: str = "openai", temperature: float = 0.0, top_k: int = 4, history_len: int = 3, file_path: str = DEFAULT_DB_PATH, persist_path: str = DEFAULT_PERSIST_PATH): |
|
""" |
|
调用带历史记录的问答链进行回答 |
|
""" |
|
if question == None or len(question) < 1: |
|
return "", chat_history |
|
try: |
|
if (model, embedding) not in self.chat_qa_chain_self: |
|
self.chat_qa_chain_self[(model, embedding)] = Chat_QA_chain_self(model=model, temperature=temperature, |
|
top_k=top_k, chat_history=chat_history, file_path=file_path, persist_path=persist_path, embedding=embedding) |
|
chain = self.chat_qa_chain_self[(model, embedding)] |
|
return "", chain.answer(question=question, temperature=temperature, top_k=top_k) |
|
except Exception as e: |
|
return e, chat_history |
|
|
|
def qa_chain_self_answer(self, question: str, chat_history: list = [], model: str = "openai", embedding="openai", temperature: float = 0.0, top_k: int = 4, file_path: str = DEFAULT_DB_PATH, persist_path: str = DEFAULT_PERSIST_PATH): |
|
""" |
|
调用不带历史记录的问答链进行回答 |
|
""" |
|
if question == None or len(question) < 1: |
|
return "", chat_history |
|
try: |
|
if (model, embedding) not in self.qa_chain_self: |
|
self.qa_chain_self[(model, embedding)] = QA_chain_self(model=model, temperature=temperature, |
|
top_k=top_k, file_path=file_path, persist_path=persist_path, embedding=embedding) |
|
chain = self.qa_chain_self[(model, embedding)] |
|
chat_history.append( |
|
(question, chain.answer(question, temperature, top_k))) |
|
return "", chat_history |
|
except Exception as e: |
|
return e, chat_history |
|
|
|
def clear_history(self): |
|
if len(self.chat_qa_chain_self) > 0: |
|
for chain in self.chat_qa_chain_self.values(): |
|
chain.clear_history() |
|
|
|
|
|
def format_chat_prompt(message, chat_history): |
|
""" |
|
该函数用于格式化聊天 prompt。 |
|
|
|
参数: |
|
message: 当前的用户消息。 |
|
chat_history: 聊天历史记录。 |
|
|
|
返回: |
|
prompt: 格式化后的 prompt。 |
|
""" |
|
|
|
prompt = "" |
|
|
|
for turn in chat_history: |
|
|
|
user_message, bot_message = turn |
|
|
|
prompt = f"{prompt}\nUser: {user_message}\nAssistant: {bot_message}" |
|
|
|
prompt = f"{prompt}\nUser: {message}\nAssistant:" |
|
|
|
return prompt |
|
|
|
|
|
|
|
def respond(message, chat_history, llm, history_len=3, temperature=0.1, max_tokens=2048): |
|
""" |
|
该函数用于生成机器人的回复。 |
|
|
|
参数: |
|
message: 当前的用户消息。 |
|
chat_history: 聊天历史记录。 |
|
|
|
返回: |
|
"": 空字符串表示没有内容需要显示在界面上,可以替换为真正的机器人回复。 |
|
chat_history: 更新后的聊天历史记录 |
|
""" |
|
if message == None or len(message) < 1: |
|
return "", chat_history |
|
try: |
|
|
|
chat_history = chat_history[-history_len:] if history_len > 0 else [] |
|
|
|
formatted_prompt = format_chat_prompt(message, chat_history) |
|
|
|
bot_message = get_completion( |
|
formatted_prompt, llm, temperature=temperature, max_tokens=max_tokens) |
|
|
|
bot_message = re.sub(r"\\n", '<br/>', bot_message) |
|
|
|
chat_history.append((message, bot_message)) |
|
|
|
return "", chat_history |
|
except Exception as e: |
|
return e, chat_history |
|
|
|
|
|
model_center = Model_center() |
|
|
|
block = gr.Blocks() |
|
with block as demo: |
|
with gr.Row(equal_height=True): |
|
gr.Image(value=AIGC_LOGO_PATH, scale=1, min_width=10, show_label=False, show_download_button=False, container=False) |
|
|
|
with gr.Column(scale=2): |
|
gr.Markdown("""<h1><center>动手学大模型应用开发</center></h1> |
|
<center>LLM-UNIVERSE</center> |
|
""") |
|
gr.Image(value=DATAWHALE_LOGO_PATH, scale=1, min_width=10, show_label=False, show_download_button=False, container=False) |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=4): |
|
chatbot = gr.Chatbot(height=400, show_copy_button=True, show_share_button=True, avatar_images=(AIGC_AVATAR_PATH, DATAWHALE_AVATAR_PATH)) |
|
|
|
msg = gr.Textbox(label="Prompt/问题") |
|
|
|
with gr.Row(): |
|
|
|
db_with_his_btn = gr.Button("Chat db with history") |
|
db_wo_his_btn = gr.Button("Chat db without history") |
|
llm_btn = gr.Button("Chat with llm") |
|
with gr.Row(): |
|
|
|
clear = gr.ClearButton( |
|
components=[chatbot], value="Clear console") |
|
|
|
with gr.Column(scale=1): |
|
file = gr.File(label='请选择知识库目录', file_count='directory', |
|
file_types=['.txt', '.md', '.docx', '.pdf']) |
|
with gr.Row(): |
|
init_db = gr.Button("知识库文件向量化") |
|
model_argument = gr.Accordion("参数配置", open=False) |
|
with model_argument: |
|
temperature = gr.Slider(0, |
|
1, |
|
value=0.01, |
|
step=0.01, |
|
label="llm temperature", |
|
interactive=True) |
|
|
|
top_k = gr.Slider(1, |
|
10, |
|
value=3, |
|
step=1, |
|
label="vector db search top k", |
|
interactive=True) |
|
|
|
history_len = gr.Slider(0, |
|
5, |
|
value=3, |
|
step=1, |
|
label="history length", |
|
interactive=True) |
|
|
|
model_select = gr.Accordion("模型选择") |
|
with model_select: |
|
llm = gr.Dropdown( |
|
LLM_MODEL_LIST, |
|
label="large language model", |
|
value=INIT_LLM, |
|
interactive=True) |
|
|
|
embeddings = gr.Dropdown(EMBEDDING_MODEL_LIST, |
|
label="Embedding model", |
|
value=INIT_EMBEDDING_MODEL) |
|
|
|
|
|
init_db.click(create_db_info, |
|
inputs=[file, embeddings], outputs=[msg]) |
|
|
|
|
|
db_with_his_btn.click(model_center.chat_qa_chain_self_answer, inputs=[ |
|
msg, chatbot, llm, embeddings, temperature, top_k, history_len], |
|
outputs=[msg, chatbot]) |
|
|
|
db_wo_his_btn.click(model_center.qa_chain_self_answer, inputs=[ |
|
msg, chatbot, llm, embeddings, temperature, top_k], outputs=[msg, chatbot]) |
|
|
|
llm_btn.click(respond, inputs=[ |
|
msg, chatbot, llm, history_len, temperature], outputs=[msg, chatbot], show_progress="minimal") |
|
|
|
|
|
msg.submit(respond, inputs=[ |
|
msg, chatbot, llm, history_len, temperature], outputs=[msg, chatbot], show_progress="hidden") |
|
|
|
clear.click(model_center.clear_history) |
|
gr.Markdown("""提醒:<br> |
|
1. 使用时请先上传自己的知识文件,不然将会解析项目自带的知识库。 |
|
2. 初始化数据库时间可能较长,请耐心等待。 |
|
3. 使用中如果出现异常,将会在文本输入框进行展示,请不要惊慌。 <br> |
|
""") |
|
|
|
gr.close_all() |
|
|
|
|
|
|
|
demo.launch() |
|
|