Spaces:
Sleeping
Sleeping
File size: 5,184 Bytes
4ded835 bee6ae5 4ded835 bee6ae5 4ded835 bee6ae5 35634c4 a7b5cc2 35634c4 1831b73 bee6ae5 e7e03e0 35634c4 d56ced4 4e58501 1831b73 bee6ae5 4ded835 e7e03e0 35634c4 d56ced4 4ded835 d56ced4 bee6ae5 c126f3f d56ced4 c126f3f d56ced4 e7e03e0 bee6ae5 d56ced4 c126f3f 4ded835 d56ced4 bee6ae5 4ded835 c126f3f 35634c4 c126f3f 4ded835 e7e03e0 bee6ae5 c126f3f 4ded835 cfa0432 4ded835 c126f3f 4ded835 bee6ae5 4ded835 c126f3f bee6ae5 1831b73 bee6ae5 cfa0432 c126f3f bee6ae5 4ded835 bee6ae5 e7e03e0 1831b73 c126f3f bee6ae5 d56ced4 4ded835 35634c4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 |
# app.py
import os
import logging
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
# Embeddings 与 VectorStore 用新的分包
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_chroma import Chroma
# LLM 继续用 community 包里的 Pipeline
from langchain_community.llms import HuggingFacePipeline
from langchain.chains import RetrievalQA
from langchain.prompts import PromptTemplate
from build_index import main as build_index_if_needed # 确保 build_index.py 与 app.py 同目录
logging.basicConfig(level=logging.INFO)
# ─── 配置 ─────────────────────────────────────────────────────
VECTOR_STORE_DIR = "./vector_store"
MODEL_NAME = "uer/gpt2-chinese-cluecorpussmall"
EMBEDDING_MODEL_NAME = "sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
# 容器启动时自动构建向量库(如果 vector_store 目录为空)
if not os.path.exists(VECTOR_STORE_DIR) or not os.listdir(VECTOR_STORE_DIR):
logging.info("向量库不存在,启动自动构建……")
build_index_if_needed()
# ─── 1. 加载生成模型 ──────────────────────────────────────────────
logging.info("🔧 加载生成模型…")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
device_map="auto",
)
gen_pipe = pipeline(
task="text-generation",
model=model,
tokenizer=tokenizer,
max_new_tokens=256,
temperature=0.5,
top_p=0.9,
do_sample=True,
trust_remote_code=True,
)
llm = HuggingFacePipeline(pipeline=gen_pipe)
logging.info("✅ 生成模型加载成功。")
# ─── 2. 加载向量库 ─────────────────────────────────────────────
logging.info("📚 加载向量库…")
embeddings = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL_NAME)
vectordb = Chroma(persist_directory=VECTOR_STORE_DIR, embedding_function=embeddings)
retriever = vectordb.as_retriever(search_kwargs={"k": 3})
logging.info("✅ 向量库加载成功。")
# ─── 3. 自定义 Prompt ─────────────────────────────────────────
prompt_template = PromptTemplate.from_template(
"""你是一位专业的数学助教,请根据以下参考资料回答用户的问题。
如果资料中没有相关内容,请直接回答“我不知道”或“资料中未提及”,不要编造答案。
参考资料:
{context}
用户问题:
{question}
回答(只允许基于参考资料,不要编造):
"""
)
# ─── 4. 构建 RAG 问答链(map_reduce) ───────────────────────────
qa_chain = RetrievalQA.from_chain_type(
llm=llm,
chain_type="map_reduce", # map_reduce 自动分段、避免超长
retriever=retriever,
return_source_documents=True,
)
logging.info("✅ RAG 问答链(map_reduce)构建成功。")
# ─── 5. 业务函数 ───────────────────────────────────────────────
def qa_fn(query: str):
if not query or not query.strip():
return "❌ 请输入问题内容。"
try:
result = qa_chain({"query": query})
except Exception as e:
logging.error(f"问答链运行出错:{e}")
return "抱歉,问答过程中出现错误,请稍后重试。"
answer = result.get("result", "").strip()
sources = result.get("source_documents", [])
if not answer:
return "📌 回答:未生成答案,请稍后再试。"
if not sources:
return f"📌 回答:{answer}\n\n(未检索到参考片段)"
# 拼接参考片段
sources_text = "\n\n".join(
[f"【片段 {i+1}】\n{doc.page_content}" for i, doc in enumerate(sources)]
)
return f"📌 回答:{answer}\n\n📚 参考:\n{sources_text}"
# ─── 6. Gradio 界面 ─────────────────────────────────────────────
with gr.Blocks(title="智能学习助手") as demo:
gr.Markdown("## 📘 智能学习助手\n输入教材相关问题,例如:“什么是函数的定义域?”")
with gr.Row():
query = gr.Textbox(label="问题", placeholder="请输入你的问题", lines=2)
answer = gr.Textbox(label="回答", lines=12)
gr.Button("提问").click(fn=qa_fn, inputs=query, outputs=answer)
gr.Markdown(
"---\n"
"模型:UER/GPT2-Chinese-ClueCorpus + Sentence-Transformers RAG (map_reduce) \n"
"由 Hugging Face Spaces 提供算力支持"
)
if __name__ == "__main__":
demo.launch()
|