File size: 4,741 Bytes
bee6ae5
 
 
 
 
 
 
a7b5cc2
1831b73
bee6ae5
e7e03e0
 
d56ced4
 
4e58501
1831b73
bee6ae5
c126f3f
 
 
e7e03e0
d56ced4
 
 
 
 
bee6ae5
d56ced4
bee6ae5
c126f3f
 
 
 
 
 
 
 
 
 
 
 
 
d56ced4
c126f3f
 
d56ced4
e7e03e0
bee6ae5
d56ced4
c126f3f
 
 
d56ced4
bee6ae5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c126f3f
 
 
 
bee6ae5
c126f3f
 
d56ced4
e7e03e0
bee6ae5
c126f3f
cfa0432
 
c126f3f
bee6ae5
c126f3f
bee6ae5
 
c126f3f
 
 
bee6ae5
1831b73
bee6ae5
 
 
cfa0432
c126f3f
bee6ae5
 
 
 
 
 
 
e7e03e0
 
 
1831b73
c126f3f
bee6ae5
d56ced4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import os
import gradio as gr
import torch
import logging

from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import Chroma
from langchain_community.llms import HuggingFacePipeline
from langchain.chains import RetrievalQA
from langchain.prompts import PromptTemplate
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline

from build_index import main as build_index_if_needed  # 确保提交了 build_index.py

logging.basicConfig(level=logging.INFO)

# ─── 配置 ─────────────────────────────────────────────────────
VECTOR_STORE_DIR = "./vector_store"
MODEL_NAME = "uer/gpt2-chinese-cluecorpussmall"
EMBEDDING_MODEL_NAME = "sentence-transformers/paraphrase-multilingual-mpnet-base-v2"

# 如果向量库不存在,自动构建
if not os.path.exists(VECTOR_STORE_DIR) or not os.listdir(VECTOR_STORE_DIR):
    logging.info("向量库不存在,启动自动构建……")
    build_index_if_needed()

# ─── 1. 加载 LLM ────────────────────────────────────────────────
logging.info("🔧 加载生成模型…")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
    device_map="auto",
)
gen_pipe = pipeline(
    task="text-generation",
    model=model,
    tokenizer=tokenizer,
    max_new_tokens=256,
    temperature=0.5,
    top_p=0.9,
    do_sample=True,
    trust_remote_code=True,
)
llm = HuggingFacePipeline(pipeline=gen_pipe)
logging.info("✅ 生成模型加载成功。")

# ─── 2. 加载向量库 ─────────────────────────────────────────────
logging.info("📚 加载向量库…")
embeddings = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL_NAME)
vectordb = Chroma(persist_directory=VECTOR_STORE_DIR, embedding_function=embeddings)
retriever = vectordb.as_retriever(search_kwargs={"k": 3})
logging.info("✅ 向量库加载成功。")

# ─── 3. 自定义 Prompt ─────────────────────────────────────────
prompt_template = PromptTemplate.from_template(
    """你是一位专业的数学助教,请根据以下参考资料回答用户的问题。
如果资料中没有相关内容,请直接回答“我不知道”或“资料中未提及”,不要编造答案。
参考资料:
{context}

用户问题:
{question}

回答(只允许基于参考资料,不要编造):
"""
)

# ─── 4. 构建 RAG 问答链 ───────────────────────────────────────
qa_chain = RetrievalQA.from_chain_type(
    llm=llm,
    chain_type="stuff",
    retriever=retriever,
    chain_type_kwargs={"prompt": prompt_template},
    return_source_documents=True,
)
logging.info("✅ RAG 问答链构建成功。")

# ─── 5. 业务函数 ───────────────────────────────────────────────
def qa_fn(query: str):
    if not query.strip():
        return "❌ 请输入问题内容。"
    result = qa_chain({"query": query})
    answer = result["result"].strip()
    sources = result.get("source_documents", [])
    if not sources:
        return "📌 回答:未在知识库中找到相关内容,请尝试更换问题或补充教材。"
    sources_text = "\n\n".join(
        [f"【片段 {i+1}】\n{doc.page_content}" for i, doc in enumerate(sources)]
    )
    return f"📌 回答:{answer}\n\n📚 参考:\n{sources_text}"

# ─── 6. Gradio 界面 ─────────────────────────────────────────────
with gr.Blocks(title="智能学习助手") as demo:
    gr.Markdown("## 📘 智能学习助手\n输入教材相关问题,例如:“什么是函数的定义域?”")
    with gr.Row():
        query = gr.Textbox(label="问题", placeholder="请输入你的问题", lines=2)
        answer = gr.Textbox(label="回答", lines=12)
    gr.Button("提问").click(fn=qa_fn, inputs=query, outputs=answer)
    gr.Markdown(
        "---\n"
        "模型:UER/GPT2-Chinese-ClueCorpus + Sentence-Transformers RAG  \n"
        "由 Hugging Face Spaces 提供算力支持"
    )

if __name__ == "__main__":
    demo.launch()