ljy5946 commited on
Commit
4ded835
·
verified ·
1 Parent(s): b70af83

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -18
app.py CHANGED
@@ -1,30 +1,33 @@
 
 
1
  import os
 
2
  import gradio as gr
3
  import torch
4
- import logging
 
5
 
6
  from langchain_community.embeddings import HuggingFaceEmbeddings
7
  from langchain_community.vectorstores import Chroma
8
  from langchain_community.llms import HuggingFacePipeline
9
  from langchain.chains import RetrievalQA
10
  from langchain.prompts import PromptTemplate
11
- from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
12
 
13
- from build_index import main as build_index_if_needed # 确保提交了 build_index.py
14
 
15
  logging.basicConfig(level=logging.INFO)
16
 
17
  # ─── 配置 ─────────────────────────────────────────────────────
18
- VECTOR_STORE_DIR = "./vector_store"
19
- MODEL_NAME = "uer/gpt2-chinese-cluecorpussmall"
20
- EMBEDDING_MODEL_NAME = "sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
21
 
22
- # 如果向量库不存在,自动构建
23
  if not os.path.exists(VECTOR_STORE_DIR) or not os.listdir(VECTOR_STORE_DIR):
24
  logging.info("向量库不存在,启动自动构建……")
25
  build_index_if_needed()
26
 
27
- # ─── 1. 加载 LLM ────────────────────────────────────────────────
28
  logging.info("🔧 加载生成模型…")
29
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
30
  model = AutoModelForCausalLM.from_pretrained(
@@ -48,8 +51,8 @@ logging.info("✅ 生成模型加载成功。")
48
  # ─── 2. 加载向量库 ─────────────────────────────────────────────
49
  logging.info("📚 加载向量库…")
50
  embeddings = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL_NAME)
51
- vectordb = Chroma(persist_directory=VECTOR_STORE_DIR, embedding_function=embeddings)
52
- retriever = vectordb.as_retriever(search_kwargs={"k": 3})
53
  logging.info("✅ 向量库加载成功。")
54
 
55
  # ─── 3. 自定义 Prompt ─────────────────────────────────────────
@@ -66,25 +69,34 @@ prompt_template = PromptTemplate.from_template(
66
  """
67
  )
68
 
69
- # ─── 4. 构建 RAG 问答链 ───────────────────────────────────────
70
  qa_chain = RetrievalQA.from_chain_type(
71
  llm=llm,
72
- chain_type="stuff",
73
  retriever=retriever,
74
  chain_type_kwargs={"prompt": prompt_template},
75
  return_source_documents=True,
76
  )
77
- logging.info("✅ RAG 问答链构建成功。")
78
 
79
  # ─── 5. 业务函数 ───────────────────────────────────────────────
80
  def qa_fn(query: str):
81
- if not query.strip():
82
  return "❌ 请输入问题内容。"
83
- result = qa_chain({"query": query})
84
- answer = result["result"].strip()
 
 
 
 
 
85
  sources = result.get("source_documents", [])
 
 
86
  if not sources:
87
- return "📌 回答:未在知识库中找到相关内容,请尝试更换问题或补充教材。"
 
 
88
  sources_text = "\n\n".join(
89
  [f"【片段 {i+1}】\n{doc.page_content}" for i, doc in enumerate(sources)]
90
  )
@@ -99,7 +111,7 @@ with gr.Blocks(title="智能学习助手") as demo:
99
  gr.Button("提问").click(fn=qa_fn, inputs=query, outputs=answer)
100
  gr.Markdown(
101
  "---\n"
102
- "模型:UER/GPT2-Chinese-ClueCorpus + Sentence-Transformers RAG \n"
103
  "由 Hugging Face Spaces 提供算力支持"
104
  )
105
 
@@ -109,3 +121,4 @@ if __name__ == "__main__":
109
 
110
 
111
 
 
 
1
+ # app.py
2
+
3
  import os
4
+ import logging
5
  import gradio as gr
6
  import torch
7
+
8
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
9
 
10
  from langchain_community.embeddings import HuggingFaceEmbeddings
11
  from langchain_community.vectorstores import Chroma
12
  from langchain_community.llms import HuggingFacePipeline
13
  from langchain.chains import RetrievalQA
14
  from langchain.prompts import PromptTemplate
 
15
 
16
+ from build_index import main as build_index_if_needed # 需确保 build_index.py 在同目录
17
 
18
  logging.basicConfig(level=logging.INFO)
19
 
20
  # ─── 配置 ─────────────────────────────────────────────────────
21
+ VECTOR_STORE_DIR = "./vector_store"
22
+ MODEL_NAME = "uer/gpt2-chinese-cluecorpussmall"
23
+ EMBEDDING_MODEL_NAME = "sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
24
 
25
+ # 容器启动时自动构建向量库(如果还没提交 vector_store)
26
  if not os.path.exists(VECTOR_STORE_DIR) or not os.listdir(VECTOR_STORE_DIR):
27
  logging.info("向量库不存在,启动自动构建……")
28
  build_index_if_needed()
29
 
30
+ # ─── 1. 加载生成模型 ──────────────────────────────────────────────
31
  logging.info("🔧 加载生成模型…")
32
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
33
  model = AutoModelForCausalLM.from_pretrained(
 
51
  # ─── 2. 加载向量库 ─────────────────────────────────────────────
52
  logging.info("📚 加载向量库…")
53
  embeddings = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL_NAME)
54
+ vectordb = Chroma(persist_directory=VECTOR_STORE_DIR, embedding_function=embeddings)
55
+ retriever = vectordb.as_retriever(search_kwargs={"k": 3})
56
  logging.info("✅ 向量库加载成功。")
57
 
58
  # ─── 3. 自定义 Prompt ─────────────────────────────────────────
 
69
  """
70
  )
71
 
72
+ # ─── 4. 构建 RAG 问答链(map_reduce) ───────────────────────────
73
  qa_chain = RetrievalQA.from_chain_type(
74
  llm=llm,
75
+ chain_type="map_reduce", # map_reduce 避免超长
76
  retriever=retriever,
77
  chain_type_kwargs={"prompt": prompt_template},
78
  return_source_documents=True,
79
  )
80
+ logging.info("✅ RAG 问答链(map_reduce)构建成功。")
81
 
82
  # ─── 5. 业务函数 ───────────────────────────────────────────────
83
  def qa_fn(query: str):
84
+ if not query or not query.strip():
85
  return "❌ 请输入问题内容。"
86
+ try:
87
+ result = qa_chain({"query": query})
88
+ except Exception as e:
89
+ logging.error(f"问答链运行出错:{e}")
90
+ return "抱歉,问答过程中出现错误,请稍后重试。"
91
+
92
+ answer = result.get("result", "").strip()
93
  sources = result.get("source_documents", [])
94
+ if not answer:
95
+ return "📌 回答:未生成答案,请稍后再试。"
96
  if not sources:
97
+ return f"📌 回答:{answer}\n\n(未检索到参考片段)"
98
+
99
+ # 拼接参考片段
100
  sources_text = "\n\n".join(
101
  [f"【片段 {i+1}】\n{doc.page_content}" for i, doc in enumerate(sources)]
102
  )
 
111
  gr.Button("提问").click(fn=qa_fn, inputs=query, outputs=answer)
112
  gr.Markdown(
113
  "---\n"
114
+ "模型:UER/GPT2-Chinese-ClueCorpus + Sentence-Transformers RAG (map_reduce) \n"
115
  "由 Hugging Face Spaces 提供算力支持"
116
  )
117
 
 
121
 
122
 
123
 
124
+