ljy5946 commited on
Commit
853beb6
·
verified ·
1 Parent(s): 75d6e02

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +79 -52
app.py CHANGED
@@ -1,105 +1,132 @@
1
- # app.py ── 2025-06-08 适配 HuggingFace CPU Space
2
- import os, logging, gradio as gr
 
 
3
  from langchain_community.vectorstores import Chroma
4
  from langchain_huggingface import HuggingFaceEmbeddings
5
  from langchain.chains import RetrievalQA
 
6
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
7
  from langchain.llms import HuggingFacePipeline
8
 
9
  logging.basicConfig(level=logging.INFO)
10
 
11
- # ========= 1. 载入本地向量库 =========
12
- embedder = HuggingFaceEmbeddings(
13
  model_name="sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
14
  )
15
- VEC_DIR = "vector_store"
16
- if not (os.path.isdir(VEC_DIR) and os.path.isfile(f"{VEC_DIR}/chroma.sqlite3")):
17
- raise RuntimeError(f"❌ 未找到完整向量库 {VEC_DIR},请先执行 build_vector_store.py")
18
-
19
- vectordb = Chroma(persist_directory=VEC_DIR, embedding_function=embedder)
20
 
21
- # ========= 2. 载入轻量 LLM =========
22
- model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" # 1.1B CPU 可跑
23
  tokenizer = AutoTokenizer.from_pretrained(model_id)
24
  model = AutoModelForCausalLM.from_pretrained(
25
  model_id,
26
- device_map="auto", # 需要 requirements 里有 accelerate
27
  torch_dtype="auto",
28
  )
29
- generator = pipeline(
30
  "text-generation",
31
  model=model,
32
  tokenizer=tokenizer,
33
- max_new_tokens=256,
34
  temperature=0.7,
35
  top_p=0.9,
36
  )
37
- llm = HuggingFacePipeline(pipeline=generator)
38
 
39
- # ========= 3. 构建 RAG 问答链 =========
40
  qa_chain = RetrievalQA.from_chain_type(
41
  llm=llm,
42
  chain_type="stuff",
43
- retriever=vectordb.as_retriever(search_kwargs={"k": 3}),
44
  )
45
 
46
- # ========= 4. 业务函数 =========
47
- def simple_qa(user_query: str):
48
  if not user_query.strip():
49
  return "⚠️ 请输入学习问题,例如:什么是定积分?"
50
  try:
51
- return qa_chain.run(user_query)
 
52
  except Exception as e:
53
- logging.error(f"[QA ERROR] {e}")
54
- return f"⚠️ 问答失败:{e}"
55
 
 
56
  def generate_outline(topic: str):
57
  if not topic.strip():
58
- return "⚠️ 请输入章节或主题", ""
59
  try:
60
- docs = vectordb.as_retriever(search_kwargs={"k": 3}).get_relevant_documents(topic)
61
- snippet = "\n---\n".join(d.page_content for d in docs)
 
 
62
  prompt = (
63
- f"请基于以下资料,为「{topic}」生成结构化学习大纲,格式:\n"
64
- f"一、章节标题\n 1. 节标题\n (1)要点…\n\n"
65
- f"资料:\n{snippet}\n\n大纲:"
66
  )
67
- outline = llm.invoke(prompt).strip()
68
- return outline, snippet
69
  except Exception as e:
70
- logging.error(f"[OUTLINE ERROR] {e}")
71
- return "⚠️ 生成失败", ""
72
 
73
- def placeholder(*_):
74
- return "功能待开发…"
 
 
 
 
 
75
 
76
- # ========= 5. Gradio UI =========
77
- with gr.Blocks(title="智能学习助手") as demo:
78
- gr.Markdown("# 智能学习助手 v2.0\n💡 大学生专业课 RAG Demo")
79
  with gr.Tabs():
 
80
  with gr.TabItem("智能问答"):
81
- chatbot = gr.Chatbot(height=350)
82
- msg = gr.Textbox(placeholder="在此提问…")
83
- def chat(m, hist):
84
- ans = simple_qa(m)
85
- hist.append((m, ans))
86
- return "", hist
87
- msg.submit(chat, [msg, chatbot], [msg, chatbot])
 
 
 
 
 
88
 
 
89
  with gr.TabItem("生成学习大纲"):
90
- topic = gr.Textbox(label="章节主题", placeholder="高等数学 第六章 定积分")
91
- outline = gr.Textbox(label="学习大纲", lines=12)
92
- debug = gr.Textbox(label="调试:检索片段", lines=6)
93
- gen = gr.Button("生成")
94
- gen.click(generate_outline, [topic], [outline, debug])
 
95
 
 
96
  with gr.TabItem("自动出题"):
97
- placeholder(label="待开发")
 
 
 
 
 
98
 
 
99
  with gr.TabItem("答案批改"):
100
- placeholder(label="待开发")
 
 
 
 
101
 
102
- gr.Markdown("---\nPowered by LangChainTinyLlama • Chroma")
103
 
104
  if __name__ == "__main__":
105
  demo.launch()
 
1
+ # app.py
2
+ import gradio as gr
3
+ import logging
4
+
5
  from langchain_community.vectorstores import Chroma
6
  from langchain_huggingface import HuggingFaceEmbeddings
7
  from langchain.chains import RetrievalQA
8
+ from langchain.schema import Document
9
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
10
  from langchain.llms import HuggingFacePipeline
11
 
12
  logging.basicConfig(level=logging.INFO)
13
 
14
+ # === 加载向量库 ===
15
+ embedding_model = HuggingFaceEmbeddings(
16
  model_name="sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
17
  )
18
+ vector_store = Chroma(
19
+ persist_directory="vector_store",
20
+ embedding_function=embedding_model,
21
+ )
 
22
 
23
+ # === 加载 LLM 模型(openchat) ===
24
+ model_id = "openchat/openchat-3.5-0106"
25
  tokenizer = AutoTokenizer.from_pretrained(model_id)
26
  model = AutoModelForCausalLM.from_pretrained(
27
  model_id,
28
+ device_map="auto",
29
  torch_dtype="auto",
30
  )
31
+ gen_pipe = pipeline(
32
  "text-generation",
33
  model=model,
34
  tokenizer=tokenizer,
35
+ max_new_tokens=512,
36
  temperature=0.7,
37
  top_p=0.9,
38
  )
39
+ llm = HuggingFacePipeline(pipeline=gen_pipe)
40
 
41
+ # === 构建问答链 ===
42
  qa_chain = RetrievalQA.from_chain_type(
43
  llm=llm,
44
  chain_type="stuff",
45
+ retriever=vector_store.as_retriever(search_kwargs={"k": 3}),
46
  )
47
 
48
+ # === 智能问答函数 ===
49
+ def simple_qa(user_query):
50
  if not user_query.strip():
51
  return "⚠️ 请输入学习问题,例如:什么是定积分?"
52
  try:
53
+ answer = qa_chain.run(user_query)
54
+ return answer
55
  except Exception as e:
56
+ logging.error(f"问答失败: {e}")
57
+ return f"⚠️ 问答失败,请稍后再试。\n[调试信息] {e}"
58
 
59
+ # === 大纲生成函数 ===
60
  def generate_outline(topic: str):
61
  if not topic.strip():
62
+ return "⚠️ 请输入章节或主题,例如:高等数学 第六章 定积分", ""
63
  try:
64
+ docs = vector_store.as_retriever(search_kwargs={"k": 3}).get_relevant_documents(topic)
65
+ if not docs:
66
+ return "⚠️ 没有找到相关内容,请换个关键词试试。", ""
67
+ snippet = "\n".join(doc.page_content for doc in docs)
68
  prompt = (
69
+ f"根据以下内容,为“{topic}”生成大学本科层次的结构化学习大纲,格式如下:\n"
70
+ f"一、章节标题\n 1. 节标题\n (1)要点描述\n...\n\n"
71
+ f"文档内容:\n{snippet}\n\n学习大纲:"
72
  )
73
+ result = llm.generate(prompt).generations[0][0].text.strip()
74
+ return result, snippet
75
  except Exception as e:
76
+ logging.error(f"大纲生成失败: {e}")
77
+ return "⚠️ 抱歉,生成失败,请稍后再试。", ""
78
 
79
+ # === 占位函数 ===
80
+ def placeholder_fn(*args, **kwargs):
81
+ return "功能尚未实现,请等待后续更新。"
82
+
83
+ # === Gradio UI ===
84
+ with gr.Blocks() as demo:
85
+ gr.Markdown("# 智能学习助手 v2.0\n— 大学生专业课学习助手 —")
86
 
 
 
 
87
  with gr.Tabs():
88
+ # --- 模块 A:智能问答 ---
89
  with gr.TabItem("智能问答"):
90
+ gr.Markdown("> 示例:什么是函数的定义域?")
91
+ chatbot = gr.Chatbot()
92
+ user_msg = gr.Textbox(placeholder="输入您的学习问题,然后按回车或点击发送")
93
+ send_btn = gr.Button("发送")
94
+
95
+ def update_chat(message, chat_history):
96
+ reply = simple_qa(message)
97
+ chat_history.append((message, reply))
98
+ return "", chat_history
99
+
100
+ send_btn.click(update_chat, inputs=[user_msg, chatbot], outputs=[user_msg, chatbot])
101
+ user_msg.submit(update_chat, inputs=[user_msg, chatbot], outputs=[user_msg, chatbot])
102
 
103
+ # --- 模块 B:生成学习大纲 ---
104
  with gr.TabItem("生成学习大纲"):
105
+ gr.Markdown("> 示例:高等数学 第六章 定积分")
106
+ topic_input = gr.Textbox(label="章节主题", placeholder="请输入章节名")
107
+ outline_output = gr.Textbox(label="系统生成的大纲", lines=12)
108
+ snippet_output = gr.Textbox(label="[调试] 检索片段展示", lines=6)
109
+ gen_outline_btn = gr.Button("生成大纲")
110
+ gen_outline_btn.click(fn=generate_outline, inputs=topic_input, outputs=[outline_output, snippet_output])
111
 
112
+ # --- 模块 C:自动出题(占位) ---
113
  with gr.TabItem("自动出题"):
114
+ gr.Markdown("(出题模块,待开发)")
115
+ topic2 = gr.Textbox(label="知识点/主题", placeholder="如:高数 第三章 多元函数")
116
+ difficulty2 = gr.Dropdown(choices=["简单", "中等", "困难"], label="难度")
117
+ count2 = gr.Slider(1, 10, step=1, label="题目数量")
118
+ gen_q_btn = gr.Button("开始出题")
119
+ gen_q_btn.click(placeholder_fn, inputs=[topic2, difficulty2, count2], outputs=topic2)
120
 
121
+ # --- 模块 D:答案批改(占位) ---
122
  with gr.TabItem("答案批改"):
123
+ gr.Markdown("(批改模块,待开发)")
124
+ std_ans = gr.Textbox(label="标准答案", lines=5)
125
+ user_ans = gr.Textbox(label="您的作答", lines=5)
126
+ grade_btn = gr.Button("开始批改")
127
+ grade_btn.click(placeholder_fn, inputs=[user_ans, std_ans], outputs=user_ans)
128
 
129
+ gr.Markdown("---\n由 HuggingFace 提供支持版本 2.0")
130
 
131
  if __name__ == "__main__":
132
  demo.launch()