superone001 commited on
Commit
7f49897
·
verified ·
1 Parent(s): 77444cc

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +204 -86
agent.py CHANGED
@@ -1,95 +1,213 @@
1
- from typing import Annotated, Sequence, TypedDict
2
- from langchain_community.llms import HuggingFaceHub,HuggingFaceEndpoint
3
- from langchain_core.messages import BaseMessage, HumanMessage, AIMessage
4
- from langgraph.graph import StateGraph, END
5
- from langchain_core.agents import AgentAction, AgentFinish
6
- from langchain.agents import create_react_agent
7
- from langchain import hub
8
- from ai_tools import get_tools # 导入自定义工具集
9
- import operator
10
-
11
- class AgentState(TypedDict):
12
- messages: Annotated[Sequence[BaseMessage], operator.add]
13
- intermediate_steps: Annotated[list, operator.add]
14
-
15
- def build_graph():
16
- # 1. 初始化模型 - 使用HuggingFace免费接口
17
- llm = HuggingFaceEndpoint(
18
- endpoint_url="https://api-inference.huggingface.co/models/mistralai/Mistral-7B-Instruct-v0.2",
19
- max_new_tokens=500, # 直接指定
20
- temperature=0.1, # 直接指定
21
- repetition_penalty=1.2, # 直接指定
22
- top_p=0.9, # 可选参数
23
- # 其他参数也可以直接在这里指定
24
- )
 
 
 
 
 
 
 
 
25
 
26
- # 2. 创建ReAct代理
27
- prompt = hub.pull("hwchase17/react")
28
- tools = get_tools()
29
- agent = create_react_agent(llm, tools, prompt)
 
 
 
 
 
30
 
31
- # 3. 定义节点行为
32
- def agent_node(state: AgentState):
33
- input = state["messages"][-1].content
34
- result = agent.invoke({
35
- "input": input,
36
- "intermediate_steps": state["intermediate_steps"]
37
- })
38
- return {"intermediate_steps": [result]}
 
39
 
40
- def tool_node(state: AgentState):
41
- last_step = state["intermediate_steps"][-1]
42
- action = last_step[0] if isinstance(last_step, list) else last_step
43
-
44
- if not isinstance(action, AgentAction):
45
- return {"messages": [AIMessage(content="Invalid action format")]}
46
-
47
- # 执行工具调用
48
- tool = next((t for t in tools if t.name == action.tool), None)
49
- if not tool:
50
- return {"messages": [AIMessage(content=f"Tool {action.tool} not found")]}
51
-
52
- observation = tool.run(action.tool_input)
53
- return {"messages": [AIMessage(content=observation)]}
54
 
55
- # 4. 构建状态图
56
- workflow = StateGraph(AgentState)
57
- workflow.add_node("agent", agent_node)
58
- workflow.add_node("tool", tool_node)
 
 
 
 
 
59
 
60
- # 5. 定义边和条件
61
- def route_action(state: AgentState):
62
- last_step = state["intermediate_steps"][-1]
63
- action = last_step[0] if isinstance(last_step, list) else last_step
64
-
65
- if isinstance(action, AgentFinish):
66
- return END
67
- return "tool"
 
 
 
 
 
68
 
69
- workflow.set_entry_point("agent")
70
- workflow.add_conditional_edges(
71
- "agent",
72
- route_action,
73
- {"tool": "tool", END: END}
74
- )
75
- workflow.add_edge("tool", "agent")
 
 
 
 
 
 
76
 
77
- return workflow.compile()
 
 
 
 
 
 
 
 
 
78
 
79
- class BasicAgent:
80
- """LangGraph智能体封装"""
81
- def __init__(self):
82
- print("BasicAgent initialized.")
83
- self.graph = build_graph()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
 
85
- def __call__(self, question: str) -> str:
86
- print(f"Agent received question: {question[:50]}...")
87
- messages = [HumanMessage(content=question)]
88
- result = self.graph.invoke({
89
- "messages": messages,
90
- "intermediate_steps": []
91
- })
92
-
93
- # 提取最终答案
94
- final_message = result["messages"][-1].content
95
- return final_message.strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """LangGraph Agent"""
2
+ import os
3
+ from dotenv import load_dotenv
4
+ from langgraph.graph import START, StateGraph, MessagesState
5
+ from langgraph.prebuilt import tools_condition
6
+ from langgraph.prebuilt import ToolNode
7
+ from langchain_google_genai import ChatGoogleGenerativeAI
8
+ from langchain_groq import ChatGroq
9
+ from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFaceEmbeddings
10
+ from langchain_community.tools.tavily_search import TavilySearchResults
11
+ from langchain_community.document_loaders import WikipediaLoader
12
+ from langchain_community.document_loaders import ArxivLoader
13
+ from langchain_community.vectorstores import SupabaseVectorStore
14
+ from langchain_core.messages import SystemMessage, HumanMessage
15
+ from langchain_core.tools import tool
16
+ from langchain.tools.retriever import create_retriever_tool
17
+ from supabase.client import Client, create_client
18
+
19
+ load_dotenv()
20
+
21
+ @tool
22
+ def multiply(a: int, b: int) -> int:
23
+ """Multiply two numbers.
24
+ Args:
25
+ a: first int
26
+ b: second int
27
+ """
28
+ return a * b
29
+
30
+ @tool
31
+ def add(a: int, b: int) -> int:
32
+ """Add two numbers.
33
 
34
+ Args:
35
+ a: first int
36
+ b: second int
37
+ """
38
+ return a + b
39
+
40
+ @tool
41
+ def subtract(a: int, b: int) -> int:
42
+ """Subtract two numbers.
43
 
44
+ Args:
45
+ a: first int
46
+ b: second int
47
+ """
48
+ return a - b
49
+
50
+ @tool
51
+ def divide(a: int, b: int) -> int:
52
+ """Divide two numbers.
53
 
54
+ Args:
55
+ a: first int
56
+ b: second int
57
+ """
58
+ if b == 0:
59
+ raise ValueError("Cannot divide by zero.")
60
+ return a / b
61
+
62
+ @tool
63
+ def modulus(a: int, b: int) -> int:
64
+ """Get the modulus of two numbers.
 
 
 
65
 
66
+ Args:
67
+ a: first int
68
+ b: second int
69
+ """
70
+ return a % b
71
+
72
+ @tool
73
+ def wiki_search(query: str) -> str:
74
+ """Search Wikipedia for a query and return maximum 2 results.
75
 
76
+ Args:
77
+ query: The search query."""
78
+ search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
79
+ formatted_search_docs = "\n\n---\n\n".join(
80
+ [
81
+ f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
82
+ for doc in search_docs
83
+ ])
84
+ return {"wiki_results": formatted_search_docs}
85
+
86
+ @tool
87
+ def web_search(query: str) -> str:
88
+ """Search Tavily for a query and return maximum 3 results.
89
 
90
+ Args:
91
+ query: The search query."""
92
+ search_docs = TavilySearchResults(max_results=3).invoke(query=query)
93
+ formatted_search_docs = "\n\n---\n\n".join(
94
+ [
95
+ f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
96
+ for doc in search_docs
97
+ ])
98
+ return {"web_results": formatted_search_docs}
99
+
100
+ @tool
101
+ def arvix_search(query: str) -> str:
102
+ """Search Arxiv for a query and return maximum 3 result.
103
 
104
+ Args:
105
+ query: The search query."""
106
+ search_docs = ArxivLoader(query=query, load_max_docs=3).load()
107
+ formatted_search_docs = "\n\n---\n\n".join(
108
+ [
109
+ f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content[:1000]}\n</Document>'
110
+ for doc in search_docs
111
+ ])
112
+ return {"arvix_results": formatted_search_docs}
113
+
114
 
115
+
116
+ # load the system prompt from the file
117
+ with open("system_prompt.txt", "r", encoding="utf-8") as f:
118
+ system_prompt = f.read()
119
+
120
+ # System message
121
+ sys_msg = SystemMessage(content=system_prompt)
122
+
123
+ # build a retriever
124
+ embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2") # dim=768
125
+ supabase: Client = create_client(
126
+ os.environ.get("SUPABASE_URL"),
127
+ os.environ.get("SUPABASE_SERVICE_KEY"))
128
+ vector_store = SupabaseVectorStore(
129
+ client=supabase,
130
+ embedding= embeddings,
131
+ table_name="documents",
132
+ query_name="match_documents_langchain",
133
+ )
134
+ create_retriever_tool = create_retriever_tool(
135
+ retriever=vector_store.as_retriever(),
136
+ name="Question Search",
137
+ description="A tool to retrieve similar questions from a vector store.",
138
+ )
139
+
140
+
141
+
142
+ tools = [
143
+ multiply,
144
+ add,
145
+ subtract,
146
+ divide,
147
+ modulus,
148
+ wiki_search,
149
+ web_search,
150
+ arvix_search,
151
+ ]
152
+
153
+ # Build graph function
154
+ def build_graph(provider: str = "groq"):
155
+ """Build the graph"""
156
+ # Load environment variables from .env file
157
+ if provider == "google":
158
+ # Google Gemini
159
+ llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0)
160
+ elif provider == "groq":
161
+ # Groq https://console.groq.com/docs/models
162
+ llm = ChatGroq(model="qwen-qwq-32b", temperature=0) # optional : qwen-qwq-32b gemma2-9b-it
163
+ elif provider == "huggingface":
164
+ # TODO: Add huggingface endpoint
165
+ llm = ChatHuggingFace(
166
+ llm=HuggingFaceEndpoint(
167
+ url="https://api-inference.huggingface.co/models/Meta-DeepLearning/llama-2-7b-chat-hf",
168
+ temperature=0,
169
+ ),
170
+ )
171
+ else:
172
+ raise ValueError("Invalid provider. Choose 'google', 'groq' or 'huggingface'.")
173
+ # Bind tools to LLM
174
+ llm_with_tools = llm.bind_tools(tools)
175
+
176
+ # Node
177
+ def assistant(state: MessagesState):
178
+ """Assistant node"""
179
+ return {"messages": [llm_with_tools.invoke(state["messages"])]}
180
 
181
+ def retriever(state: MessagesState):
182
+ """Retriever node"""
183
+ similar_question = vector_store.similarity_search(state["messages"][0].content)
184
+ example_msg = HumanMessage(
185
+ content=f"Here I provide a similar question and answer for reference: \n\n{similar_question[0].page_content}",
186
+ )
187
+ return {"messages": [sys_msg] + state["messages"] + [example_msg]}
188
+
189
+ builder = StateGraph(MessagesState)
190
+ builder.add_node("retriever", retriever)
191
+ builder.add_node("assistant", assistant)
192
+ builder.add_node("tools", ToolNode(tools))
193
+ builder.add_edge(START, "retriever")
194
+ builder.add_edge("retriever", "assistant")
195
+ builder.add_conditional_edges(
196
+ "assistant",
197
+ tools_condition,
198
+ )
199
+ builder.add_edge("tools", "assistant")
200
+
201
+ # Compile graph
202
+ return builder.compile()
203
+
204
+ # test
205
+ if __name__ == "__main__":
206
+ question = "When was a picture of St. Thomas Aquinas first added to the Wikipedia page on the Principle of double effect?"
207
+ # Build the graph
208
+ graph = build_graph(provider="groq")
209
+ # Run the graph
210
+ messages = [HumanMessage(content=question)]
211
+ messages = graph.invoke({"messages": messages})
212
+ for m in messages["messages"]:
213
+ m.pretty_print()