ljy5946 commited on
Commit
cfa0432
·
verified ·
1 Parent(s): 15dcd94

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -111
app.py CHANGED
@@ -1,131 +1,56 @@
1
- import logging
2
  import gradio as gr
3
- import torch
4
- from langchain_chroma import Chroma
5
- from langchain_huggingface import HuggingFaceEmbeddings
6
- from langchain_community.llms import HuggingFacePipeline
7
  from langchain.chains import RetrievalQA
8
- from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
 
9
 
10
- logging.basicConfig(level=logging.INFO)
 
 
11
 
12
- # 1. Load vector store
13
- embedding_model = HuggingFaceEmbeddings(
14
- model_name="sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
15
- )
16
- vector_store = Chroma(
17
- persist_directory="vector_store",
18
- embedding_function=embedding_model,
19
- )
20
 
21
- # 2. Load lightweight LLM (Phi-2)
22
- model_id = "microsoft/phi-2"
23
- tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
24
- model = AutoModelForCausalLM.from_pretrained(
25
- model_id,
26
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
27
- device_map="auto",
28
- )
29
- gen_pipe = pipeline(
30
- task="text-generation",
31
- model=model,
32
- tokenizer=tokenizer,
33
- max_new_tokens=256,
34
- temperature=0.5,
35
- top_p=0.9,
36
- do_sample=True,
37
  )
38
- llm = HuggingFacePipeline(pipeline=gen_pipe)
39
 
40
- # 3. Build RAG QA chain
41
- retriever = vector_store.as_retriever(search_kwargs={"k": 3})
42
  qa_chain = RetrievalQA.from_chain_type(
43
  llm=llm,
44
  chain_type="stuff",
45
  retriever=retriever,
 
46
  )
47
 
48
- # 4. Business functions
49
- def simple_qa(user_query: str) -> str:
50
- if not user_query.strip():
51
- return "⚠️ 请输入学习问题,例如:什么是定积分?"
52
  try:
53
- return qa_chain.run(user_query)
54
- except Exception as e:
55
- logging.exception("问答失败:%s", e)
56
- return f"⚠️ 问答失败,请稍后再试。\n[调试信息] {e}"
57
-
58
- def generate_outline(topic: str):
59
- if not topic.strip():
60
- yield "⚠️ 请输入章节或主题,例如:高等数学 第六章 定积分", ""
61
- return
62
-
63
- yield "⌛ 正在检索/生成,请稍候…", ""
64
-
65
- try:
66
- docs = retriever.get_relevant_documents(topic)
67
- if not docs:
68
- yield "⚠️ 没有找到相关内容,请换个关键词试试。", ""
69
- return
70
-
71
- snippet = "\n".join(d.page_content for d in docs)
72
- prompt = (
73
- f"根据以下内容,为“{topic}”生成大学本科层次的结构化学习大纲,格式示例:\n"
74
- f"一、章节标题\n 1. 节标题\n (1)要点描述\n...\n\n"
75
- f"文档内容:\n{snippet}\n\n学习大纲:"
76
  )
77
- raw = gen_pipe(prompt, max_new_tokens=512)[0]["generated_text"]
78
- outline = raw.split("学习大纲:")[-1].strip()
79
- yield outline, snippet
80
  except Exception as e:
81
- logging.exception("大纲生成失败:%s", e)
82
- yield "⚠️ 抱歉,生成失败,请稍后再试。", ""
83
-
84
- def placeholder_fn(*args, **kwargs):
85
- return "功能尚未实现,请等待后续更新。"
86
-
87
- # 5. Gradio UI
88
- with gr.Blocks(title="智能学习助手", theme=gr.themes.Base()) as demo:
89
- gr.Markdown("# 📚 智能学习助手 v2.0\n— 专业课向量问答与大纲生成 —")
90
-
91
- with gr.Tabs():
92
- # Chat tab
93
- with gr.TabItem("💬 智能问答"):
94
- chatbot = gr.Chatbot(show_label=False, height=400)
95
- user_msg = gr.Textbox(placeholder="输入学习问题", show_label=False)
96
- send_btn = gr.Button("发送", variant="primary")
97
-
98
- def chat_flow(message, history):
99
- history.append((message, "🤔 正在思考中,请稍后…"))
100
- yield "", history
101
- ans = simple_qa(message)
102
- history[-1] = (message, ans)
103
- yield "", history
104
-
105
- send_btn.click(chat_flow, [user_msg, chatbot], [user_msg, chatbot])
106
- user_msg.submit(chat_flow, [user_msg, chatbot], [user_msg, chatbot])
107
-
108
- # Outline tab
109
- with gr.TabItem("📝 生成学习大纲"):
110
- topic_in = gr.Textbox(label="章节主题", placeholder="例如:定积分")
111
- outline_out = gr.Textbox(label="系统生成的大纲", lines=12)
112
- snippet_out = gr.Textbox(label="[调试] 检索片段", lines=6, visible=False)
113
- gen_btn = gr.Button("生成大纲", variant="primary")
114
- gen_btn.click(generate_outline, inputs=topic_in, outputs=[outline_out, snippet_out])
115
 
116
- # Placeholder tabs
117
- with gr.TabItem(" 自动出题"):
118
- gr.Textbox(label="知识点").render()
119
- gr.Dropdown(["简单", "中等", "困难"], label="难度").render()
120
- gr.Slider(1, 10, step=1, label="题目数量").render()
121
- gr.Button("开始出题").click(placeholder_fn, [], [])
122
 
123
- with gr.TabItem("✅ 答案批改"):
124
- gr.Textbox(label="标准答案", lines=4).render()
125
- gr.Textbox(label="学生答案", lines=4).render()
126
- gr.Button("开始批改").click(placeholder_fn, [], [])
127
 
128
- gr.Markdown("---\n模型:Phi-2 + 向量库检索 | Powered by Hugging Face Spaces")
129
 
130
- if __name__ == "__main__":
131
- demo.launch()
 
1
+ import os
2
  import gradio as gr
3
+ from langchain.vectorstores import Chroma
4
+ from langchain.embeddings import HuggingFaceEmbeddings
 
 
5
  from langchain.chains import RetrievalQA
6
+ from transformers import pipeline
7
+ from langchain.llms import HuggingFacePipeline
8
 
9
+ # 设置路径
10
+ VECTOR_STORE_DIR = "./vector_store"
11
+ MODEL_NAME = "uer/gpt2-chinese-cluecorpussmall"
12
 
13
+ # 设置 LLM 和检索器
14
+ print("🔧 加载生成模型...")
15
+ gen_pipe = pipeline("text-generation", model=MODEL_NAME, max_new_tokens=256)
16
+ llm = HuggingFacePipeline(pipeline=gen_pipe)
 
 
 
 
17
 
18
+ print("📚 加载向量库...")
19
+ embeddings = HuggingFaceEmbeddings(
20
+ model_name="sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  )
22
+ vectordb = Chroma(persist_directory=VECTOR_STORE_DIR, embedding_function=embeddings)
23
 
24
+ retriever = vectordb.as_retriever(search_kwargs={"k": 3})
 
25
  qa_chain = RetrievalQA.from_chain_type(
26
  llm=llm,
27
  chain_type="stuff",
28
  retriever=retriever,
29
+ return_source_documents=True
30
  )
31
 
32
+ def qa_fn(query):
33
+ if not query.strip():
34
+ return "❌ 请输入问题内容。"
 
35
  try:
36
+ result = qa_chain({"query": query})
37
+ answer = result["result"]
38
+ sources = result.get("source_documents", [])
39
+ sources_text = "\n\n".join(
40
+ [f"【片段 {i+1}】\n" + doc.page_content for i, doc in enumerate(sources)]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  )
42
+ return f"📌 回答:{answer.strip()}\n\n📚 参考:\n{sources_text}"
 
 
43
  except Exception as e:
44
+ return f" 出现错误:{str(e)}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
+ with gr.Blocks(title="数学知识问答助手") as demo:
47
+ gr.Markdown("## 📘 数学知识问答助手\n输入教材相关问题,例如:“什么是函数的定义域?”")
48
+ with gr.Row():
49
+ query_input = gr.Textbox(label="问题", placeholder="请输入你的问题", lines=2)
50
+ output_box = gr.Textbox(label="回答", lines=15)
51
+ submit_btn = gr.Button("提问")
52
 
53
+ submit_btn.click(fn=qa_fn, inputs=query_input, outputs=output_box)
 
 
 
54
 
55
+ demo.launch()
56