import os, gradio as gr, torch, logging from langchain_chroma import Chroma from langchain_community.embeddings import HuggingFaceEmbeddings # ← 新路径 from langchain_community.llms import HuggingFacePipeline from langchain.chains import RetrievalQA from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline logging.basicConfig(level=logging.INFO) VECTOR_STORE_DIR = "./vector_store" MODEL_NAME = "uer/gpt2-chinese-cluecorpussmall" EMBEDDING_MODEL_NAME = "sentence-transformers/paraphrase-multilingual-mpnet-base-v2" # ─── 1. 加载 LLM ─── print("🔧 加载生成模型…") tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) model = AutoModelForCausalLM.from_pretrained( MODEL_NAME, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, device_map="auto", ) gen_pipe = pipeline( task="text-generation", model=model, tokenizer=tokenizer, max_new_tokens=256, temperature=0.5, top_p=0.9, do_sample=True, ) llm = HuggingFacePipeline(pipeline=gen_pipe) # ─── 2. 加载向量库 ─── print("📚 加载向量库…") embeddings = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL_NAME) vectordb = Chroma(persist_directory=VECTOR_STORE_DIR, embedding_function=embeddings) # ─── 3. 构建 RAG 问答链 ─── retriever = vectordb.as_retriever(search_kwargs={"k": 3}) qa_chain = RetrievalQA.from_chain_type( llm=llm, chain_type="stuff", retriever=retriever, return_source_documents=True, ) # ─── 4. 业务函数 ─── def qa_fn(query: str): if not query.strip(): return "❌ 请输入问题内容。" result = qa_chain({"query": query}) answer = result["result"] sources = result.get("source_documents", []) sources_text = "\n\n".join( [f"【片段 {i+1}】\n{doc.page_content}" for i, doc in enumerate(sources)] ) return f"📌 回答:{answer.strip()}\n\n📚 参考:\n{sources_text}" # ─── 5. Gradio UI ─── with gr.Blocks(title="数学知识问答助手") as demo: gr.Markdown("## 📘 数学知识问答助手\n输入教材相关问题,例如:“什么是函数的定义域?”") with gr.Row(): query = gr.Textbox(label="问题", placeholder="请输入你的问题", lines=2) answer = gr.Textbox(label="回答", lines=15) gr.Button("提问").click(qa_fn, inputs=query, outputs=answer) gr.Markdown("---\n模型:gpt2-chinese-cluecorpus + Chroma RAG\nPowered by Hugging Face Spaces") if __name__ == "__main__": demo.launch()