ljy5946 commited on
Commit
d4ceb14
·
verified ·
1 Parent(s): 2999a2c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -75
app.py CHANGED
@@ -1,6 +1,5 @@
1
- import gradio as gr
2
- import logging
3
-
4
  from langchain_community.vectorstores import Chroma
5
  from langchain_huggingface import HuggingFaceEmbeddings
6
  from langchain.chains import RetrievalQA
@@ -9,120 +8,98 @@ from langchain.llms import HuggingFacePipeline
9
 
10
  logging.basicConfig(level=logging.INFO)
11
 
12
- # === 加载向量库 ===
13
- embedding_model = HuggingFaceEmbeddings(
14
  model_name="sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
15
  )
16
- vector_store = Chroma(
17
- persist_directory="vector_store",
18
- embedding_function=embedding_model,
19
- )
 
20
 
21
- # === 加载 LLM 模型(openchat) ===
22
- model_id = "openchat/openchat-3.5-0106"
23
  tokenizer = AutoTokenizer.from_pretrained(model_id)
24
  model = AutoModelForCausalLM.from_pretrained(
25
  model_id,
26
- device_map="auto",
27
  torch_dtype="auto",
28
  )
29
- gen_pipe = pipeline(
30
  "text-generation",
31
  model=model,
32
  tokenizer=tokenizer,
33
- max_new_tokens=512,
34
  temperature=0.7,
35
  top_p=0.9,
36
  )
37
- llm = HuggingFacePipeline(pipeline=gen_pipe)
38
 
39
- # === 构建问答链 ===
40
  qa_chain = RetrievalQA.from_chain_type(
41
  llm=llm,
42
  chain_type="stuff",
43
- retriever=vector_store.as_retriever(search_kwargs={"k": 3}),
44
  )
45
 
46
- # === 智能问答函数 ===
47
- def simple_qa(user_query):
48
  if not user_query.strip():
49
  return "⚠️ 请输入学习问题,例如:什么是定积分?"
50
  try:
51
- answer = qa_chain.run(user_query)
52
- return answer
53
  except Exception as e:
54
- logging.error(f"问答失败: {e}")
55
- return f"⚠️ 问答失败,请稍后再试。\n[调试信息] {e}"
56
 
57
- # === 大纲生成函数 ===
58
  def generate_outline(topic: str):
59
  if not topic.strip():
60
- return "⚠️ 请输入章节或主题,例如:高等数学 第六章 定积分", ""
61
  try:
62
- docs = vector_store.as_retriever(search_kwargs={"k": 3}).get_relevant_documents(topic)
63
- snippet = "\n".join([doc.page_content for doc in docs])
64
  prompt = (
65
- f"根据以下内容,为“{topic}”生成大学本科层次的结构化学习大纲,格式如下:\n"
66
- f"一、章节标题\n 1. 节标题\n (1)要点描述\n...\n\n"
67
- f"文档内容:\n{snippet}\n\n学习大纲:"
68
  )
69
- result = llm.generate(prompt).generations[0][0].text.strip()
70
- return result, snippet
71
  except Exception as e:
72
- logging.error(f"大纲生成失败: {e}")
73
- return "⚠️ 抱歉,生成失败,请稍后再试。", ""
74
 
75
- # === 占位函数 ===
76
- def placeholder_fn(*args, **kwargs):
77
- return "功能尚未实现,请等待后续更新。"
78
-
79
- # === Gradio UI ===
80
- with gr.Blocks() as demo:
81
- gr.Markdown("# 智能学习助手 v2.0\n— 大学生专业课学习助手 —")
82
 
 
 
 
83
  with gr.Tabs():
84
- # --- 模块 A:智能问答 ---
85
  with gr.TabItem("智能问答"):
86
- gr.Markdown("> 示例:什么是函数的定义域?")
87
- chatbot = gr.Chatbot()
88
- user_msg = gr.Textbox(placeholder="输入您的学习问题,然后按回车或点击发送")
89
- send_btn = gr.Button("发送")
90
-
91
- def update_chat(message, chat_history):
92
- reply = simple_qa(message)
93
- chat_history.append((message, reply))
94
- return "", chat_history
95
-
96
- send_btn.click(update_chat, inputs=[user_msg, chatbot], outputs=[user_msg, chatbot])
97
- user_msg.submit(update_chat, inputs=[user_msg, chatbot], outputs=[user_msg, chatbot])
98
 
99
- # --- 模块 B:生成学习大纲 ---
100
  with gr.TabItem("生成学习大纲"):
101
- gr.Markdown("> 示例:高等数学 第六章 定积分")
102
- topic_input = gr.Textbox(label="章节主题", placeholder="请输入章节名")
103
- outline_output = gr.Textbox(label="系统生成的大纲", lines=12)
104
- snippet_output = gr.Textbox(label="[调试] 检索片段展示", lines=6)
105
- gen_outline_btn = gr.Button("生成大纲")
106
- gen_outline_btn.click(fn=generate_outline, inputs=topic_input, outputs=[outline_output, snippet_output])
107
 
108
- # --- 模块 C:自动出题(占位) ---
109
  with gr.TabItem("自动出题"):
110
- gr.Markdown("(出题模块,待开发)")
111
- topic2 = gr.Textbox(label="知识点/主题", placeholder="如:高数 第三章 多元函数")
112
- difficulty2 = gr.Dropdown(choices=["简单", "中等", "困难"], label="难度")
113
- count2 = gr.Slider(1, 10, step=1, label="题目数量")
114
- gen_q_btn = gr.Button("开始出题")
115
- gen_q_btn.click(placeholder_fn, inputs=[topic2, difficulty2, count2], outputs=topic2)
116
 
117
- # --- 模块 D:答案批改(占位) ---
118
  with gr.TabItem("答案批改"):
119
- gr.Markdown("(批改模块,待开发)")
120
- std_ans = gr.Textbox(label="标准答案", lines=5)
121
- user_ans = gr.Textbox(label="您的作答", lines=5)
122
- grade_btn = gr.Button("开始批改")
123
- grade_btn.click(placeholder_fn, inputs=[user_ans, std_ans], outputs=user_ans)
124
 
125
- gr.Markdown("---\n由 HuggingFace 提供支持版本 2.0")
126
 
127
  if __name__ == "__main__":
128
  demo.launch()
 
1
+ # app.py ── 2025-06-08 适配 HuggingFace CPU Space
2
+ import os, logging, gradio as gr
 
3
  from langchain_community.vectorstores import Chroma
4
  from langchain_huggingface import HuggingFaceEmbeddings
5
  from langchain.chains import RetrievalQA
 
8
 
9
  logging.basicConfig(level=logging.INFO)
10
 
11
+ # ========= 1. 载入本地向量库 =========
12
+ embedder = HuggingFaceEmbeddings(
13
  model_name="sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
14
  )
15
+ VEC_DIR = "vector_store"
16
+ if not (os.path.isdir(VEC_DIR) and os.path.isfile(f"{VEC_DIR}/chroma.sqlite3")):
17
+ raise RuntimeError(f"❌ 未找到完整向量库 {VEC_DIR},请先执行 build_vector_store.py")
18
+
19
+ vectordb = Chroma(persist_directory=VEC_DIR, embedding_function=embedder)
20
 
21
+ # ========= 2. 载入轻量 LLM =========
22
+ model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" # 1.1B CPU 可跑
23
  tokenizer = AutoTokenizer.from_pretrained(model_id)
24
  model = AutoModelForCausalLM.from_pretrained(
25
  model_id,
26
+ device_map="auto", # 需要 requirements 里有 accelerate
27
  torch_dtype="auto",
28
  )
29
+ generator = pipeline(
30
  "text-generation",
31
  model=model,
32
  tokenizer=tokenizer,
33
+ max_new_tokens=256,
34
  temperature=0.7,
35
  top_p=0.9,
36
  )
37
+ llm = HuggingFacePipeline(pipeline=generator)
38
 
39
+ # ========= 3. 构建 RAG 问答链 =========
40
  qa_chain = RetrievalQA.from_chain_type(
41
  llm=llm,
42
  chain_type="stuff",
43
+ retriever=vectordb.as_retriever(search_kwargs={"k": 3}),
44
  )
45
 
46
+ # ========= 4. 业务函数 =========
47
+ def simple_qa(user_query: str):
48
  if not user_query.strip():
49
  return "⚠️ 请输入学习问题,例如:什么是定积分?"
50
  try:
51
+ return qa_chain.run(user_query)
 
52
  except Exception as e:
53
+ logging.error(f"[QA ERROR] {e}")
54
+ return f"⚠️ 问答失败:{e}"
55
 
 
56
  def generate_outline(topic: str):
57
  if not topic.strip():
58
+ return "⚠️ 请输入章节或主题", ""
59
  try:
60
+ docs = vectordb.as_retriever(search_kwargs={"k": 3}).get_relevant_documents(topic)
61
+ snippet = "\n---\n".join(d.page_content for d in docs)
62
  prompt = (
63
+ f"请基于以下资料,为「{topic}」生成结构化学习大纲,格式:\n"
64
+ f"一、章节标题\n 1. 节标题\n (1)要点…\n\n"
65
+ f"资料:\n{snippet}\n\n大纲:"
66
  )
67
+ outline = llm.invoke(prompt).strip()
68
+ return outline, snippet
69
  except Exception as e:
70
+ logging.error(f"[OUTLINE ERROR] {e}")
71
+ return "⚠️ 生成失败", ""
72
 
73
+ def placeholder(*_):
74
+ return "功能待开发…"
 
 
 
 
 
75
 
76
+ # ========= 5. Gradio UI =========
77
+ with gr.Blocks(title="智能学习助手") as demo:
78
+ gr.Markdown("# 智能学习助手 v2.0\n💡 大学生专业课 RAG Demo")
79
  with gr.Tabs():
 
80
  with gr.TabItem("智能问答"):
81
+ chatbot = gr.Chatbot(height=350)
82
+ msg = gr.Textbox(placeholder="在此提问…")
83
+ def chat(m, hist):
84
+ ans = simple_qa(m)
85
+ hist.append((m, ans))
86
+ return "", hist
87
+ msg.submit(chat, [msg, chatbot], [msg, chatbot])
 
 
 
 
 
88
 
 
89
  with gr.TabItem("生成学习大纲"):
90
+ topic = gr.Textbox(label="章节主题", placeholder="高等数学 第六章 定积分")
91
+ outline = gr.Textbox(label="学习大纲", lines=12)
92
+ debug = gr.Textbox(label="调试:检索片段", lines=6)
93
+ gen = gr.Button("生成")
94
+ gen.click(generate_outline, [topic], [outline, debug])
 
95
 
 
96
  with gr.TabItem("自动出题"):
97
+ placeholder(label="待开发")
 
 
 
 
 
98
 
 
99
  with gr.TabItem("答案批改"):
100
+ placeholder(label="待开发")
 
 
 
 
101
 
102
+ gr.Markdown("---\nPowered by LangChainTinyLlama • Chroma")
103
 
104
  if __name__ == "__main__":
105
  demo.launch()