ljy5946 commited on
Commit
e7e03e0
·
verified ·
1 Parent(s): cfa0432

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +88 -25
app.py CHANGED
@@ -1,37 +1,94 @@
1
  import os
2
  import gradio as gr
3
- from langchain.vectorstores import Chroma
4
- from langchain.embeddings import HuggingFaceEmbeddings
 
 
 
 
 
 
 
5
  from langchain.chains import RetrievalQA
6
- from transformers import pipeline
7
- from langchain.llms import HuggingFacePipeline
 
 
 
8
 
9
  # 设置路径
 
10
  VECTOR_STORE_DIR = "./vector_store"
11
- MODEL_NAME = "uer/gpt2-chinese-cluecorpussmall"
 
 
 
12
 
13
- # 设置 LLM 和检索器
14
  print("🔧 加载生成模型...")
15
- gen_pipe = pipeline("text-generation", model=MODEL_NAME, max_new_tokens=256)
16
- llm = HuggingFacePipeline(pipeline=gen_pipe)
17
-
18
- print("📚 加载向量库...")
19
- embeddings = HuggingFaceEmbeddings(
20
- model_name="sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
21
- )
22
- vectordb = Chroma(persist_directory=VECTOR_STORE_DIR, embedding_function=embeddings)
23
-
24
- retriever = vectordb.as_retriever(search_kwargs={"k": 3})
25
- qa_chain = RetrievalQA.from_chain_type(
26
- llm=llm,
27
- chain_type="stuff",
28
- retriever=retriever,
29
- return_source_documents=True
30
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  def qa_fn(query):
33
  if not query.strip():
34
  return "❌ 请输入问题内容。"
 
 
35
  try:
36
  result = qa_chain({"query": query})
37
  answer = result["result"]
@@ -41,9 +98,11 @@ def qa_fn(query):
41
  )
42
  return f"📌 回答:{answer.strip()}\n\n📚 参考:\n{sources_text}"
43
  except Exception as e:
44
- return f" 出现错误:{str(e)}"
 
45
 
46
- with gr.Blocks(title="数学知识问答助手") as demo:
 
47
  gr.Markdown("## 📘 数学知识问答助手\n输入教材相关问题,例如:“什么是函数的定义域?”")
48
  with gr.Row():
49
  query_input = gr.Textbox(label="问题", placeholder="请输入你的问题", lines=2)
@@ -52,5 +111,9 @@ with gr.Blocks(title="数学知识问答助手") as demo:
52
 
53
  submit_btn.click(fn=qa_fn, inputs=query_input, outputs=output_box)
54
 
55
- demo.launch()
 
 
 
 
56
 
 
1
  import os
2
  import gradio as gr
3
+ import torch # 确保导入 torch,因为 phi-2 模型需要
4
+ import logging
5
+
6
+ # LangChain 新版导入方式
7
+ from langchain_chroma import Chroma
8
+ from langchain_huggingface import HuggingFaceEmbeddings
9
+ # 注意:HuggingFacePipeline 从 0.2.x 开始推荐从 langchain_huggingface 导入
10
+ # 如果遇到问题,也可以尝试从 langchain_community.llms 导入
11
+ from langchain_huggingface.llms import HuggingFacePipeline # 或者 from langchain_community.llms import HuggingFacePipeline
12
  from langchain.chains import RetrievalQA
13
+
14
+ # Transformers
15
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
16
+
17
+ logging.basicConfig(level=logging.INFO) # 更好的日志级别
18
 
19
  # 设置路径
20
+ # 确保这些路径与您的预下载模型和向量库文件夹名称匹配
21
  VECTOR_STORE_DIR = "./vector_store"
22
+ # 将模型名称指向本地预下载的路径
23
+ MODEL_NAME = "./hf_models_cache/models--uer--gpt2-chinese-cluecorpussmall"
24
+ EMBEDDING_MODEL_NAME = "./hf_models_cache/models--sentence-transformers--paraphrase-multilingual-mpnet-base-v2"
25
+
26
 
27
+ # 1. 轻量 LLM(uer/gpt2-chinese-cluecorpussmall)
28
  print("🔧 加载生成模型...")
29
+ try:
30
+ # 确保 tokenizer 和 model 是从正确的本地路径加载
31
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
32
+ model = AutoModelForCausalLM.from_pretrained(
33
+ MODEL_NAME,
34
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
35
+ device_map="auto",
36
+ )
37
+ gen_pipe = pipeline(
38
+ task="text-generation",
39
+ model=model,
40
+ tokenizer=tokenizer,
41
+ max_new_tokens=256,
42
+ temperature=0.5,
43
+ top_p=0.9,
44
+ do_sample=True,
45
+ )
46
+ llm = HuggingFacePipeline(pipeline=gen_pipe)
47
+ print("✅ 生成模型加载成功。")
48
+ except Exception as e:
49
+ logging.error(f"加载生成模型失败: {e}", exc_info=True)
50
+ # 如果加载失败,可以考虑在这里退出或者给一个友好的错误提示
51
+ llm = None # 确保 llm 为 None,避免后续报错
52
+ print("❌ 生成模型加载失败,应用可能无法正常工作。")
53
+
54
+
55
+ # 2. 向量库和嵌入模型
56
+ print("📚 加载向量库和嵌入模型...")
57
+ try:
58
+ embeddings = HuggingFaceEmbeddings(
59
+ model_name=EMBEDDING_MODEL_NAME # 指向本地预下载的嵌入模型路径
60
+ )
61
+ vectordb = Chroma(persist_directory=VECTOR_STORE_DIR, embedding_function=embeddings)
62
+ print("✅ 向量库加载成功。")
63
+ except Exception as e:
64
+ logging.error(f"加载向量库失败: {e}", exc_info=True)
65
+ vectordb = None # 确保 vectordb 为 None
66
+ print("❌ 向量库加载失败,RAG功能将无法使用。")
67
+
68
 
69
+ # 3. RAG 问答链
70
+ qa_chain = None
71
+ if llm and vectordb: # 只有当LLM和向量库都成功加载时才构建RAG链
72
+ try:
73
+ retriever = vectordb.as_retriever(search_kwargs={"k": 3})
74
+ qa_chain = RetrievalQA.from_chain_type(
75
+ llm=llm,
76
+ chain_type="stuff",
77
+ retriever=retriever,
78
+ return_source_documents=True
79
+ )
80
+ print("✅ RAG问答链构建成功。")
81
+ except Exception as e:
82
+ logging.error(f"构建RAG问答链失败: {e}", exc_info=True)
83
+ print("❌ RAG问答链构建失败。")
84
+
85
+
86
+ # 4. 业务函数
87
  def qa_fn(query):
88
  if not query.strip():
89
  return "❌ 请输入问题内容。"
90
+ if not qa_chain: # 检查 qa_chain 是否已成功构建
91
+ return "⚠️ 问答系统未完全加载,请稍后再试或检查日志。"
92
  try:
93
  result = qa_chain({"query": query})
94
  answer = result["result"]
 
98
  )
99
  return f"📌 回答:{answer.strip()}\n\n📚 参考:\n{sources_text}"
100
  except Exception as e:
101
+ logging.exception("问答失败:%s", e)
102
+ return f"❌ 出现错误:{str(e)}\n请检查日志获取更多信息。"
103
 
104
+ # 5. Gradio UI
105
+ with gr.Blocks(title="数学知识问答助手", theme=gr.themes.Base()) as demo:
106
  gr.Markdown("## 📘 数学知识问答助手\n输入教材相关问题,例如:“什么是函数的定义域?”")
107
  with gr.Row():
108
  query_input = gr.Textbox(label="问题", placeholder="请输入你的问题", lines=2)
 
111
 
112
  submit_btn.click(fn=qa_fn, inputs=query_input, outputs=output_box)
113
 
114
+ gr.Markdown("---\n模型:uer/gpt2-chinese-cluecorpussmall + Chroma RAG | Powered by Hugging Face Spaces")
115
+
116
+
117
+ if __name__ == "__main__":
118
+ demo.launch()
119