Spaces:
Sleeping
Sleeping
File size: 4,428 Bytes
bee6ae5 a7b5cc2 1831b73 bee6ae5 e7e03e0 4e58501 1831b73 bee6ae5 c126f3f e7e03e0 bee6ae5 c126f3f bee6ae5 c126f3f bee6ae5 e7e03e0 bee6ae5 c126f3f bee6ae5 c126f3f bee6ae5 c126f3f bee6ae5 e7e03e0 bee6ae5 c126f3f cfa0432 bee6ae5 c126f3f bee6ae5 c126f3f bee6ae5 c126f3f bee6ae5 1831b73 bee6ae5 cfa0432 c126f3f bee6ae5 e7e03e0 1831b73 c126f3f bee6ae5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 |
import os
import gradio as gr
import torch
import logging
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import Chroma
from langchain_community.llms import HuggingFacePipeline
from langchain.chains import RetrievalQA
from langchain.prompts import PromptTemplate
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
logging.basicConfig(level=logging.INFO)
# ─── 配置 ─────────────────────────────────────────────────────
VECTOR_STORE_DIR = "./vector_store"
MODEL_NAME = "uer/gpt2-chinese-cluecorpussmall"
EMBEDDING_MODEL_NAME = "sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
# ─── 1. 加载 LLM ────────────────────────────────────────────────
print("🔧 加载生成模型…")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
device_map="auto",
)
gen_pipe = pipeline(
task="text-generation",
model=model,
tokenizer=tokenizer,
max_new_tokens=256,
temperature=0.5,
top_p=0.9,
do_sample=True,
)
llm = HuggingFacePipeline(pipeline=gen_pipe)
print("✅ 生成模型加载成功。")
# ─── 2. 加载向量库 ─────────────────────────────────────────────
print("📚 加载向量库…")
embeddings = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL_NAME)
vectordb = Chroma(persist_directory=VECTOR_STORE_DIR, embedding_function=embeddings)
retriever = vectordb.as_retriever(search_kwargs={"k": 3})
print("✅ 向量库加载成功。")
# ─── 3. 自定义 Prompt ─────────────────────────────────────────
prompt_template = PromptTemplate.from_template(
"""你是一位专业的数学助教,请根据以下参考资料回答用户的问题。
如果资料中没有相关内容,请直接回答“我不知道”或“资料中未提及”,不要编造答案。
参考资料:
{context}
用户问题:
{question}
回答(只允许基于参考资料,不要编造):
"""
)
# ─── 4. 构建 RAG 问答链 ───────────────────────────────────────
qa_chain = RetrievalQA.from_chain_type(
llm=llm,
chain_type="stuff",
retriever=retriever,
chain_type_kwargs={"prompt": prompt_template},
return_source_documents=True,
)
print("✅ RAG 问答链构建成功。")
# ─── 5. 业务函数 ───────────────────────────────────────────────
def qa_fn(query: str):
if not query.strip():
return "❌ 请输入问题内容。"
# 执行检索与问答
result = qa_chain({"query": query})
answer = result["result"].strip()
sources = result.get("source_documents", [])
if not sources:
return "📌 回答:未在知识库中找到相关内容,请尝试更换问题或补充教材。"
# 拼接参考片段
sources_text = "\n\n".join(
[f"【片段 {i+1}】\n{doc.page_content}" for i, doc in enumerate(sources)]
)
return f"📌 回答:{answer}\n\n📚 参考:\n{sources_text}"
# ─── 6. Gradio 界面 ─────────────────────────────────────────────
with gr.Blocks(title="智能学习助手") as demo:
gr.Markdown("## 📘 智能学习助手\n输入教材相关问题,例如:“什么是函数的定义域?”")
with gr.Row():
query = gr.Textbox(label="问题", placeholder="请输入你的问题", lines=2)
answer = gr.Textbox(label="回答", lines=12)
gr.Button("提问").click(fn=qa_fn, inputs=query, outputs=answer)
gr.Markdown(
"---\n"
"模型:UER/GPT2-Chinese-ClueCorpus + Sentence-Transformers RAG \n"
"由 Hugging Face Spaces 提供算力支持"
)
if __name__ == "__main__":
demo.launch()
|