File size: 5,081 Bytes
536e921
ac710d7
536e921
ac710d7
536e921
 
 
 
 
 
 
 
14517da
536e921
 
 
 
 
 
 
 
14517da
536e921
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14517da
536e921
 
 
 
 
 
14517da
536e921
 
 
 
 
 
 
 
 
ac710d7
14517da
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ac710d7
 
 
14517da
ac710d7
536e921
ac710d7
 
14517da
ac710d7
14517da
536e921
 
 
 
 
 
 
 
ac710d7
14517da
 
 
 
ac710d7
14517da
 
 
536e921
14517da
ac710d7
14517da
ac710d7
536e921
ac710d7
536e921
 
 
 
ac710d7
14517da
ac710d7
536e921
 
 
 
 
 
 
ac710d7
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
# app.py
import gradio as gr
import logging

from langchain_community.vectorstores import Chroma
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain.chains import RetrievalQA
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from langchain.llms import HuggingFacePipeline

logging.basicConfig(level=logging.INFO)

# === 加载向量库 ===
embedding_model = HuggingFaceEmbeddings(
    model_name="sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
)
vector_store = Chroma(
    persist_directory="vector_store",
    embedding_function=embedding_model,
)

# === 加载 LLM 模型(openchat) ===
model_id = "openchat/openchat-3.5-0106"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_map="auto",
    torch_dtype="auto",
)
gen_pipe = pipeline(
    "text-generation",
    model=model,
    tokenizer=tokenizer,
    max_new_tokens=512,
    temperature=0.7,
    top_p=0.9,
)
llm = HuggingFacePipeline(pipeline=gen_pipe)

# === 构建问答链 ===
qa_chain = RetrievalQA.from_chain_type(
    llm=llm,
    chain_type="stuff",
    retriever=vector_store.as_retriever(search_kwargs={"k": 3}),
)

# === 智能问答函数 ===
def simple_qa(user_query):
    if not user_query.strip():
        return "⚠️ 请输入学习问题,例如:什么是定积分?"
    try:
        answer = qa_chain.run(user_query)
        return answer
    except Exception as e:
        logging.error(f"问答失败: {e}")
        return "抱歉,暂时无法回答,请稍后再试。"

# === 大纲生成函数 ===
def generate_outline(topic: str):
    if not topic.strip():
        return "⚠️ 请输入章节或主题,例如:高等数学 第六章 定积分"
    try:
        docs = vector_store.as_retriever(search_kwargs={"k": 3}).get_relevant_documents(topic)
        snippet = "\n".join([doc.page_content for doc in docs])
        prompt = (
            f"根据以下内容,为“{topic}”生成大学本科层次的结构化学习大纲,格式如下:\n"
            f"一、章节标题\n  1. 节标题\n    (1)要点描述\n...\n\n"
            f"文档内容:\n{snippet}\n\n学习大纲:"
        )
        result = llm.generate(prompt).generations[0][0].text.strip()
        return result
    except Exception as e:
        logging.error(f"大纲生成失败: {e}")
        return "⚠️ 抱歉,生成失败,请稍后再试。"

# === 占位函数 ===
def placeholder_fn(*args, **kwargs):
    return "功能尚未实现,请等待后续更新。"

# === Gradio UI ===
with gr.Blocks() as demo:
    gr.Markdown("# 智能学习助手 v2.0\n— 大学生专业课学习助手 —")

    with gr.Tabs():
        # --- 模块 A:智能问答 ---
        with gr.TabItem("智能问答"):
            gr.Markdown("> 示例:什么是函数的定义域?")
            chatbot = gr.Chatbot()
            user_msg = gr.Textbox(placeholder="输入您的学习问题,然后按回车或点击发送")
            send_btn = gr.Button("发送")

            def update_chat(message, chat_history):
                reply = simple_qa(message)
                chat_history.append((message, reply))
                return "", chat_history

            send_btn.click(update_chat, inputs=[user_msg, chatbot], outputs=[user_msg, chatbot])
            user_msg.submit(update_chat, inputs=[user_msg, chatbot], outputs=[user_msg, chatbot])

        # --- 模块 B:生成学习大纲 ---
        with gr.TabItem("生成学习大纲"):
            gr.Markdown("> 示例:高等数学 第六章 定积分")
            topic_input = gr.Textbox(label="章节主题", placeholder="请输入章节名")
            outline_output = gr.Textbox(label="系统生成的大纲", lines=12)
            gen_outline_btn = gr.Button("生成大纲")
            gen_outline_btn.click(fn=generate_outline, inputs=topic_input, outputs=outline_output)

        # --- 模块 C:自动出题(占位) ---
        with gr.TabItem("自动出题"):
            gr.Markdown("(出题模块,待开发)")
            topic2 = gr.Textbox(label="知识点/主题", placeholder="如:高数 第三章 多元函数")
            difficulty2 = gr.Dropdown(choices=["简单", "中等", "困难"], label="难度")
            count2 = gr.Slider(1, 10, step=1, label="题目数量")
            gen_q_btn = gr.Button("开始出题")
            gen_q_btn.click(placeholder_fn, inputs=[topic2, difficulty2, count2], outputs=topic2)

        # --- 模块 D:答案批改(占位) ---
        with gr.TabItem("答案批改"):
            gr.Markdown("(批改模块,待开发)")
            std_ans = gr.Textbox(label="标准答案", lines=5)
            user_ans = gr.Textbox(label="您的作答", lines=5)
            grade_btn = gr.Button("开始批改")
            grade_btn.click(placeholder_fn, inputs=[user_ans, std_ans], outputs=user_ans)

    gr.Markdown("---\n由 HuggingFace 提供支持 • 版本 2.0")

if __name__ == "__main__":
    demo.launch()