Spaces:
Sleeping
Sleeping
File size: 3,871 Bytes
d4ceb14 536e921 2999a2c 536e921 d4ceb14 536e921 d4ceb14 536e921 d4ceb14 536e921 d4ceb14 536e921 d4ceb14 536e921 d4ceb14 536e921 d4ceb14 536e921 d4ceb14 536e921 d4ceb14 536e921 d4ceb14 536e921 d4ceb14 536e921 d4ceb14 ac710d7 14517da d4ceb14 14517da d4ceb14 14517da d4ceb14 14517da d4ceb14 14517da d4ceb14 14517da d4ceb14 ac710d7 d4ceb14 ac710d7 d4ceb14 14517da ac710d7 d4ceb14 ac710d7 d4ceb14 ac710d7 d4ceb14 536e921 d4ceb14 ac710d7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 |
# app.py ── 2025-06-08 适配 HuggingFace CPU Space
import os, logging, gradio as gr
from langchain_community.vectorstores import Chroma
from langchain_huggingface import HuggingFaceEmbeddings
from langchain.chains import RetrievalQA
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from langchain.llms import HuggingFacePipeline
logging.basicConfig(level=logging.INFO)
# ========= 1. 载入本地向量库 =========
embedder = HuggingFaceEmbeddings(
model_name="sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
)
VEC_DIR = "vector_store"
if not (os.path.isdir(VEC_DIR) and os.path.isfile(f"{VEC_DIR}/chroma.sqlite3")):
raise RuntimeError(f"❌ 未找到完整向量库 {VEC_DIR},请先执行 build_vector_store.py")
vectordb = Chroma(persist_directory=VEC_DIR, embedding_function=embedder)
# ========= 2. 载入轻量 LLM =========
model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" # 1.1B CPU 可跑
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
device_map="auto", # 需要 requirements 里有 accelerate
torch_dtype="auto",
)
generator = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
max_new_tokens=256,
temperature=0.7,
top_p=0.9,
)
llm = HuggingFacePipeline(pipeline=generator)
# ========= 3. 构建 RAG 问答链 =========
qa_chain = RetrievalQA.from_chain_type(
llm=llm,
chain_type="stuff",
retriever=vectordb.as_retriever(search_kwargs={"k": 3}),
)
# ========= 4. 业务函数 =========
def simple_qa(user_query: str):
if not user_query.strip():
return "⚠️ 请输入学习问题,例如:什么是定积分?"
try:
return qa_chain.run(user_query)
except Exception as e:
logging.error(f"[QA ERROR] {e}")
return f"⚠️ 问答失败:{e}"
def generate_outline(topic: str):
if not topic.strip():
return "⚠️ 请输入章节或主题", ""
try:
docs = vectordb.as_retriever(search_kwargs={"k": 3}).get_relevant_documents(topic)
snippet = "\n---\n".join(d.page_content for d in docs)
prompt = (
f"请基于以下资料,为「{topic}」生成结构化学习大纲,格式:\n"
f"一、章节标题\n 1. 节标题\n (1)要点…\n\n"
f"资料:\n{snippet}\n\n大纲:"
)
outline = llm.invoke(prompt).strip()
return outline, snippet
except Exception as e:
logging.error(f"[OUTLINE ERROR] {e}")
return "⚠️ 生成失败", ""
def placeholder(*_):
return "功能待开发…"
# ========= 5. Gradio UI =========
with gr.Blocks(title="智能学习助手") as demo:
gr.Markdown("# 智能学习助手 v2.0\n💡 大学生专业课 RAG Demo")
with gr.Tabs():
with gr.TabItem("智能问答"):
chatbot = gr.Chatbot(height=350)
msg = gr.Textbox(placeholder="在此提问…")
def chat(m, hist):
ans = simple_qa(m)
hist.append((m, ans))
return "", hist
msg.submit(chat, [msg, chatbot], [msg, chatbot])
with gr.TabItem("生成学习大纲"):
topic = gr.Textbox(label="章节主题", placeholder="高等数学 第六章 定积分")
outline = gr.Textbox(label="学习大纲", lines=12)
debug = gr.Textbox(label="调试:检索片段", lines=6)
gen = gr.Button("生成")
gen.click(generate_outline, [topic], [outline, debug])
with gr.TabItem("自动出题"):
placeholder(label="待开发")
with gr.TabItem("答案批改"):
placeholder(label="待开发")
gr.Markdown("---\nPowered by LangChain • TinyLlama • Chroma")
if __name__ == "__main__":
demo.launch()
|