Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -10,6 +10,8 @@ from langchain.chains import RetrievalQA
|
|
10 |
from langchain.prompts import PromptTemplate
|
11 |
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
|
12 |
|
|
|
|
|
13 |
logging.basicConfig(level=logging.INFO)
|
14 |
|
15 |
# ─── 配置 ─────────────────────────────────────────────────────
|
@@ -17,8 +19,13 @@ VECTOR_STORE_DIR = "./vector_store"
|
|
17 |
MODEL_NAME = "uer/gpt2-chinese-cluecorpussmall"
|
18 |
EMBEDDING_MODEL_NAME = "sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
|
19 |
|
|
|
|
|
|
|
|
|
|
|
20 |
# ─── 1. 加载 LLM ────────────────────────────────────────────────
|
21 |
-
|
22 |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
|
23 |
model = AutoModelForCausalLM.from_pretrained(
|
24 |
MODEL_NAME,
|
@@ -33,22 +40,22 @@ gen_pipe = pipeline(
|
|
33 |
temperature=0.5,
|
34 |
top_p=0.9,
|
35 |
do_sample=True,
|
|
|
36 |
)
|
37 |
llm = HuggingFacePipeline(pipeline=gen_pipe)
|
38 |
-
|
39 |
|
40 |
# ─── 2. 加载向量库 ─────────────────────────────────────────────
|
41 |
-
|
42 |
embeddings = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL_NAME)
|
43 |
vectordb = Chroma(persist_directory=VECTOR_STORE_DIR, embedding_function=embeddings)
|
44 |
retriever = vectordb.as_retriever(search_kwargs={"k": 3})
|
45 |
-
|
46 |
|
47 |
# ─── 3. 自定义 Prompt ─────────────────────────────────────────
|
48 |
prompt_template = PromptTemplate.from_template(
|
49 |
"""你是一位专业的数学助教,请根据以下参考资料回答用户的问题。
|
50 |
如果资料中没有相关内容,请直接回答“我不知道”或“资料中未提及”,不要编造答案。
|
51 |
-
|
52 |
参考资料:
|
53 |
{context}
|
54 |
|
@@ -67,19 +74,17 @@ qa_chain = RetrievalQA.from_chain_type(
|
|
67 |
chain_type_kwargs={"prompt": prompt_template},
|
68 |
return_source_documents=True,
|
69 |
)
|
70 |
-
|
71 |
|
72 |
# ─── 5. 业务函数 ───────────────────────────────────────────────
|
73 |
def qa_fn(query: str):
|
74 |
if not query.strip():
|
75 |
return "❌ 请输入问题内容。"
|
76 |
-
# 执行检索与问答
|
77 |
result = qa_chain({"query": query})
|
78 |
answer = result["result"].strip()
|
79 |
sources = result.get("source_documents", [])
|
80 |
if not sources:
|
81 |
return "📌 回答:未在知识库中找到相关内容,请尝试更换问题或补充教材。"
|
82 |
-
# 拼接参考片段
|
83 |
sources_text = "\n\n".join(
|
84 |
[f"【片段 {i+1}】\n{doc.page_content}" for i, doc in enumerate(sources)]
|
85 |
)
|
@@ -103,3 +108,4 @@ if __name__ == "__main__":
|
|
103 |
|
104 |
|
105 |
|
|
|
|
10 |
from langchain.prompts import PromptTemplate
|
11 |
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
|
12 |
|
13 |
+
from build_index import main as build_index_if_needed # 确保提交了 build_index.py
|
14 |
+
|
15 |
logging.basicConfig(level=logging.INFO)
|
16 |
|
17 |
# ─── 配置 ─────────────────────────────────────────────────────
|
|
|
19 |
MODEL_NAME = "uer/gpt2-chinese-cluecorpussmall"
|
20 |
EMBEDDING_MODEL_NAME = "sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
|
21 |
|
22 |
+
# 如果向量库不存在,自动构建
|
23 |
+
if not os.path.exists(VECTOR_STORE_DIR) or not os.listdir(VECTOR_STORE_DIR):
|
24 |
+
logging.info("向量库不存在,启动自动构建……")
|
25 |
+
build_index_if_needed()
|
26 |
+
|
27 |
# ─── 1. 加载 LLM ────────────────────────────────────────────────
|
28 |
+
logging.info("🔧 加载生成模型…")
|
29 |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
|
30 |
model = AutoModelForCausalLM.from_pretrained(
|
31 |
MODEL_NAME,
|
|
|
40 |
temperature=0.5,
|
41 |
top_p=0.9,
|
42 |
do_sample=True,
|
43 |
+
trust_remote_code=True,
|
44 |
)
|
45 |
llm = HuggingFacePipeline(pipeline=gen_pipe)
|
46 |
+
logging.info("✅ 生成模型加载成功。")
|
47 |
|
48 |
# ─── 2. 加载向量库 ─────────────────────────────────────────────
|
49 |
+
logging.info("📚 加载向量库…")
|
50 |
embeddings = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL_NAME)
|
51 |
vectordb = Chroma(persist_directory=VECTOR_STORE_DIR, embedding_function=embeddings)
|
52 |
retriever = vectordb.as_retriever(search_kwargs={"k": 3})
|
53 |
+
logging.info("✅ 向量库加载成功。")
|
54 |
|
55 |
# ─── 3. 自定义 Prompt ─────────────────────────────────────────
|
56 |
prompt_template = PromptTemplate.from_template(
|
57 |
"""你是一位专业的数学助教,请根据以下参考资料回答用户的问题。
|
58 |
如果资料中没有相关内容,请直接回答“我不知道”或“资料中未提及”,不要编造答案。
|
|
|
59 |
参考资料:
|
60 |
{context}
|
61 |
|
|
|
74 |
chain_type_kwargs={"prompt": prompt_template},
|
75 |
return_source_documents=True,
|
76 |
)
|
77 |
+
logging.info("✅ RAG 问答链构建成功。")
|
78 |
|
79 |
# ─── 5. 业务函数 ───────────────────────────────────────────────
|
80 |
def qa_fn(query: str):
|
81 |
if not query.strip():
|
82 |
return "❌ 请输入问题内容。"
|
|
|
83 |
result = qa_chain({"query": query})
|
84 |
answer = result["result"].strip()
|
85 |
sources = result.get("source_documents", [])
|
86 |
if not sources:
|
87 |
return "📌 回答:未在知识库中找到相关内容,请尝试更换问题或补充教材。"
|
|
|
88 |
sources_text = "\n\n".join(
|
89 |
[f"【片段 {i+1}】\n{doc.page_content}" for i, doc in enumerate(sources)]
|
90 |
)
|
|
|
108 |
|
109 |
|
110 |
|
111 |
+
|