Spaces:
Sleeping
Sleeping
File size: 4,741 Bytes
bee6ae5 a7b5cc2 1831b73 bee6ae5 e7e03e0 d56ced4 4e58501 1831b73 bee6ae5 c126f3f e7e03e0 d56ced4 bee6ae5 d56ced4 bee6ae5 c126f3f d56ced4 c126f3f d56ced4 e7e03e0 bee6ae5 d56ced4 c126f3f d56ced4 bee6ae5 c126f3f bee6ae5 c126f3f d56ced4 e7e03e0 bee6ae5 c126f3f cfa0432 c126f3f bee6ae5 c126f3f bee6ae5 c126f3f bee6ae5 1831b73 bee6ae5 cfa0432 c126f3f bee6ae5 e7e03e0 1831b73 c126f3f bee6ae5 d56ced4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 |
import os
import gradio as gr
import torch
import logging
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import Chroma
from langchain_community.llms import HuggingFacePipeline
from langchain.chains import RetrievalQA
from langchain.prompts import PromptTemplate
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from build_index import main as build_index_if_needed # 确保提交了 build_index.py
logging.basicConfig(level=logging.INFO)
# ─── 配置 ─────────────────────────────────────────────────────
VECTOR_STORE_DIR = "./vector_store"
MODEL_NAME = "uer/gpt2-chinese-cluecorpussmall"
EMBEDDING_MODEL_NAME = "sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
# 如果向量库不存在,自动构建
if not os.path.exists(VECTOR_STORE_DIR) or not os.listdir(VECTOR_STORE_DIR):
logging.info("向量库不存在,启动自动构建……")
build_index_if_needed()
# ─── 1. 加载 LLM ────────────────────────────────────────────────
logging.info("🔧 加载生成模型…")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
device_map="auto",
)
gen_pipe = pipeline(
task="text-generation",
model=model,
tokenizer=tokenizer,
max_new_tokens=256,
temperature=0.5,
top_p=0.9,
do_sample=True,
trust_remote_code=True,
)
llm = HuggingFacePipeline(pipeline=gen_pipe)
logging.info("✅ 生成模型加载成功。")
# ─── 2. 加载向量库 ─────────────────────────────────────────────
logging.info("📚 加载向量库…")
embeddings = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL_NAME)
vectordb = Chroma(persist_directory=VECTOR_STORE_DIR, embedding_function=embeddings)
retriever = vectordb.as_retriever(search_kwargs={"k": 3})
logging.info("✅ 向量库加载成功。")
# ─── 3. 自定义 Prompt ─────────────────────────────────────────
prompt_template = PromptTemplate.from_template(
"""你是一位专业的数学助教,请根据以下参考资料回答用户的问题。
如果资料中没有相关内容,请直接回答“我不知道”或“资料中未提及”,不要编造答案。
参考资料:
{context}
用户问题:
{question}
回答(只允许基于参考资料,不要编造):
"""
)
# ─── 4. 构建 RAG 问答链 ───────────────────────────────────────
qa_chain = RetrievalQA.from_chain_type(
llm=llm,
chain_type="stuff",
retriever=retriever,
chain_type_kwargs={"prompt": prompt_template},
return_source_documents=True,
)
logging.info("✅ RAG 问答链构建成功。")
# ─── 5. 业务函数 ───────────────────────────────────────────────
def qa_fn(query: str):
if not query.strip():
return "❌ 请输入问题内容。"
result = qa_chain({"query": query})
answer = result["result"].strip()
sources = result.get("source_documents", [])
if not sources:
return "📌 回答:未在知识库中找到相关内容,请尝试更换问题或补充教材。"
sources_text = "\n\n".join(
[f"【片段 {i+1}】\n{doc.page_content}" for i, doc in enumerate(sources)]
)
return f"📌 回答:{answer}\n\n📚 参考:\n{sources_text}"
# ─── 6. Gradio 界面 ─────────────────────────────────────────────
with gr.Blocks(title="智能学习助手") as demo:
gr.Markdown("## 📘 智能学习助手\n输入教材相关问题,例如:“什么是函数的定义域?”")
with gr.Row():
query = gr.Textbox(label="问题", placeholder="请输入你的问题", lines=2)
answer = gr.Textbox(label="回答", lines=12)
gr.Button("提问").click(fn=qa_fn, inputs=query, outputs=answer)
gr.Markdown(
"---\n"
"模型:UER/GPT2-Chinese-ClueCorpus + Sentence-Transformers RAG \n"
"由 Hugging Face Spaces 提供算力支持"
)
if __name__ == "__main__":
demo.launch()
|