Spaces:
Sleeping
Sleeping
# app.py ── 2025-06-08 适配 HuggingFace CPU Space | |
import os, logging, gradio as gr | |
from langchain_community.vectorstores import Chroma | |
from langchain_huggingface import HuggingFaceEmbeddings | |
from langchain.chains import RetrievalQA | |
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline | |
from langchain.llms import HuggingFacePipeline | |
logging.basicConfig(level=logging.INFO) | |
# ========= 1. 载入本地向量库 ========= | |
embedder = HuggingFaceEmbeddings( | |
model_name="sentence-transformers/paraphrase-multilingual-mpnet-base-v2" | |
) | |
VEC_DIR = "vector_store" | |
if not (os.path.isdir(VEC_DIR) and os.path.isfile(f"{VEC_DIR}/chroma.sqlite3")): | |
raise RuntimeError(f"❌ 未找到完整向量库 {VEC_DIR},请先执行 build_vector_store.py") | |
vectordb = Chroma(persist_directory=VEC_DIR, embedding_function=embedder) | |
# ========= 2. 载入轻量 LLM ========= | |
model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" # 1.1B CPU 可跑 | |
tokenizer = AutoTokenizer.from_pretrained(model_id) | |
model = AutoModelForCausalLM.from_pretrained( | |
model_id, | |
device_map="auto", # 需要 requirements 里有 accelerate | |
torch_dtype="auto", | |
) | |
generator = pipeline( | |
"text-generation", | |
model=model, | |
tokenizer=tokenizer, | |
max_new_tokens=256, | |
temperature=0.7, | |
top_p=0.9, | |
) | |
llm = HuggingFacePipeline(pipeline=generator) | |
# ========= 3. 构建 RAG 问答链 ========= | |
qa_chain = RetrievalQA.from_chain_type( | |
llm=llm, | |
chain_type="stuff", | |
retriever=vectordb.as_retriever(search_kwargs={"k": 3}), | |
) | |
# ========= 4. 业务函数 ========= | |
def simple_qa(user_query: str): | |
if not user_query.strip(): | |
return "⚠️ 请输入学习问题,例如:什么是定积分?" | |
try: | |
return qa_chain.run(user_query) | |
except Exception as e: | |
logging.error(f"[QA ERROR] {e}") | |
return f"⚠️ 问答失败:{e}" | |
def generate_outline(topic: str): | |
if not topic.strip(): | |
return "⚠️ 请输入章节或主题", "" | |
try: | |
docs = vectordb.as_retriever(search_kwargs={"k": 3}).get_relevant_documents(topic) | |
snippet = "\n---\n".join(d.page_content for d in docs) | |
prompt = ( | |
f"请基于以下资料,为「{topic}」生成结构化学习大纲,格式:\n" | |
f"一、章节标题\n 1. 节标题\n (1)要点…\n\n" | |
f"资料:\n{snippet}\n\n大纲:" | |
) | |
outline = llm.invoke(prompt).strip() | |
return outline, snippet | |
except Exception as e: | |
logging.error(f"[OUTLINE ERROR] {e}") | |
return "⚠️ 生成失败", "" | |
def placeholder(*_): | |
return "功能待开发…" | |
# ========= 5. Gradio UI ========= | |
with gr.Blocks(title="智能学习助手") as demo: | |
gr.Markdown("# 智能学习助手 v2.0\n💡 大学生专业课 RAG Demo") | |
with gr.Tabs(): | |
with gr.TabItem("智能问答"): | |
chatbot = gr.Chatbot(height=350) | |
msg = gr.Textbox(placeholder="在此提问…") | |
def chat(m, hist): | |
ans = simple_qa(m) | |
hist.append((m, ans)) | |
return "", hist | |
msg.submit(chat, [msg, chatbot], [msg, chatbot]) | |
with gr.TabItem("生成学习大纲"): | |
topic = gr.Textbox(label="章节主题", placeholder="高等数学 第六章 定积分") | |
outline = gr.Textbox(label="学习大纲", lines=12) | |
debug = gr.Textbox(label="调试:检索片段", lines=6) | |
gen = gr.Button("生成") | |
gen.click(generate_outline, [topic], [outline, debug]) | |
with gr.TabItem("自动出题"): | |
placeholder(label="待开发") | |
with gr.TabItem("答案批改"): | |
placeholder(label="待开发") | |
gr.Markdown("---\nPowered by LangChain • TinyLlama • Chroma") | |
if __name__ == "__main__": | |
demo.launch() | |