Spaces:
Sleeping
Sleeping
File size: 5,081 Bytes
536e921 ac710d7 536e921 ac710d7 536e921 14517da 536e921 14517da 536e921 14517da 536e921 14517da 536e921 ac710d7 14517da ac710d7 14517da ac710d7 536e921 ac710d7 14517da ac710d7 14517da 536e921 ac710d7 14517da ac710d7 14517da 536e921 14517da ac710d7 14517da ac710d7 536e921 ac710d7 536e921 ac710d7 14517da ac710d7 536e921 ac710d7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 |
# app.py
import gradio as gr
import logging
from langchain_community.vectorstores import Chroma
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain.chains import RetrievalQA
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from langchain.llms import HuggingFacePipeline
logging.basicConfig(level=logging.INFO)
# === 加载向量库 ===
embedding_model = HuggingFaceEmbeddings(
model_name="sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
)
vector_store = Chroma(
persist_directory="vector_store",
embedding_function=embedding_model,
)
# === 加载 LLM 模型(openchat) ===
model_id = "openchat/openchat-3.5-0106"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
device_map="auto",
torch_dtype="auto",
)
gen_pipe = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
max_new_tokens=512,
temperature=0.7,
top_p=0.9,
)
llm = HuggingFacePipeline(pipeline=gen_pipe)
# === 构建问答链 ===
qa_chain = RetrievalQA.from_chain_type(
llm=llm,
chain_type="stuff",
retriever=vector_store.as_retriever(search_kwargs={"k": 3}),
)
# === 智能问答函数 ===
def simple_qa(user_query):
if not user_query.strip():
return "⚠️ 请输入学习问题,例如:什么是定积分?"
try:
answer = qa_chain.run(user_query)
return answer
except Exception as e:
logging.error(f"问答失败: {e}")
return "抱歉,暂时无法回答,请稍后再试。"
# === 大纲生成函数 ===
def generate_outline(topic: str):
if not topic.strip():
return "⚠️ 请输入章节或主题,例如:高等数学 第六章 定积分"
try:
docs = vector_store.as_retriever(search_kwargs={"k": 3}).get_relevant_documents(topic)
snippet = "\n".join([doc.page_content for doc in docs])
prompt = (
f"根据以下内容,为“{topic}”生成大学本科层次的结构化学习大纲,格式如下:\n"
f"一、章节标题\n 1. 节标题\n (1)要点描述\n...\n\n"
f"文档内容:\n{snippet}\n\n学习大纲:"
)
result = llm.generate(prompt).generations[0][0].text.strip()
return result
except Exception as e:
logging.error(f"大纲生成失败: {e}")
return "⚠️ 抱歉,生成失败,请稍后再试。"
# === 占位函数 ===
def placeholder_fn(*args, **kwargs):
return "功能尚未实现,请等待后续更新。"
# === Gradio UI ===
with gr.Blocks() as demo:
gr.Markdown("# 智能学习助手 v2.0\n— 大学生专业课学习助手 —")
with gr.Tabs():
# --- 模块 A:智能问答 ---
with gr.TabItem("智能问答"):
gr.Markdown("> 示例:什么是函数的定义域?")
chatbot = gr.Chatbot()
user_msg = gr.Textbox(placeholder="输入您的学习问题,然后按回车或点击发送")
send_btn = gr.Button("发送")
def update_chat(message, chat_history):
reply = simple_qa(message)
chat_history.append((message, reply))
return "", chat_history
send_btn.click(update_chat, inputs=[user_msg, chatbot], outputs=[user_msg, chatbot])
user_msg.submit(update_chat, inputs=[user_msg, chatbot], outputs=[user_msg, chatbot])
# --- 模块 B:生成学习大纲 ---
with gr.TabItem("生成学习大纲"):
gr.Markdown("> 示例:高等数学 第六章 定积分")
topic_input = gr.Textbox(label="章节主题", placeholder="请输入章节名")
outline_output = gr.Textbox(label="系统生成的大纲", lines=12)
gen_outline_btn = gr.Button("生成大纲")
gen_outline_btn.click(fn=generate_outline, inputs=topic_input, outputs=outline_output)
# --- 模块 C:自动出题(占位) ---
with gr.TabItem("自动出题"):
gr.Markdown("(出题模块,待开发)")
topic2 = gr.Textbox(label="知识点/主题", placeholder="如:高数 第三章 多元函数")
difficulty2 = gr.Dropdown(choices=["简单", "中等", "困难"], label="难度")
count2 = gr.Slider(1, 10, step=1, label="题目数量")
gen_q_btn = gr.Button("开始出题")
gen_q_btn.click(placeholder_fn, inputs=[topic2, difficulty2, count2], outputs=topic2)
# --- 模块 D:答案批改(占位) ---
with gr.TabItem("答案批改"):
gr.Markdown("(批改模块,待开发)")
std_ans = gr.Textbox(label="标准答案", lines=5)
user_ans = gr.Textbox(label="您的作答", lines=5)
grade_btn = gr.Button("开始批改")
grade_btn.click(placeholder_fn, inputs=[user_ans, std_ans], outputs=user_ans)
gr.Markdown("---\n由 HuggingFace 提供支持 • 版本 2.0")
if __name__ == "__main__":
demo.launch()
|