superone001 commited on
Commit
d71a902
·
verified ·
1 Parent(s): cb66876

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +186 -133
agent.py CHANGED
@@ -1,213 +1,266 @@
1
- """LangGraph Agent"""
2
- import os
3
  from dotenv import load_dotenv
4
- from langgraph.graph import START, StateGraph, MessagesState
5
- from langgraph.prebuilt import tools_condition
6
- from langgraph.prebuilt import ToolNode
7
- from langchain_google_genai import ChatGoogleGenerativeAI
8
- from langchain_groq import ChatGroq
9
- from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFaceEmbeddings
10
- from langchain_community.tools.tavily_search import TavilySearchResults
11
  from langchain_community.document_loaders import WikipediaLoader
12
  from langchain_community.document_loaders import ArxivLoader
13
- from langchain_community.vectorstores import SupabaseVectorStore
 
 
 
14
  from langchain_core.messages import SystemMessage, HumanMessage
15
- from langchain_core.tools import tool
16
- from langchain.tools.retriever import create_retriever_tool
17
- from supabase.client import Client, create_client
 
 
 
18
 
19
  load_dotenv()
20
 
21
  @tool
22
- def multiply(a: int, b: int) -> int:
23
- """Multiply two numbers.
 
 
24
  Args:
25
- a: first int
26
- b: second int
27
  """
28
- return a * b
29
 
30
  @tool
31
- def add(a: int, b: int) -> int:
32
- """Add two numbers.
 
33
 
34
  Args:
35
- a: first int
36
- b: second int
37
  """
38
- return a + b
39
 
40
  @tool
41
- def subtract(a: int, b: int) -> int:
42
- """Subtract two numbers.
 
43
 
44
  Args:
45
- a: first int
46
- b: second int
47
  """
48
- return a - b
49
 
50
  @tool
51
  def divide(a: int, b: int) -> int:
52
- """Divide two numbers.
 
53
 
54
  Args:
55
- a: first int
56
- b: second int
57
  """
58
  if b == 0:
59
- raise ValueError("Cannot divide by zero.")
60
  return a / b
61
 
62
  @tool
63
- def modulus(a: int, b: int) -> int:
64
- """Get the modulus of two numbers.
 
65
 
66
  Args:
67
- a: first int
68
- b: second int
69
  """
70
  return a % b
71
 
72
  @tool
73
  def wiki_search(query: str) -> str:
74
- """Search Wikipedia for a query and return maximum 2 results.
 
75
 
76
  Args:
77
- query: The search query."""
78
- search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
79
- formatted_search_docs = "\n\n---\n\n".join(
 
80
  [
81
- f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
82
  for doc in search_docs
83
  ])
84
  return {"wiki_results": formatted_search_docs}
85
 
86
  @tool
87
- def web_search(query: str) -> str:
88
- """Search Tavily for a query and return maximum 3 results.
 
 
89
 
90
  Args:
91
- query: The search query."""
92
- search_docs = TavilySearchResults(max_results=3).invoke(query=query)
93
- formatted_search_docs = "\n\n---\n\n".join(
 
94
  [
95
- f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
96
  for doc in search_docs
97
  ])
98
- return {"web_results": formatted_search_docs}
99
 
100
  @tool
101
- def arvix_search(query: str) -> str:
102
- """Search Arxiv for a query and return maximum 3 result.
 
103
 
104
  Args:
105
- query: The search query."""
106
- search_docs = ArxivLoader(query=query, load_max_docs=3).load()
107
- formatted_search_docs = "\n\n---\n\n".join(
 
108
  [
109
- f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content[:1000]}\n</Document>'
110
  for doc in search_docs
111
  ])
112
- return {"arvix_results": formatted_search_docs}
113
-
114
-
115
-
116
- # load the system prompt from the file
117
- with open("system_prompt.txt", "r", encoding="utf-8") as f:
118
- system_prompt = f.read()
119
-
120
- # System message
121
- sys_msg = SystemMessage(content=system_prompt)
122
 
123
- # build a retriever
124
- embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2") # dim=768
125
- supabase: Client = create_client(
126
- os.environ.get("SUPABASE_URL"),
127
- os.environ.get("SUPABASE_SERVICE_KEY"))
128
- vector_store = SupabaseVectorStore(
129
- client=supabase,
130
- embedding= embeddings,
131
- table_name="documents",
132
- query_name="match_documents_langchain",
133
- )
134
- create_retriever_tool = create_retriever_tool(
135
- retriever=vector_store.as_retriever(),
136
- name="Question Search",
137
- description="A tool to retrieve similar questions from a vector store.",
138
- )
139
 
 
 
 
 
 
 
 
 
 
 
 
140
 
141
 
142
  tools = [
143
- multiply,
144
  add,
145
- subtract,
 
146
  divide,
147
- modulus,
148
  wiki_search,
149
- web_search,
150
  arvix_search,
 
 
 
151
  ]
152
 
153
- # Build graph function
154
- def build_graph(provider: str = "groq"):
155
- """Build the graph"""
156
- # Load environment variables from .env file
157
- if provider == "google":
158
- # Google Gemini
159
- llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0)
160
- elif provider == "groq":
161
- # Groq https://console.groq.com/docs/models
162
- llm = ChatGroq(model="qwen-qwq-32b", temperature=0) # optional : qwen-qwq-32b gemma2-9b-it
163
- elif provider == "huggingface":
164
- # TODO: Add huggingface endpoint
165
- llm = ChatHuggingFace(
166
- llm=HuggingFaceEndpoint(
167
- url="https://api-inference.huggingface.co/models/Meta-DeepLearning/llama-2-7b-chat-hf",
168
- temperature=0,
169
- ),
170
- )
171
- else:
172
- raise ValueError("Invalid provider. Choose 'google', 'groq' or 'huggingface'.")
173
- # Bind tools to LLM
174
- llm_with_tools = llm.bind_tools(tools)
175
 
176
- # Node
 
 
 
 
 
 
 
 
177
  def assistant(state: MessagesState):
178
- """Assistant node"""
179
- return {"messages": [llm_with_tools.invoke(state["messages"])]}
180
-
181
- def retriever(state: MessagesState):
182
- """Retriever node"""
183
- similar_question = vector_store.similarity_search(state["messages"][0].content)
184
- example_msg = HumanMessage(
185
- content=f"Here I provide a similar question and answer for reference: \n\n{similar_question[0].page_content}",
186
- )
187
- return {"messages": [sys_msg] + state["messages"] + [example_msg]}
188
 
 
189
  builder = StateGraph(MessagesState)
190
- builder.add_node("retriever", retriever)
191
  builder.add_node("assistant", assistant)
192
  builder.add_node("tools", ToolNode(tools))
193
- builder.add_edge(START, "retriever")
194
- builder.add_edge("retriever", "assistant")
195
- builder.add_conditional_edges(
196
- "assistant",
197
- tools_condition,
198
- )
199
  builder.add_edge("tools", "assistant")
200
 
201
  # Compile graph
202
  return builder.compile()
203
 
204
- # test
 
205
  if __name__ == "__main__":
206
- question = "When was a picture of St. Thomas Aquinas first added to the Wikipedia page on the Principle of double effect?"
207
- # Build the graph
208
- graph = build_graph(provider="groq")
209
- # Run the graph
210
- messages = [HumanMessage(content=question)]
211
- messages = graph.invoke({"messages": messages})
212
- for m in messages["messages"]:
213
- m.pretty_print()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from dotenv import load_dotenv
2
+
3
+ from langchain_openai import ChatOpenAI
4
+ from langchain_core.tools import tool
 
 
 
 
5
  from langchain_community.document_loaders import WikipediaLoader
6
  from langchain_community.document_loaders import ArxivLoader
7
+ from langchain_community.tools.tavily_search import TavilySearchResults
8
+ from langchain_tavily import TavilyExtract
9
+ from youtube_transcript_api import YouTubeTranscriptApi
10
+
11
  from langchain_core.messages import SystemMessage, HumanMessage
12
+ from langgraph.graph import START, StateGraph, MessagesState
13
+ from langgraph.prebuilt import ToolNode
14
+ from langgraph.prebuilt import tools_condition
15
+ import base64
16
+ import httpx
17
+
18
 
19
  load_dotenv()
20
 
21
  @tool
22
+ def add(a: int, b: int) -> int:
23
+ """
24
+ Add b to a.
25
+
26
  Args:
27
+ a: first int number
28
+ b: second int number
29
  """
30
+ return a + b
31
 
32
  @tool
33
+ def substract(a: int, b: int) -> int:
34
+ """
35
+ Subtract b from a.
36
 
37
  Args:
38
+ a: first int number
39
+ b: second int number
40
  """
41
+ return a - b
42
 
43
  @tool
44
+ def multiply(a: int, b: int) -> int:
45
+ """
46
+ Multiply a by b.
47
 
48
  Args:
49
+ a: first int number
50
+ b: second int number
51
  """
52
+ return a * b
53
 
54
  @tool
55
  def divide(a: int, b: int) -> int:
56
+ """
57
+ Divide a by b.
58
 
59
  Args:
60
+ a: first int number
61
+ b: second int number
62
  """
63
  if b == 0:
64
+ raise ValueError("Can't divide by zero.")
65
  return a / b
66
 
67
  @tool
68
+ def mod(a: int, b: int) -> int:
69
+ """
70
+ Remainder of a devided by b.
71
 
72
  Args:
73
+ a: first int number
74
+ b: second int number
75
  """
76
  return a % b
77
 
78
  @tool
79
  def wiki_search(query: str) -> str:
80
+ """
81
+ Search Wikipedia.
82
 
83
  Args:
84
+ query: what to search for
85
+ """
86
+ search_docs = WikipediaLoader(query=query, load_max_docs=3).load()
87
+ formatted_search_docs = "".join(
88
  [
89
+ f'<START source="{doc.metadata["source"]}">{doc.page_content[:1000]}<END>'
90
  for doc in search_docs
91
  ])
92
  return {"wiki_results": formatted_search_docs}
93
 
94
  @tool
95
+ def arvix_search(query: str) -> str:
96
+ """
97
+ Search arXiv which is online archive of preprint and postprint manuscripts
98
+ for different fields of science.
99
 
100
  Args:
101
+ query: what to search for
102
+ """
103
+ search_docs = ArxivLoader(query=query, load_max_docs=3).load()
104
+ formatted_search_docs = "".join(
105
  [
106
+ f'<START source="{doc.metadata["source"]}">{doc.page_content[:1000]}<END>'
107
  for doc in search_docs
108
  ])
109
+ return {"arvix_results": formatted_search_docs}
110
 
111
  @tool
112
+ def web_search(query: str) -> str:
113
+ """
114
+ Search WEB.
115
 
116
  Args:
117
+ query: what to search for
118
+ """
119
+ search_docs = TavilySearchResults(max_results=3, include_answer=True).invoke({"query": query})
120
+ formatted_search_docs = "".join(
121
  [
122
+ f'<START source="{doc["url"]}">{doc["content"][:1000]}<END>'
123
  for doc in search_docs
124
  ])
125
+ return {"web_results": formatted_search_docs}
 
 
 
 
 
 
 
 
 
126
 
127
+ @tool
128
+ def open_web_page(url: str) -> str:
129
+ """
130
+ Open web page and get its content.
131
+
132
+ Args:
133
+ url: web page url in ""
134
+ """
135
+ search_docs = TavilyExtract().invoke({"urls": [url]})
136
+ formatted_search_docs = f'<START source="{search_docs["results"][0]["url"]}">{search_docs["results"][0]["raw_content"][:1000]}<END>'
137
+ return {"web_page_content": formatted_search_docs}
 
 
 
 
 
138
 
139
+ @tool
140
+ def youtube_transcript(url: str) -> str:
141
+ """
142
+ Get transcript of YouTube video.
143
+ Args:
144
+ url: YouTube video url in ""
145
+ """
146
+ video_id = url.partition("https://www.youtube.com/watch?v=")[2]
147
+ transcript = YouTubeTranscriptApi.get_transcript(video_id)
148
+ transcript_text = " ".join([item["text"] for item in transcript])
149
+ return {"youtube_transcript": transcript_text}
150
 
151
 
152
  tools = [
 
153
  add,
154
+ substract,
155
+ multiply,
156
  divide,
157
+ mod,
158
  wiki_search,
 
159
  arvix_search,
160
+ web_search,
161
+ open_web_page,
162
+ youtube_transcript,
163
  ]
164
 
165
+ # System prompt
166
+ system_prompt = f"""
167
+ You are a general AI assistant. I will ask you a question.
168
+ First, provide a step-by-step explanation of your reasoning to arrive at the answer.
169
+ Then, respond with your final answer in a single line, formatted as follows: "FINAL ANSWER: [YOUR FINAL ANSWER]".
170
+ [YOUR FINAL ANSWER] should be a number, a string, or a comma-separated list of numbers and/or strings, depending on the question.
171
+ If the answer is a number, do not use commas or units (e.g., $, %) unless specified.
172
+ If the answer is a string, do not use articles or abbreviations (e.g., for cities), and write digits in plain text unless specified.
173
+ If the answer is a comma-separated list, apply the above rules for each element based on whether it is a number or a string.
174
+ """
175
+ system_message = SystemMessage(content=system_prompt)
176
+
177
+ # Build graph
178
+ def build_graph():
179
+ """Build LangGrapth graph of agent."""
 
 
 
 
 
 
 
180
 
181
+ # Language model and tools
182
+ llm = ChatOpenAI(
183
+ model="gpt-4.1",
184
+ temperature=0,
185
+ max_retries=2
186
+ )
187
+ llm_with_tools = llm.bind_tools(tools, strict=True)
188
+
189
+ # Nodes
190
  def assistant(state: MessagesState):
191
+ """Assistant node."""
192
+ return {"messages": [llm_with_tools.invoke([system_message] + state["messages"])]}
 
 
 
 
 
 
 
 
193
 
194
+ # Graph
195
  builder = StateGraph(MessagesState)
 
196
  builder.add_node("assistant", assistant)
197
  builder.add_node("tools", ToolNode(tools))
198
+ builder.add_edge(START, "assistant")
199
+ builder.add_conditional_edges("assistant", tools_condition)
 
 
 
 
200
  builder.add_edge("tools", "assistant")
201
 
202
  # Compile graph
203
  return builder.compile()
204
 
205
+
206
+ # Testing and solving particular tasks
207
  if __name__ == "__main__":
208
+
209
+ agent = build_graph()
210
+
211
+ question = """
212
+ Review the chess position provided in the image. It is black's turn.
213
+ Provide the correct next move for black which guarantees a win.
214
+ Please provide your response in algebraic notation.
215
+ """
216
+ content_urls = {
217
+ "image": "https://agents-course-unit4-scoring.hf.space/files/cca530fc-4052-43b2-b130-b30968d8aa44",
218
+ "audio": None
219
+ }
220
+
221
+ # Define user message and add all the content
222
+ content = [
223
+ {
224
+ "type": "text",
225
+ "text": question
226
+ }
227
+ ]
228
+ if content_urls["image"]:
229
+ image_data = base64.b64encode(httpx.get(content_urls["image"]).content).decode("utf-8")
230
+ content.append(
231
+ {
232
+ "type": "image",
233
+ "source_type": "base64",
234
+ "data": image_data,
235
+ "mime_type": "image/jpeg"
236
+ }
237
+ )
238
+ if content_urls["audio"]:
239
+ audio_data = base64.b64encode(httpx.get(content_urls["audio"]).content).decode("utf-8")
240
+ content.append(
241
+ {
242
+ "type": "audio",
243
+ "source_type": "base64",
244
+ "data": audio_data,
245
+ "mime_type": "audio/wav"
246
+ }
247
+ )
248
+ messages = {
249
+ "role": "user",
250
+ "content": content
251
+ }
252
+
253
+ # Run agent on the question
254
+ messages = agent.invoke({"messages": messages})
255
+ for message in messages["messages"]:
256
+ message.pretty_print()
257
+
258
+ answer = messages["messages"][-1].content
259
+ index = answer.find("FINAL ANSWER: ")
260
+
261
+ print("\n")
262
+ print("="*30)
263
+ if index == -1:
264
+ print(answer)
265
+ print(answer[index+14:])
266
+ print("="*30)