superone001 commited on
Commit
420e08f
·
verified ·
1 Parent(s): 81917a3

Create agent.py

Browse files
Files changed (1) hide show
  1. agent.py +242 -0
agent.py ADDED
@@ -0,0 +1,242 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ LangGraph Agent - 多工具智能代理系统
3
+ 结合数学计算、网络搜索、学术检索和向量数据库增强能力
4
+ 支持多种AI模型提供商(Google Gemini, Groq, HuggingFace)
5
+ """
6
+
7
+ import os
8
+ from dotenv import load_dotenv
9
+ from langgraph.graph import START, StateGraph, MessagesState
10
+ from langgraph.prebuilt import tools_condition
11
+ from langgraph.prebuilt import ToolNode
12
+ from langchain_google_genai import ChatGoogleGenerativeAI
13
+ from langchain_groq import ChatGroq
14
+ from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFaceEmbeddings
15
+ from langchain_community.tools.tavily_search import TavilySearchResults
16
+ from langchain_community.document_loaders import WikipediaLoader
17
+ from langchain_community.document_loaders import ArxivLoader
18
+ from langchain_community.vectorstores import SupabaseVectorStore
19
+ from langchain_core.messages import SystemMessage, HumanMessage
20
+ from langchain_core.tools import tool
21
+ from langchain.tools.retriever import create_retriever_tool
22
+ from supabase.client import Client, create_client
23
+
24
+ # 加载环境变量(API密钥、数据库连接等)
25
+ load_dotenv()
26
+
27
+ # ======================
28
+ # 工具定义部分
29
+ # ======================
30
+
31
+ @tool
32
+ def multiply(a: int, b: int) -> int:
33
+ """乘法运算: 返回两个整数的乘积"""
34
+ return a * b
35
+
36
+ @tool
37
+ def add(a: int, b: int) -> int:
38
+ """加法运算: 返回两个整数的和"""
39
+ return a + b
40
+
41
+ @tool
42
+ def subtract(a: int, b: int) -> int:
43
+ """减法运算: 返回两个整数的差"""
44
+ return a - b
45
+
46
+ @tool
47
+ def divide(a: int, b: int) -> int:
48
+ """除法运算: 返回两个整数的商"""
49
+ if b == 0:
50
+ raise ValueError("Cannot divide by zero.")
51
+ return a / b
52
+
53
+ @tool
54
+ def modulus(a: int, b: int) -> int:
55
+ """取模运算: 返回两个整数的模"""
56
+ return a % b
57
+
58
+ @tool
59
+ def wiki_search(query: str) -> str:
60
+ """维基百科搜索: 返回最多2个相关结果"""
61
+ search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
62
+ # 格式化搜索结果
63
+ formatted_search_docs = "\n\n---\n\n".join(
64
+ [
65
+ f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
66
+ for doc in search_docs
67
+ ])
68
+ return {"wiki_results": formatted_search_docs}
69
+
70
+ @tool
71
+ def web_search(query: str) -> str:
72
+ """网络搜索(Tavily): 返回最多3个相关结果"""
73
+ search_docs = TavilySearchResults(max_results=3).invoke(query=query)
74
+ # 格式化搜索结果
75
+ formatted_search_docs = "\n\n---\n\n".join(
76
+ [
77
+ f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
78
+ for doc in search_docs
79
+ ])
80
+ return {"web_results": formatted_search_docs}
81
+
82
+ @tool
83
+ def arvix_search(query: str) -> str:
84
+ """学术论文搜索(Arxiv): 返回最多3个相关结果"""
85
+ search_docs = ArxivLoader(query=query, load_max_docs=3).load()
86
+ # 格式化搜索结果(截取前1000字符)
87
+ formatted_search_docs = "\n\n---\n\n".join(
88
+ [
89
+ f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content[:1000]}\n</Document>'
90
+ for doc in search_docs
91
+ ])
92
+ return {"arvix_results": formatted_search_docs}
93
+
94
+ # ======================
95
+ # 系统初始化和配置
96
+ # ======================
97
+
98
+ # 加载系统提示(定义AI行为准则)
99
+ with open("system_prompt.txt", "r", encoding="utf-8") as f:
100
+ system_prompt = f.read()
101
+ sys_msg = SystemMessage(content=system_prompt)
102
+
103
+ # 构建向量数据库检索工具
104
+ embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
105
+ supabase: Client = create_client(
106
+ os.environ.get("SUPABASE_URL"),
107
+ os.environ.get("SUPABASE_SERVICE_KEY"))
108
+ vector_store = SupabaseVectorStore(
109
+ client=supabase,
110
+ embedding=embeddings,
111
+ table_name="documents",
112
+ query_name="match_documents_langchain",
113
+ )
114
+ retriever_tool = create_retriever_tool(
115
+ retriever=vector_store.as_retriever(),
116
+ name="Question Search",
117
+ description="从向量数据库中检索相似问题",
118
+ )
119
+
120
+ # 所有可用工具列表
121
+ tools = [
122
+ multiply,
123
+ add,
124
+ subtract,
125
+ divide,
126
+ modulus,
127
+ wiki_search,
128
+ web_search,
129
+ arvix_search,
130
+ retriever_tool, # 添加向量检索工具
131
+ ]
132
+
133
+ # ======================
134
+ # 图构建函数(核心逻辑)
135
+ # ======================
136
+
137
+ def build_graph(provider: str = "groq"):
138
+ """
139
+ 构建LangGraph工作流
140
+
141
+ 参数:
142
+ provider: AI模型提供商 ("google", "groq", "huggingface")
143
+
144
+ 返回:
145
+ 编译好的LangGraph对象
146
+ """
147
+
148
+ # 1. 选择AI模型
149
+ if provider == "google":
150
+ # Google Gemini模型
151
+ llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0)
152
+ elif provider == "groq":
153
+ # Groq高速推理模型
154
+ llm = ChatGroq(model="qwen-qwq-32b", temperature=0) # 可选模型: qwen-qwq-32b, gemma2-9b-it
155
+ elif provider == "huggingface":
156
+ # HuggingFace端点模型
157
+ llm = ChatHuggingFace(
158
+ llm=HuggingFaceEndpoint(
159
+ url="https://api-inference.huggingface.co/models/Meta-DeepLearning/llama-2-7b-chat-hf",
160
+ temperature=0,
161
+ ),
162
+ )
163
+ else:
164
+ raise ValueError("无效的提供商。请选择 'google', 'groq' 或 'huggingface'")
165
+
166
+ # 2. 将工具绑定到AI模型
167
+ llm_with_tools = llm.bind_tools(tools)
168
+
169
+ # 3. 定义图节点
170
+
171
+ def retriever_node(state: MessagesState):
172
+ """检索节点:从向量数据库查找相似问题"""
173
+ # 获取最新用户消息
174
+ user_query = state["messages"][-1].content
175
+
176
+ # 从向量数据库检索相似问题
177
+ similar_question = vector_store.similarity_search(user_query)
178
+
179
+ # 构建参考消息
180
+ reference_msg = HumanMessage(
181
+ content=f"参考类似问题及解答:\n\n{similar_question[0].page_content}",
182
+ )
183
+
184
+ # 返回增强后的消息流:系统提示 + 原始消息 + 参考消息
185
+ return {"messages": [sys_msg] + state["messages"] + [reference_msg]}
186
+
187
+ def assistant_node(state: MessagesState):
188
+ """AI节点:处理消息并决定下一步动作"""
189
+ # 调用AI模型处理当前消息状态
190
+ response = llm_with_tools.invoke(state["messages"])
191
+ return {"messages": [response]}
192
+
193
+ # 4. 构建图结构
194
+ builder = StateGraph(MessagesState)
195
+
196
+ # 添加节点
197
+ builder.add_node("retriever", retriever_node) # 检索节点
198
+ builder.add_node("assistant", assistant_node) # AI处理节点
199
+ builder.add_node("tools", ToolNode(tools)) # 工具执行节点
200
+
201
+ # 设置节点间关系
202
+ builder.add_edge(START, "retriever") # 开始 -> 检索
203
+ builder.add_edge("retriever", "assistant") # 检索 -> AI处理
204
+
205
+ # 条件边:AI处理后判断是否需要调用工具
206
+ builder.add_conditional_edges(
207
+ "assistant",
208
+ tools_condition, # 内置工具判断条件
209
+ )
210
+
211
+ # 工具执行后返回AI节点
212
+ builder.add_edge("tools", "assistant")
213
+
214
+ # 5. 编译图
215
+ return builder.compile()
216
+
217
+ # ======================
218
+ # 测试执行
219
+ # ======================
220
+
221
+ if __name__ == "__main__":
222
+ # 测试问题
223
+ question = "托马斯·阿奎纳斯的图片是什么时候首次添加到双重效应原则的维基百科页面的?"
224
+
225
+ # 构建图(使用Groq提供商)
226
+ agent_graph = build_graph(provider="groq")
227
+
228
+ # 初始化消息
229
+ messages = [HumanMessage(content=question)]
230
+
231
+ # 执行图工作流
232
+ result = agent_graph.invoke({"messages": messages})
233
+
234
+ # 打印所有消息
235
+ print("\n===== 完整对话记录 =====")
236
+ for msg in result["messages"]:
237
+ print(f"[{msg.type}]: {msg.content[:200]}...")
238
+
239
+ # 提取最终回答
240
+ final_answer = result["messages"][-1].content
241
+ print("\n===== 最终回答 =====")
242
+ print(final_answer)