Spaces:
Sleeping
Sleeping
File size: 5,350 Bytes
853beb6 536e921 2999a2c 536e921 853beb6 536e921 853beb6 536e921 853beb6 536e921 853beb6 536e921 853beb6 536e921 853beb6 536e921 853beb6 536e921 853beb6 536e921 853beb6 536e921 853beb6 536e921 853beb6 536e921 853beb6 536e921 853beb6 ac710d7 853beb6 14517da 853beb6 14517da 853beb6 14517da 853beb6 14517da 853beb6 14517da 853beb6 14517da 853beb6 ac710d7 853beb6 ac710d7 853beb6 14517da 853beb6 ac710d7 853beb6 ac710d7 853beb6 ac710d7 853beb6 ac710d7 853beb6 ac710d7 853beb6 536e921 853beb6 ac710d7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
# app.py
import gradio as gr
import logging
from langchain_community.vectorstores import Chroma
from langchain_huggingface import HuggingFaceEmbeddings
from langchain.chains import RetrievalQA
from langchain.schema import Document
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from langchain.llms import HuggingFacePipeline
logging.basicConfig(level=logging.INFO)
# === 加载向量库 ===
embedding_model = HuggingFaceEmbeddings(
model_name="sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
)
vector_store = Chroma(
persist_directory="vector_store",
embedding_function=embedding_model,
)
# === 加载 LLM 模型(openchat) ===
model_id = "openchat/openchat-3.5-0106"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
device_map="auto",
torch_dtype="auto",
)
gen_pipe = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
max_new_tokens=512,
temperature=0.7,
top_p=0.9,
)
llm = HuggingFacePipeline(pipeline=gen_pipe)
# === 构建问答链 ===
qa_chain = RetrievalQA.from_chain_type(
llm=llm,
chain_type="stuff",
retriever=vector_store.as_retriever(search_kwargs={"k": 3}),
)
# === 智能问答函数 ===
def simple_qa(user_query):
if not user_query.strip():
return "⚠️ 请输入学习问题,例如:什么是定积分?"
try:
answer = qa_chain.run(user_query)
return answer
except Exception as e:
logging.error(f"问答失败: {e}")
return f"⚠️ 问答失败,请稍后再试。\n[调试信息] {e}"
# === 大纲生成函数 ===
def generate_outline(topic: str):
if not topic.strip():
return "⚠️ 请输入章节或主题,例如:高等数学 第六章 定积分", ""
try:
docs = vector_store.as_retriever(search_kwargs={"k": 3}).get_relevant_documents(topic)
if not docs:
return "⚠️ 没有找到相关内容,请换个关键词试试。", ""
snippet = "\n".join(doc.page_content for doc in docs)
prompt = (
f"根据以下内容,为“{topic}”生成大学本科层次的结构化学习大纲,格式如下:\n"
f"一、章节标题\n 1. 节标题\n (1)要点描述\n...\n\n"
f"文档内容:\n{snippet}\n\n学习大纲:"
)
result = llm.generate(prompt).generations[0][0].text.strip()
return result, snippet
except Exception as e:
logging.error(f"大纲生成失败: {e}")
return "⚠️ 抱歉,生成失败,请稍后再试。", ""
# === 占位函数 ===
def placeholder_fn(*args, **kwargs):
return "功能尚未实现,请等待后续更新。"
# === Gradio UI ===
with gr.Blocks() as demo:
gr.Markdown("# 智能学习助手 v2.0\n— 大学生专业课学习助手 —")
with gr.Tabs():
# --- 模块 A:智能问答 ---
with gr.TabItem("智能问答"):
gr.Markdown("> 示例:什么是函数的定义域?")
chatbot = gr.Chatbot()
user_msg = gr.Textbox(placeholder="输入您的学习问题,然后按回车或点击发送")
send_btn = gr.Button("发送")
def update_chat(message, chat_history):
reply = simple_qa(message)
chat_history.append((message, reply))
return "", chat_history
send_btn.click(update_chat, inputs=[user_msg, chatbot], outputs=[user_msg, chatbot])
user_msg.submit(update_chat, inputs=[user_msg, chatbot], outputs=[user_msg, chatbot])
# --- 模块 B:生成学习大纲 ---
with gr.TabItem("生成学习大纲"):
gr.Markdown("> 示例:高等数学 第六章 定积分")
topic_input = gr.Textbox(label="章节主题", placeholder="请输入章节名")
outline_output = gr.Textbox(label="系统生成的大纲", lines=12)
snippet_output = gr.Textbox(label="[调试] 检索片段展示", lines=6)
gen_outline_btn = gr.Button("生成大纲")
gen_outline_btn.click(fn=generate_outline, inputs=topic_input, outputs=[outline_output, snippet_output])
# --- 模块 C:自动出题(占位) ---
with gr.TabItem("自动出题"):
gr.Markdown("(出题模块,待开发)")
topic2 = gr.Textbox(label="知识点/主题", placeholder="如:高数 第三章 多元函数")
difficulty2 = gr.Dropdown(choices=["简单", "中等", "困难"], label="难度")
count2 = gr.Slider(1, 10, step=1, label="题目数量")
gen_q_btn = gr.Button("开始出题")
gen_q_btn.click(placeholder_fn, inputs=[topic2, difficulty2, count2], outputs=topic2)
# --- 模块 D:答案批改(占位) ---
with gr.TabItem("答案批改"):
gr.Markdown("(批改模块,待开发)")
std_ans = gr.Textbox(label="标准答案", lines=5)
user_ans = gr.Textbox(label="您的作答", lines=5)
grade_btn = gr.Button("开始批改")
grade_btn.click(placeholder_fn, inputs=[user_ans, std_ans], outputs=user_ans)
gr.Markdown("---\n由 HuggingFace 提供支持 • 版本 2.0")
if __name__ == "__main__":
demo.launch()
|