Yongkang ZOU commited on
Commit
fc07371
·
1 Parent(s): 7cbe9e1

update agent

Browse files
Files changed (1) hide show
  1. agent.py +53 -146
agent.py CHANGED
@@ -1,213 +1,120 @@
1
- """LangGraph Agent"""
2
  import os
3
  from dotenv import load_dotenv
4
  from langgraph.graph import START, StateGraph, MessagesState
5
- from langgraph.prebuilt import tools_condition
6
- from langgraph.prebuilt import ToolNode
7
  from langchain_google_genai import ChatGoogleGenerativeAI
8
  from langchain_groq import ChatGroq
9
- from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFaceEmbeddings
10
  from langchain_community.tools.tavily_search import TavilySearchResults
11
- from langchain_community.document_loaders import WikipediaLoader
12
- from langchain_community.document_loaders import ArxivLoader
13
- from langchain_community.vectorstores import SupabaseVectorStore
14
  from langchain_core.messages import SystemMessage, HumanMessage
15
  from langchain_core.tools import tool
16
- from langchain.tools.retriever import create_retriever_tool
17
- # from supabase.client import Client, create_client
18
 
19
  load_dotenv()
20
 
 
 
21
  @tool
22
  def multiply(a: int, b: int) -> int:
23
- """Multiply two numbers.
24
- Args:
25
- a: first int
26
- b: second int
27
- """
28
  return a * b
29
 
30
  @tool
31
  def add(a: int, b: int) -> int:
32
- """Add two numbers.
33
-
34
- Args:
35
- a: first int
36
- b: second int
37
- """
38
  return a + b
39
 
40
  @tool
41
  def subtract(a: int, b: int) -> int:
42
- """Subtract two numbers.
43
-
44
- Args:
45
- a: first int
46
- b: second int
47
- """
48
  return a - b
49
 
50
  @tool
51
- def divide(a: int, b: int) -> int:
52
- """Divide two numbers.
53
-
54
- Args:
55
- a: first int
56
- b: second int
57
- """
58
  if b == 0:
59
  raise ValueError("Cannot divide by zero.")
60
  return a / b
61
 
62
  @tool
63
  def modulus(a: int, b: int) -> int:
64
- """Get the modulus of two numbers.
65
-
66
- Args:
67
- a: first int
68
- b: second int
69
- """
70
  return a % b
71
 
72
  @tool
73
  def wiki_search(query: str) -> str:
74
- """Search Wikipedia for a query and return maximum 2 results.
75
-
76
- Args:
77
- query: The search query."""
78
- search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
79
- formatted_search_docs = "\n\n---\n\n".join(
80
- [
81
- f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
82
- for doc in search_docs
83
- ])
84
- return {"wiki_results": formatted_search_docs}
85
 
86
  @tool
87
  def web_search(query: str) -> str:
88
- """Search Tavily for a query and return maximum 3 results.
89
-
90
- Args:
91
- query: The search query."""
92
- search_docs = TavilySearchResults(max_results=3).invoke(query=query)
93
- formatted_search_docs = "\n\n---\n\n".join(
94
- [
95
- f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
96
- for doc in search_docs
97
- ])
98
- return {"web_results": formatted_search_docs}
99
 
100
  @tool
101
  def arvix_search(query: str) -> str:
102
- """Search Arxiv for a query and return maximum 3 result.
103
-
104
- Args:
105
- query: The search query."""
106
- search_docs = ArxivLoader(query=query, load_max_docs=3).load()
107
- formatted_search_docs = "\n\n---\n\n".join(
108
- [
109
- f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content[:1000]}\n</Document>'
110
- for doc in search_docs
111
- ])
112
- return {"arvix_results": formatted_search_docs}
113
-
114
-
115
-
116
- # load the system prompt from the file
117
- with open("system_prompt.txt", "r", encoding="utf-8") as f:
118
- system_prompt = f.read()
119
-
120
- # System message
121
  sys_msg = SystemMessage(content=system_prompt)
122
 
123
- # # build a retriever
124
- # embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2") # dim=768
125
- # supabase: Client = create_client(
126
- # os.environ.get("SUPABASE_URL"),
127
- # os.environ.get("SUPABASE_SERVICE_KEY"))
128
- # vector_store = SupabaseVectorStore(
129
- # client=supabase,
130
- # embedding= embeddings,
131
- # table_name="documents",
132
- # query_name="match_documents_langchain",
133
- # )
134
- # create_retriever_tool = create_retriever_tool(
135
- # retriever=vector_store.as_retriever(),
136
- # name="Question Search",
137
- # description="A tool to retrieve similar questions from a vector store.",
138
- # )
139
-
140
-
141
-
142
- tools = [
143
- multiply,
144
- add,
145
- subtract,
146
- divide,
147
- modulus,
148
- wiki_search,
149
- web_search,
150
- arvix_search,
151
- ]
152
-
153
- # Build graph function
154
  def build_graph(provider: str = "groq"):
155
- """Build the graph"""
156
- # Load environment variables from .env file
157
  if provider == "google":
158
- # Google Gemini
159
  llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0)
160
  elif provider == "groq":
161
- # Groq https://console.groq.com/docs/models
162
- llm = ChatGroq(model="qwen-qwq-32b", temperature=0) # optional : qwen-qwq-32b gemma2-9b-it
163
  elif provider == "huggingface":
164
- # TODO: Add huggingface endpoint
165
  llm = ChatHuggingFace(
166
  llm=HuggingFaceEndpoint(
167
  url="https://api-inference.huggingface.co/models/Meta-DeepLearning/llama-2-7b-chat-hf",
168
- temperature=0,
169
- ),
170
  )
171
  else:
172
- raise ValueError("Invalid provider. Choose 'google', 'groq' or 'huggingface'.")
173
- # Bind tools to LLM
174
  llm_with_tools = llm.bind_tools(tools)
175
 
176
- # Node
177
  def assistant(state: MessagesState):
178
- """Assistant node"""
179
- return {"messages": [llm_with_tools.invoke(state["messages"])]}
180
-
181
- # def retriever(state: MessagesState):
182
- # """Retriever node"""
183
- # similar_question = vector_store.similarity_search(state["messages"][0].content)
184
- # example_msg = HumanMessage(
185
- # content=f"Here I provide a similar question and answer for reference: \n\n{similar_question[0].page_content}",
186
- # )
187
- # return {"messages": [sys_msg] + state["messages"] + [example_msg]}
188
 
 
189
  builder = StateGraph(MessagesState)
190
- # builder.add_node("retriever", retriever)
191
  builder.add_node("assistant", assistant)
192
  builder.add_node("tools", ToolNode(tools))
193
- builder.add_edge(START, "retriever")
194
- builder.add_edge("retriever", "assistant")
195
- builder.add_conditional_edges(
196
- "assistant",
197
- tools_condition,
198
- )
199
  builder.add_edge("tools", "assistant")
200
 
201
- # Compile graph
202
  return builder.compile()
203
 
204
- # test
 
205
  if __name__ == "__main__":
206
- question = "When was a picture of St. Thomas Aquinas first added to the Wikipedia page on the Principle of double effect?"
207
- # Build the graph
208
  graph = build_graph(provider="groq")
209
- # Run the graph
210
- messages = [HumanMessage(content=question)]
211
- messages = graph.invoke({"messages": messages})
212
  for m in messages["messages"]:
213
  m.pretty_print()
 
 
1
  import os
2
  from dotenv import load_dotenv
3
  from langgraph.graph import START, StateGraph, MessagesState
4
+ from langgraph.prebuilt import tools_condition, ToolNode
 
5
  from langchain_google_genai import ChatGoogleGenerativeAI
6
  from langchain_groq import ChatGroq
7
+ from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint
8
  from langchain_community.tools.tavily_search import TavilySearchResults
9
+ from langchain_community.document_loaders import WikipediaLoader, ArxivLoader
 
 
10
  from langchain_core.messages import SystemMessage, HumanMessage
11
  from langchain_core.tools import tool
 
 
12
 
13
  load_dotenv()
14
 
15
+ # ------------------- TOOL DEFINITIONS -------------------
16
+
17
  @tool
18
  def multiply(a: int, b: int) -> int:
19
+ """Multiply two numbers."""
 
 
 
 
20
  return a * b
21
 
22
  @tool
23
  def add(a: int, b: int) -> int:
24
+ """Add two numbers."""
 
 
 
 
 
25
  return a + b
26
 
27
  @tool
28
  def subtract(a: int, b: int) -> int:
29
+ """Subtract two numbers."""
 
 
 
 
 
30
  return a - b
31
 
32
  @tool
33
+ def divide(a: int, b: int) -> float:
34
+ """Divide two numbers."""
 
 
 
 
 
35
  if b == 0:
36
  raise ValueError("Cannot divide by zero.")
37
  return a / b
38
 
39
  @tool
40
  def modulus(a: int, b: int) -> int:
41
+ """Get the modulus of two numbers."""
 
 
 
 
 
42
  return a % b
43
 
44
  @tool
45
  def wiki_search(query: str) -> str:
46
+ """Search Wikipedia for a query (max 2 results)."""
47
+ docs = WikipediaLoader(query=query, load_max_docs=2).load()
48
+ return "\n\n".join([doc.page_content for doc in docs])
 
 
 
 
 
 
 
 
49
 
50
  @tool
51
  def web_search(query: str) -> str:
52
+ """Search the web using Tavily (max 3 results)."""
53
+ docs = TavilySearchResults(max_results=3).invoke(query)
54
+ return "\n\n".join([doc.page_content for doc in docs])
 
 
 
 
 
 
 
 
55
 
56
  @tool
57
  def arvix_search(query: str) -> str:
58
+ """Search Arxiv for academic papers (max 3)."""
59
+ docs = ArxivLoader(query=query, load_max_docs=3).load()
60
+ return "\n\n".join([doc.page_content[:1000] for doc in docs])
61
+
62
+ tools = [multiply, add, subtract, divide, modulus, wiki_search, web_search, arvix_search]
63
+
64
+ # ------------------- SYSTEM PROMPT -------------------
65
+
66
+ system_prompt_path = "system_prompt.txt"
67
+ if os.path.exists(system_prompt_path):
68
+ with open(system_prompt_path, "r", encoding="utf-8") as f:
69
+ system_prompt = f.read()
70
+ else:
71
+ system_prompt = (
72
+ "You are an intelligent AI agent who can solve math, science, factual, and research-based problems. "
73
+ "You can use tools like Wikipedia, Web search, or Arxiv when needed. Always give precise and helpful answers."
74
+ )
 
 
75
  sys_msg = SystemMessage(content=system_prompt)
76
 
77
+ # ------------------- GRAPH CONSTRUCTION -------------------
78
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  def build_graph(provider: str = "groq"):
80
+ """Build the LangGraph with tool-use."""
81
+ # Select LLM provider
82
  if provider == "google":
 
83
  llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0)
84
  elif provider == "groq":
85
+ llm = ChatGroq(model="qwen-qwq-32b", temperature=0)
 
86
  elif provider == "huggingface":
 
87
  llm = ChatHuggingFace(
88
  llm=HuggingFaceEndpoint(
89
  url="https://api-inference.huggingface.co/models/Meta-DeepLearning/llama-2-7b-chat-hf",
90
+ temperature=0
91
+ )
92
  )
93
  else:
94
+ raise ValueError("Invalid provider")
95
+
96
  llm_with_tools = llm.bind_tools(tools)
97
 
 
98
  def assistant(state: MessagesState):
99
+ return {"messages": [sys_msg] + [llm_with_tools.invoke(state["messages"])]}
 
 
 
 
 
 
 
 
 
100
 
101
+ # Build the graph with assistant and tools
102
  builder = StateGraph(MessagesState)
 
103
  builder.add_node("assistant", assistant)
104
  builder.add_node("tools", ToolNode(tools))
105
+
106
+ builder.add_edge(START, "assistant")
107
+ builder.add_conditional_edges("assistant", tools_condition)
 
 
 
108
  builder.add_edge("tools", "assistant")
109
 
 
110
  return builder.compile()
111
 
112
+ # ------------------- LOCAL TEST -------------------
113
+
114
  if __name__ == "__main__":
115
+ question = "What is 17 * 23?"
 
116
  graph = build_graph(provider="groq")
117
+ messages = graph.invoke({"messages": [HumanMessage(content=question)]})
118
+ print("=== AI Agent Response ===")
 
119
  for m in messages["messages"]:
120
  m.pretty_print()