File size: 3,998 Bytes
cfa0432
1831b73
4e58501
e7e03e0
 
4e58501
e7e03e0
b7a5a02
a7b5cc2
1831b73
e7e03e0
 
 
 
4e58501
1831b73
cfa0432
a7b5cc2
 
 
 
e7e03e0
 
cfa0432
e7e03e0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4e58501
e7e03e0
 
 
 
 
 
4e58501
e7e03e0
 
 
 
 
4e58501
e7e03e0
 
 
 
4e58501
e7e03e0
 
 
 
 
 
 
 
 
 
 
 
 
 
cfa0432
 
 
4e58501
e7e03e0
1831b73
cfa0432
 
 
 
 
1831b73
cfa0432
1831b73
e7e03e0
 
1831b73
e7e03e0
 
cfa0432
 
 
 
 
1831b73
cfa0432
1831b73
e7e03e0
 
 
 
1831b73
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import os
import gradio as gr
import torch
import logging

# LangChain 0.1.x 系列的导入方式
from langchain_chroma import Chroma
from langchain.embeddings import HuggingFaceEmbeddings
from langchain_community.llms import HuggingFacePipeline
from langchain.chains import RetrievalQA

# Transformers 库
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline

logging.basicConfig(level=logging.INFO)

# 设置路径
# 直接使用 Hugging Face Hub 上的模型 ID,而不是本地缓存路径
VECTOR_STORE_DIR = "./vector_store" # 这个目录用于 ChromaDB,我们保留
MODEL_NAME = "uer/gpt2-chinese-cluecorpussmall"  # <--- 修改这里!
EMBEDDING_MODEL_NAME = "sentence-transformers/paraphrase-multilingual-mpnet-base-v2" # <--- 修改这里!

# 1. 轻量 LLM(uer/gpt2-chinese-cluecorpussmall)
print("🔧 加载生成模型...")
try:
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_NAME,
        torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
        device_map="auto",
    )
    gen_pipe = pipeline(
        task="text-generation",
        model=model,
        tokenizer=tokenizer,
        max_new_tokens=256,
        temperature=0.5,
        top_p=0.9,
        do_sample=True,
    )
    llm = HuggingFacePipeline(pipeline=gen_pipe)
    print("✅ 生成模型加载成功。")
except Exception as e:
    logging.error(f"加载生成模型失败: {e}", exc_info=True)
    llm = None
    print("❌ 生成模型加载失败,应用可能无法正常工作。")

# 2. 向量库和嵌入模型
print("📚 加载向量库和嵌入模型...")
try:
    embeddings = HuggingFaceEmbeddings(
        model_name=EMBEDDING_MODEL_NAME
    )
    vectordb = Chroma(persist_directory=VECTOR_STORE_DIR, embedding_function=embeddings)
    print("✅ 向量库加载成功。")
except Exception as e:
    logging.error(f"加载向量库失败: {e}", exc_info=True)
    vectordb = None
    print("❌ 向量库加载失败,RAG功能将无法使用。")

# 3. RAG 问答链
qa_chain = None
if llm and vectordb:
    try:
        retriever = vectordb.as_retriever(search_kwargs={"k": 3})
        qa_chain = RetrievalQA.from_chain_type(
            llm=llm,
            chain_type="stuff",
            retriever=retriever,
            return_source_documents=True
        )
        print("✅ RAG问答链构建成功。")
    except Exception as e:
        logging.error(f"构建RAG问答链失败: {e}", exc_info=True)
        print("❌ RAG问答链构建失败。")

# 4. 业务函数
def qa_fn(query):
    if not query.strip():
        return "❌ 请输入问题内容。"
    if not qa_chain:
        return "⚠️ 问答系统未完全加载,请稍后再试或检查日志。"
    try:
        result = qa_chain({"query": query})
        answer = result["result"]
        sources = result.get("source_documents", [])
        sources_text = "\n\n".join(
            [f"【片段 {i+1}】\n" + doc.page_content for i, doc in enumerate(sources)]
        )
        return f"📌 回答:{answer.strip()}\n\n📚 参考:\n{sources_text}"
    except Exception as e:
        logging.exception("问答失败:%s", e)
        return f"❌ 出现错误:{str(e)}\n请检查日志获取更多信息。"

# 5. Gradio UI
with gr.Blocks(title="数学知识问答助手", theme=gr.themes.Base()) as demo:
    gr.Markdown("## 📘 数学知识问答助手\n输入教材相关问题,例如:“什么是函数的定义域?”")
    with gr.Row():
        query_input = gr.Textbox(label="问题", placeholder="请输入你的问题", lines=2)
        output_box = gr.Textbox(label="回答", lines=15)
    submit_btn = gr.Button("提问")

    submit_btn.click(fn=qa_fn, inputs=query_input, outputs=output_box)

    gr.Markdown("---\n模型:uer/gpt2-chinese-cluecorpussmall + Chroma RAG | Powered by Hugging Face Spaces")

if __name__ == "__main__":
    demo.launch()