ljy5946's picture
Update app.py
d4ceb14 verified
raw
history blame
3.87 kB
# app.py ── 2025-06-08 适配 HuggingFace CPU Space
import os, logging, gradio as gr
from langchain_community.vectorstores import Chroma
from langchain_huggingface import HuggingFaceEmbeddings
from langchain.chains import RetrievalQA
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from langchain.llms import HuggingFacePipeline
logging.basicConfig(level=logging.INFO)
# ========= 1. 载入本地向量库 =========
embedder = HuggingFaceEmbeddings(
model_name="sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
)
VEC_DIR = "vector_store"
if not (os.path.isdir(VEC_DIR) and os.path.isfile(f"{VEC_DIR}/chroma.sqlite3")):
raise RuntimeError(f"❌ 未找到完整向量库 {VEC_DIR},请先执行 build_vector_store.py")
vectordb = Chroma(persist_directory=VEC_DIR, embedding_function=embedder)
# ========= 2. 载入轻量 LLM =========
model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" # 1.1B CPU 可跑
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
device_map="auto", # 需要 requirements 里有 accelerate
torch_dtype="auto",
)
generator = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
max_new_tokens=256,
temperature=0.7,
top_p=0.9,
)
llm = HuggingFacePipeline(pipeline=generator)
# ========= 3. 构建 RAG 问答链 =========
qa_chain = RetrievalQA.from_chain_type(
llm=llm,
chain_type="stuff",
retriever=vectordb.as_retriever(search_kwargs={"k": 3}),
)
# ========= 4. 业务函数 =========
def simple_qa(user_query: str):
if not user_query.strip():
return "⚠️ 请输入学习问题,例如:什么是定积分?"
try:
return qa_chain.run(user_query)
except Exception as e:
logging.error(f"[QA ERROR] {e}")
return f"⚠️ 问答失败:{e}"
def generate_outline(topic: str):
if not topic.strip():
return "⚠️ 请输入章节或主题", ""
try:
docs = vectordb.as_retriever(search_kwargs={"k": 3}).get_relevant_documents(topic)
snippet = "\n---\n".join(d.page_content for d in docs)
prompt = (
f"请基于以下资料,为「{topic}」生成结构化学习大纲,格式:\n"
f"一、章节标题\n 1. 节标题\n (1)要点…\n\n"
f"资料:\n{snippet}\n\n大纲:"
)
outline = llm.invoke(prompt).strip()
return outline, snippet
except Exception as e:
logging.error(f"[OUTLINE ERROR] {e}")
return "⚠️ 生成失败", ""
def placeholder(*_):
return "功能待开发…"
# ========= 5. Gradio UI =========
with gr.Blocks(title="智能学习助手") as demo:
gr.Markdown("# 智能学习助手 v2.0\n💡 大学生专业课 RAG Demo")
with gr.Tabs():
with gr.TabItem("智能问答"):
chatbot = gr.Chatbot(height=350)
msg = gr.Textbox(placeholder="在此提问…")
def chat(m, hist):
ans = simple_qa(m)
hist.append((m, ans))
return "", hist
msg.submit(chat, [msg, chatbot], [msg, chatbot])
with gr.TabItem("生成学习大纲"):
topic = gr.Textbox(label="章节主题", placeholder="高等数学 第六章 定积分")
outline = gr.Textbox(label="学习大纲", lines=12)
debug = gr.Textbox(label="调试:检索片段", lines=6)
gen = gr.Button("生成")
gen.click(generate_outline, [topic], [outline, debug])
with gr.TabItem("自动出题"):
placeholder(label="待开发")
with gr.TabItem("答案批改"):
placeholder(label="待开发")
gr.Markdown("---\nPowered by LangChain • TinyLlama • Chroma")
if __name__ == "__main__":
demo.launch()