File size: 5,184 Bytes
4ded835
 
bee6ae5
4ded835
bee6ae5
 
4ded835
 
bee6ae5
35634c4
 
 
 
 
a7b5cc2
35634c4
1831b73
bee6ae5
e7e03e0
35634c4
d56ced4
4e58501
1831b73
bee6ae5
4ded835
 
 
e7e03e0
35634c4
d56ced4
 
 
 
4ded835
d56ced4
bee6ae5
c126f3f
 
 
 
 
 
 
 
 
 
 
 
 
d56ced4
c126f3f
 
d56ced4
e7e03e0
bee6ae5
d56ced4
c126f3f
4ded835
 
d56ced4
bee6ae5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4ded835
c126f3f
 
35634c4
c126f3f
 
 
4ded835
e7e03e0
bee6ae5
c126f3f
4ded835
cfa0432
4ded835
 
 
 
 
 
 
c126f3f
4ded835
 
bee6ae5
4ded835
 
 
c126f3f
 
 
bee6ae5
1831b73
bee6ae5
 
 
cfa0432
c126f3f
bee6ae5
 
 
 
4ded835
bee6ae5
 
e7e03e0
 
 
1831b73
c126f3f
bee6ae5
d56ced4
4ded835
35634c4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
# app.py

import os
import logging
import gradio as gr
import torch

from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline

# Embeddings 与 VectorStore 用新的分包
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_chroma import Chroma

# LLM 继续用 community 包里的 Pipeline
from langchain_community.llms import HuggingFacePipeline

from langchain.chains import RetrievalQA
from langchain.prompts import PromptTemplate

from build_index import main as build_index_if_needed  # 确保 build_index.py 与 app.py 同目录

logging.basicConfig(level=logging.INFO)

# ─── 配置 ─────────────────────────────────────────────────────
VECTOR_STORE_DIR        = "./vector_store"
MODEL_NAME              = "uer/gpt2-chinese-cluecorpussmall"
EMBEDDING_MODEL_NAME    = "sentence-transformers/paraphrase-multilingual-mpnet-base-v2"

# 容器启动时自动构建向量库(如果 vector_store 目录为空)
if not os.path.exists(VECTOR_STORE_DIR) or not os.listdir(VECTOR_STORE_DIR):
    logging.info("向量库不存在,启动自动构建……")
    build_index_if_needed()

# ─── 1. 加载生成模型 ──────────────────────────────────────────────
logging.info("🔧 加载生成模型…")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
    device_map="auto",
)
gen_pipe = pipeline(
    task="text-generation",
    model=model,
    tokenizer=tokenizer,
    max_new_tokens=256,
    temperature=0.5,
    top_p=0.9,
    do_sample=True,
    trust_remote_code=True,
)
llm = HuggingFacePipeline(pipeline=gen_pipe)
logging.info("✅ 生成模型加载成功。")

# ─── 2. 加载向量库 ─────────────────────────────────────────────
logging.info("📚 加载向量库…")
embeddings = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL_NAME)
vectordb   = Chroma(persist_directory=VECTOR_STORE_DIR, embedding_function=embeddings)
retriever  = vectordb.as_retriever(search_kwargs={"k": 3})
logging.info("✅ 向量库加载成功。")

# ─── 3. 自定义 Prompt ─────────────────────────────────────────
prompt_template = PromptTemplate.from_template(
    """你是一位专业的数学助教,请根据以下参考资料回答用户的问题。
如果资料中没有相关内容,请直接回答“我不知道”或“资料中未提及”,不要编造答案。
参考资料:
{context}

用户问题:
{question}

回答(只允许基于参考资料,不要编造):
"""
)

# ─── 4. 构建 RAG 问答链(map_reduce) ───────────────────────────
qa_chain = RetrievalQA.from_chain_type(
    llm=llm,
    chain_type="map_reduce",               # map_reduce 自动分段、避免超长
    retriever=retriever,
    return_source_documents=True,
)
logging.info("✅ RAG 问答链(map_reduce)构建成功。")

# ─── 5. 业务函数 ───────────────────────────────────────────────
def qa_fn(query: str):
    if not query or not query.strip():
        return "❌ 请输入问题内容。"
    try:
        result = qa_chain({"query": query})
    except Exception as e:
        logging.error(f"问答链运行出错:{e}")
        return "抱歉,问答过程中出现错误,请稍后重试。"

    answer  = result.get("result", "").strip()
    sources = result.get("source_documents", [])
    if not answer:
        return "📌 回答:未生成答案,请稍后再试。"
    if not sources:
        return f"📌 回答:{answer}\n\n(未检索到参考片段)"

    # 拼接参考片段
    sources_text = "\n\n".join(
        [f"【片段 {i+1}】\n{doc.page_content}" for i, doc in enumerate(sources)]
    )
    return f"📌 回答:{answer}\n\n📚 参考:\n{sources_text}"

# ─── 6. Gradio 界面 ─────────────────────────────────────────────
with gr.Blocks(title="智能学习助手") as demo:
    gr.Markdown("## 📘 智能学习助手\n输入教材相关问题,例如:“什么是函数的定义域?”")
    with gr.Row():
        query = gr.Textbox(label="问题", placeholder="请输入你的问题", lines=2)
        answer = gr.Textbox(label="回答", lines=12)
    gr.Button("提问").click(fn=qa_fn, inputs=query, outputs=answer)
    gr.Markdown(
        "---\n"
        "模型:UER/GPT2-Chinese-ClueCorpus + Sentence-Transformers RAG (map_reduce)  \n"
        "由 Hugging Face Spaces 提供算力支持"
    )

if __name__ == "__main__":
    demo.launch()