ljy5946 commited on
Commit
1831b73
·
verified ·
1 Parent(s): 3a2920b

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +131 -0
  2. requirements.txt +9 -0
app.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import gradio as gr
3
+ import torch
4
+ from langchain_chroma import Chroma
5
+ from langchain_huggingface import HuggingFaceEmbeddings
6
+ from langchain_community.llms import HuggingFacePipeline
7
+ from langchain.chains import RetrievalQA
8
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
9
+
10
+ logging.basicConfig(level=logging.INFO)
11
+
12
+ # 1. Load vector store
13
+ embedding_model = HuggingFaceEmbeddings(
14
+ model_name="sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
15
+ )
16
+ vector_store = Chroma(
17
+ persist_directory="vector_store",
18
+ embedding_function=embedding_model,
19
+ )
20
+
21
+ # 2. Load lightweight LLM (Phi-2)
22
+ model_id = "microsoft/phi-2"
23
+ tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
24
+ model = AutoModelForCausalLM.from_pretrained(
25
+ model_id,
26
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
27
+ device_map="auto",
28
+ )
29
+ gen_pipe = pipeline(
30
+ task="text-generation",
31
+ model=model,
32
+ tokenizer=tokenizer,
33
+ max_new_tokens=256,
34
+ temperature=0.5,
35
+ top_p=0.9,
36
+ do_sample=True,
37
+ )
38
+ llm = HuggingFacePipeline(pipeline=gen_pipe)
39
+
40
+ # 3. Build RAG QA chain
41
+ retriever = vector_store.as_retriever(search_kwargs={"k": 3})
42
+ qa_chain = RetrievalQA.from_chain_type(
43
+ llm=llm,
44
+ chain_type="stuff",
45
+ retriever=retriever,
46
+ )
47
+
48
+ # 4. Business functions
49
+ def simple_qa(user_query: str) -> str:
50
+ if not user_query.strip():
51
+ return "⚠️ 请输入学习问题,例如:什么是定积分?"
52
+ try:
53
+ return qa_chain.run(user_query)
54
+ except Exception as e:
55
+ logging.exception("问答失败:%s", e)
56
+ return f"⚠️ 问答失败,请稍后再试。\n[调试信息] {e}"
57
+
58
+ def generate_outline(topic: str):
59
+ if not topic.strip():
60
+ yield "⚠️ 请输入章节或主题,例如:高等数学 第六章 定积分", ""
61
+ return
62
+
63
+ yield "⌛ 正在检索/生成,请稍候…", ""
64
+
65
+ try:
66
+ docs = retriever.get_relevant_documents(topic)
67
+ if not docs:
68
+ yield "⚠️ 没有找到相关内容,请换个关键词试试。", ""
69
+ return
70
+
71
+ snippet = "\n".join(d.page_content for d in docs)
72
+ prompt = (
73
+ f"根据以下内容,为“{topic}”生成大学本科层次的结构化学习大纲,格式示例:\n"
74
+ f"一、章节标题\n 1. 节标题\n (1)要点描述\n...\n\n"
75
+ f"文档内容:\n{snippet}\n\n学习大纲:"
76
+ )
77
+ raw = gen_pipe(prompt, max_new_tokens=512)[0]["generated_text"]
78
+ outline = raw.split("学习大纲:")[-1].strip()
79
+ yield outline, snippet
80
+ except Exception as e:
81
+ logging.exception("大纲生成失败:%s", e)
82
+ yield "⚠️ 抱歉,生成失败,请稍后再试。", ""
83
+
84
+ def placeholder_fn(*args, **kwargs):
85
+ return "功能尚未实现,请等待后续更新。"
86
+
87
+ # 5. Gradio UI
88
+ with gr.Blocks(title="智能学习助手", theme=gr.themes.Base()) as demo:
89
+ gr.Markdown("# 📚 智能学习助手 v2.0\n— 专业课向量问答与大纲生成 —")
90
+
91
+ with gr.Tabs():
92
+ # Chat tab
93
+ with gr.TabItem("💬 智能问答"):
94
+ chatbot = gr.Chatbot(show_label=False, height=400)
95
+ user_msg = gr.Textbox(placeholder="输入学习问题", show_label=False)
96
+ send_btn = gr.Button("发送", variant="primary")
97
+
98
+ def chat_flow(message, history):
99
+ history.append((message, "🤔 正在思考中,请稍后…"))
100
+ yield "", history
101
+ ans = simple_qa(message)
102
+ history[-1] = (message, ans)
103
+ yield "", history
104
+
105
+ send_btn.click(chat_flow, [user_msg, chatbot], [user_msg, chatbot])
106
+ user_msg.submit(chat_flow, [user_msg, chatbot], [user_msg, chatbot])
107
+
108
+ # Outline tab
109
+ with gr.TabItem("📝 生成学习大纲"):
110
+ topic_in = gr.Textbox(label="章节主题", placeholder="例如:定积分")
111
+ outline_out = gr.Textbox(label="系统生成的大纲", lines=12)
112
+ snippet_out = gr.Textbox(label="[调试] 检索片段", lines=6, visible=False)
113
+ gen_btn = gr.Button("生成大纲", variant="primary")
114
+ gen_btn.click(generate_outline, inputs=topic_in, outputs=[outline_out, snippet_out])
115
+
116
+ # Placeholder tabs
117
+ with gr.TabItem("❓ 自动出题"):
118
+ gr.Textbox(label="知识点").render()
119
+ gr.Dropdown(["简单", "中等", "困难"], label="难度").render()
120
+ gr.Slider(1, 10, step=1, label="题目数量").render()
121
+ gr.Button("开始出题").click(placeholder_fn, [], [])
122
+
123
+ with gr.TabItem("✅ 答案批改"):
124
+ gr.Textbox(label="标准答案", lines=4).render()
125
+ gr.Textbox(label="学生答案", lines=4).render()
126
+ gr.Button("开始批改").click(placeholder_fn, [], [])
127
+
128
+ gr.Markdown("---\n模型:Phi-2 + 向量库检索 | Powered by Hugging Face Spaces")
129
+
130
+ if __name__ == "__main__":
131
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ langchain>=0.2.9
2
+ langchain-huggingface>=0.0.6
3
+ langchain-chroma>=0.0.6
4
+ chromadb>=0.4.24
5
+ transformers>=4.40.0
6
+ accelerate
7
+ torch>=2.1.0
8
+ gradio>=4.24.0
9
+ sentence-transformers