File size: 3,871 Bytes
d4ceb14
 
536e921
2999a2c
536e921
 
 
 
 
 
d4ceb14
 
536e921
 
d4ceb14
 
 
 
 
536e921
d4ceb14
 
536e921
 
 
d4ceb14
536e921
 
d4ceb14
536e921
 
 
d4ceb14
536e921
 
 
d4ceb14
536e921
d4ceb14
536e921
 
 
d4ceb14
536e921
 
d4ceb14
 
536e921
 
 
d4ceb14
536e921
d4ceb14
 
ac710d7
14517da
 
d4ceb14
14517da
d4ceb14
 
14517da
d4ceb14
 
 
14517da
d4ceb14
 
14517da
d4ceb14
 
14517da
d4ceb14
 
ac710d7
d4ceb14
 
 
ac710d7
 
d4ceb14
 
 
 
 
 
 
14517da
ac710d7
d4ceb14
 
 
 
 
ac710d7
 
d4ceb14
ac710d7
 
d4ceb14
536e921
d4ceb14
ac710d7
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
# app.py  ── 2025-06-08 适配 HuggingFace CPU Space
import os, logging, gradio as gr
from langchain_community.vectorstores import Chroma
from langchain_huggingface import HuggingFaceEmbeddings
from langchain.chains import RetrievalQA
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from langchain.llms import HuggingFacePipeline

logging.basicConfig(level=logging.INFO)

# ========= 1. 载入本地向量库 =========
embedder = HuggingFaceEmbeddings(
    model_name="sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
)
VEC_DIR = "vector_store"
if not (os.path.isdir(VEC_DIR) and os.path.isfile(f"{VEC_DIR}/chroma.sqlite3")):
    raise RuntimeError(f"❌ 未找到完整向量库 {VEC_DIR},请先执行 build_vector_store.py")

vectordb = Chroma(persist_directory=VEC_DIR, embedding_function=embedder)

# ========= 2. 载入轻量 LLM =========
model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"     # 1.1B CPU 可跑
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_map="auto",          # 需要 requirements 里有 accelerate
    torch_dtype="auto",
)
generator = pipeline(
    "text-generation",
    model=model,
    tokenizer=tokenizer,
    max_new_tokens=256,
    temperature=0.7,
    top_p=0.9,
)
llm = HuggingFacePipeline(pipeline=generator)

# ========= 3. 构建 RAG 问答链 =========
qa_chain = RetrievalQA.from_chain_type(
    llm=llm,
    chain_type="stuff",
    retriever=vectordb.as_retriever(search_kwargs={"k": 3}),
)

# ========= 4. 业务函数 =========
def simple_qa(user_query: str):
    if not user_query.strip():
        return "⚠️ 请输入学习问题,例如:什么是定积分?"
    try:
        return qa_chain.run(user_query)
    except Exception as e:
        logging.error(f"[QA ERROR] {e}")
        return f"⚠️ 问答失败:{e}"

def generate_outline(topic: str):
    if not topic.strip():
        return "⚠️ 请输入章节或主题", ""
    try:
        docs = vectordb.as_retriever(search_kwargs={"k": 3}).get_relevant_documents(topic)
        snippet = "\n---\n".join(d.page_content for d in docs)
        prompt = (
            f"请基于以下资料,为「{topic}」生成结构化学习大纲,格式:\n"
            f"一、章节标题\n  1. 节标题\n    (1)要点…\n\n"
            f"资料:\n{snippet}\n\n大纲:"
        )
        outline = llm.invoke(prompt).strip()
        return outline, snippet
    except Exception as e:
        logging.error(f"[OUTLINE ERROR] {e}")
        return "⚠️ 生成失败", ""

def placeholder(*_):
    return "功能待开发…"

# ========= 5. Gradio UI =========
with gr.Blocks(title="智能学习助手") as demo:
    gr.Markdown("# 智能学习助手 v2.0\n💡 大学生专业课 RAG Demo")
    with gr.Tabs():
        with gr.TabItem("智能问答"):
            chatbot = gr.Chatbot(height=350)
            msg = gr.Textbox(placeholder="在此提问…")
            def chat(m, hist): 
                ans = simple_qa(m)
                hist.append((m, ans))
                return "", hist
            msg.submit(chat, [msg, chatbot], [msg, chatbot])

        with gr.TabItem("生成学习大纲"):
            topic = gr.Textbox(label="章节主题", placeholder="高等数学 第六章 定积分")
            outline = gr.Textbox(label="学习大纲", lines=12)
            debug = gr.Textbox(label="调试:检索片段", lines=6)
            gen = gr.Button("生成")
            gen.click(generate_outline, [topic], [outline, debug])

        with gr.TabItem("自动出题"):
            placeholder(label="待开发")

        with gr.TabItem("答案批改"):
            placeholder(label="待开发")

    gr.Markdown("---\nPowered by LangChain • TinyLlama • Chroma")

if __name__ == "__main__":
    demo.launch()