File size: 4,679 Bytes
cfa0432
1831b73
e7e03e0
 
 
 
 
 
 
 
 
1831b73
e7e03e0
 
 
 
 
1831b73
cfa0432
e7e03e0
cfa0432
e7e03e0
 
 
 
1831b73
e7e03e0
cfa0432
e7e03e0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1831b73
e7e03e0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cfa0432
 
 
e7e03e0
 
1831b73
cfa0432
 
 
 
 
1831b73
cfa0432
1831b73
e7e03e0
 
1831b73
e7e03e0
 
cfa0432
 
 
 
 
1831b73
cfa0432
1831b73
e7e03e0
 
 
 
 
1831b73
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import os
import gradio as gr
import torch # 确保导入 torch,因为 phi-2 模型需要
import logging

# LangChain 新版导入方式
from langchain_chroma import Chroma
from langchain_huggingface import HuggingFaceEmbeddings
# 注意:HuggingFacePipeline 从 0.2.x 开始推荐从 langchain_huggingface 导入
# 如果遇到问题,也可以尝试从 langchain_community.llms 导入
from langchain_huggingface.llms import HuggingFacePipeline # 或者 from langchain_community.llms import HuggingFacePipeline
from langchain.chains import RetrievalQA

# Transformers 库
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline

logging.basicConfig(level=logging.INFO) # 更好的日志级别

# 设置路径
# 确保这些路径与您的预下载模型和向量库文件夹名称匹配
VECTOR_STORE_DIR = "./vector_store"
# 将模型名称指向本地预下载的路径
MODEL_NAME = "./hf_models_cache/models--uer--gpt2-chinese-cluecorpussmall"
EMBEDDING_MODEL_NAME = "./hf_models_cache/models--sentence-transformers--paraphrase-multilingual-mpnet-base-v2"


# 1. 轻量 LLM(uer/gpt2-chinese-cluecorpussmall)
print("🔧 加载生成模型...")
try:
    # 确保 tokenizer 和 model 是从正确的本地路径加载
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_NAME,
        torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
        device_map="auto",
    )
    gen_pipe = pipeline(
        task="text-generation",
        model=model,
        tokenizer=tokenizer,
        max_new_tokens=256,
        temperature=0.5,
        top_p=0.9,
        do_sample=True,
    )
    llm = HuggingFacePipeline(pipeline=gen_pipe)
    print("✅ 生成模型加载成功。")
except Exception as e:
    logging.error(f"加载生成模型失败: {e}", exc_info=True)
    # 如果加载失败,可以考虑在这里退出或者给一个友好的错误提示
    llm = None # 确保 llm 为 None,避免后续报错
    print("❌ 生成模型加载失败,应用可能无法正常工作。")


# 2. 向量库和嵌入模型
print("📚 加载向量库和嵌入模型...")
try:
    embeddings = HuggingFaceEmbeddings(
        model_name=EMBEDDING_MODEL_NAME # 指向本地预下载的嵌入模型路径
    )
    vectordb = Chroma(persist_directory=VECTOR_STORE_DIR, embedding_function=embeddings)
    print("✅ 向量库加载成功。")
except Exception as e:
    logging.error(f"加载向量库失败: {e}", exc_info=True)
    vectordb = None # 确保 vectordb 为 None
    print("❌ 向量库加载失败,RAG功能将无法使用。")


# 3. RAG 问答链
qa_chain = None
if llm and vectordb: # 只有当LLM和向量库都成功加载时才构建RAG链
    try:
        retriever = vectordb.as_retriever(search_kwargs={"k": 3})
        qa_chain = RetrievalQA.from_chain_type(
            llm=llm,
            chain_type="stuff",
            retriever=retriever,
            return_source_documents=True
        )
        print("✅ RAG问答链构建成功。")
    except Exception as e:
        logging.error(f"构建RAG问答链失败: {e}", exc_info=True)
        print("❌ RAG问答链构建失败。")


# 4. 业务函数
def qa_fn(query):
    if not query.strip():
        return "❌ 请输入问题内容。"
    if not qa_chain: # 检查 qa_chain 是否已成功构建
        return "⚠️ 问答系统未完全加载,请稍后再试或检查日志。"
    try:
        result = qa_chain({"query": query})
        answer = result["result"]
        sources = result.get("source_documents", [])
        sources_text = "\n\n".join(
            [f"【片段 {i+1}】\n" + doc.page_content for i, doc in enumerate(sources)]
        )
        return f"📌 回答:{answer.strip()}\n\n📚 参考:\n{sources_text}"
    except Exception as e:
        logging.exception("问答失败:%s", e)
        return f"❌ 出现错误:{str(e)}\n请检查日志获取更多信息。"

# 5. Gradio UI
with gr.Blocks(title="数学知识问答助手", theme=gr.themes.Base()) as demo:
    gr.Markdown("## 📘 数学知识问答助手\n输入教材相关问题,例如:“什么是函数的定义域?”")
    with gr.Row():
        query_input = gr.Textbox(label="问题", placeholder="请输入你的问题", lines=2)
        output_box = gr.Textbox(label="回答", lines=15)
    submit_btn = gr.Button("提问")

    submit_btn.click(fn=qa_fn, inputs=query_input, outputs=output_box)

    gr.Markdown("---\n模型:uer/gpt2-chinese-cluecorpussmall + Chroma RAG | Powered by Hugging Face Spaces")


if __name__ == "__main__":
    demo.launch()