GoldsWolf commited on
Commit
1cd6191
·
1 Parent(s): 81917a3

Updated agent.py

Browse files
Files changed (1) hide show
  1. agent.py +373 -0
agent.py ADDED
@@ -0,0 +1,373 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """LangGraph Agent"""
2
+ import os
3
+ import tempfile
4
+ import cmath
5
+ import pandas as pd
6
+ from dotenv import load_dotenv
7
+ from langgraph.graph import START, StateGraph, MessagesState
8
+ from langgraph.prebuilt import tools_condition
9
+ from langgraph.prebuilt import ToolNode
10
+ from langchain_google_genai import ChatGoogleGenerativeAI, GoogleGenerativeAIEmbeddings
11
+ from langchain_groq import ChatGroq
12
+ from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFaceEmbeddings
13
+ from langchain_community.tools.tavily_search import TavilySearchResults
14
+ from langchain_community.document_loaders import WikipediaLoader
15
+ from langchain_community.document_loaders import ArxivLoader
16
+ from langchain_community.vectorstores import SupabaseVectorStore
17
+ from langchain_core.messages import SystemMessage, HumanMessage
18
+ from langchain_core.tools import tool
19
+ from langchain.tools.retriever import create_retriever_tool
20
+ from supabase.client import Client, create_client
21
+ from typing import List, Dict, Any, Optional
22
+
23
+ load_dotenv()
24
+
25
+ @tool
26
+ def multiply(a: int, b: int) -> int:
27
+ """
28
+ Multiply two integers.
29
+
30
+ Args:
31
+ a (int): The first integer.
32
+ b (int): The second integer.
33
+
34
+ Returns:
35
+ int: The product of a and b.
36
+ """
37
+ return a * b
38
+
39
+ @tool
40
+ def add(a: int, b: int) -> int:
41
+ """
42
+ Add two integers.
43
+
44
+ Args:
45
+ a (int): The first integer.
46
+ b (int): The second integer.
47
+
48
+ Returns:
49
+ int: The sum of a and b.
50
+ """
51
+ return a + b
52
+
53
+ @tool
54
+ def subtract(a: int, b: int) -> int:
55
+ """
56
+ Subtract one integer from another.
57
+
58
+ Args:
59
+ a (int): The integer to subtract from.
60
+ b (int): The integer to subtract.
61
+
62
+ Returns:
63
+ int: The result of a minus b.
64
+ """
65
+ return a - b
66
+
67
+ @tool
68
+ def divide(a: int, b: int) -> float:
69
+ """
70
+ Divide one integer by another.
71
+
72
+ Args:
73
+ a (int): The numerator.
74
+ b (int): The denominator. Must not be zero.
75
+
76
+ Returns:
77
+ float: The result of a divided by b.
78
+
79
+ Raises:
80
+ ValueError: If b is zero.
81
+ """
82
+ if b == 0:
83
+ raise ValueError("Cannot divide by zero.")
84
+ return a / b
85
+
86
+ @tool
87
+ def modulus(a: int, b: int) -> int:
88
+ """
89
+ Compute the modulus (remainder) of two integers.
90
+
91
+ Args:
92
+ a (int): The dividend.
93
+ b (int): The divisor.
94
+
95
+ Returns:
96
+ int: The remainder after dividing a by b.
97
+ """
98
+ return a % b
99
+
100
+ @tool
101
+ def power(a: float, b: float) -> float:
102
+ """
103
+ Raise a number to the power of another number.
104
+
105
+ Args:
106
+ a (float): The base number.
107
+ b (float): The exponent.
108
+
109
+ Returns:
110
+ float: The result of a raised to the power of b.
111
+ """
112
+ return a**b
113
+
114
+ @tool
115
+ def square_root(a: float) -> float | complex:
116
+ """
117
+ Compute the square root of a number. Returns a complex number if input is negative.
118
+
119
+ Args:
120
+ a (float): The number to compute the square root of.
121
+
122
+ Returns:
123
+ float or complex: The square root of a. Complex if a < 0.
124
+ """
125
+ if a >= 0:
126
+ return a**0.5
127
+ return cmath.sqrt(a)
128
+
129
+ ### =============== DOCUMENT PROCESSING TOOLS =============== ###
130
+
131
+ @tool
132
+ def save_and_read_file(content: str, filename: Optional[str] = None) -> str:
133
+ """
134
+ Save text content to a file and return the file path.
135
+
136
+ Args:
137
+ content (str): The text content to save.
138
+ filename (str, optional): The name of the file. If not provided, a random name is generated.
139
+
140
+ Returns:
141
+ str: The file path where the content was saved.
142
+ """
143
+ temp_dir = tempfile.gettempdir()
144
+ if filename is None:
145
+ temp_file = tempfile.NamedTemporaryFile(delete=False, dir=temp_dir)
146
+ filepath = temp_file.name
147
+ else:
148
+ filepath = os.path.join(temp_dir, filename)
149
+
150
+ with open(filepath, "w") as f:
151
+ f.write(content)
152
+
153
+ return f"File saved to {filepath}. You can read this file to process its contents."
154
+
155
+ @tool
156
+ def analyze_csv_file(file_path: str, query: str) -> str:
157
+ """
158
+ Analyze a CSV file and answer a question about its data.
159
+
160
+ Args:
161
+ file_path (str): The path to the CSV file.
162
+ query (str): The question to answer about the data.
163
+
164
+ Returns:
165
+ str: The analysis result or error message.
166
+ """
167
+ try:
168
+ df = pd.read_csv(file_path)
169
+ result = f"CSV file loaded with {len(df)} rows and {len(df.columns)} columns.\n"
170
+ result += f"Columns: {', '.join(df.columns)}\n\n"
171
+ result += "Summary statistics:\n"
172
+ result += str(df.describe())
173
+ return result
174
+ except Exception as e:
175
+ return f"Error analyzing CSV file: {str(e)}"
176
+
177
+ @tool
178
+ def analyze_excel_file(file_path: str, query: str) -> str:
179
+ """
180
+ Analyze an Excel file and answer a question about its data.
181
+
182
+ Args:
183
+ file_path (str): The path to the Excel file.
184
+ query (str): The question to answer about the data.
185
+
186
+ Returns:
187
+ str: The analysis result or error message.
188
+ """
189
+ try:
190
+ df = pd.read_excel(file_path)
191
+ result = (
192
+ f"Excel file loaded with {len(df)} rows and {len(df.columns)} columns.\n"
193
+ )
194
+ result += f"Columns: {', '.join(df.columns)}\n\n"
195
+ result += "Summary statistics:\n"
196
+ result += str(df.describe())
197
+ return result
198
+ except Exception as e:
199
+ return f"Error analyzing Excel file: {str(e)}"
200
+
201
+ @tool
202
+ def wiki_search(input: str) -> str:
203
+ """
204
+ Search Wikipedia for a query and return up to 2 results.
205
+
206
+ Args:
207
+ input (str): The search query string.
208
+
209
+ Returns:
210
+ str: A formatted string containing up to 2 Wikipedia search results.
211
+ """
212
+ search_docs = WikipediaLoader(query=input, load_max_docs=2).load()
213
+ formatted_search_docs = "\n\n---\n\n".join(
214
+ [
215
+ f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
216
+ for doc in search_docs
217
+ ])
218
+ return {"wiki_results": formatted_search_docs}
219
+
220
+ @tool
221
+ def web_search(input: str) -> str:
222
+ """
223
+ Search the web using Tavily and return up to 5 results.
224
+
225
+ Args:
226
+ input (str): The search query string.
227
+
228
+ Returns:
229
+ str: A formatted string containing up to 5 web search results.
230
+ """
231
+ search_docs = TavilySearchResults(max_results=5).invoke(input)
232
+ formatted_search_docs = "\n\n---\n\n".join(
233
+ [
234
+ (
235
+ f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
236
+ if hasattr(doc, "metadata") and hasattr(doc, "page_content")
237
+ else
238
+ f'<Document source="{doc.get("source", "")}" page="{doc.get("page", "")}"/>\n{doc.get("content", doc.get("page_content", ""))}\n</Document>'
239
+ )
240
+ for doc in search_docs
241
+ ]
242
+ )
243
+ return {"web_results": formatted_search_docs}
244
+
245
+ @tool
246
+ def arvix_search(input: str) -> str:
247
+ """
248
+ Search Arxiv for a query and return up to 3 results.
249
+
250
+ Args:
251
+ input (str): The search query string.
252
+
253
+ Returns:
254
+ str: A formatted string containing up to 3 Arxiv search results.
255
+ """
256
+ search_docs = ArxivLoader(query=input, load_max_docs=3).load()
257
+ formatted_search_docs = "\n\n---\n\n".join(
258
+ [
259
+ f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content[:1000]}\n</Document>'
260
+ for doc in search_docs
261
+ ])
262
+ return {"arvix_results": formatted_search_docs}
263
+
264
+ # load the system prompt from the file
265
+ with open("system_prompt.txt", "r", encoding="utf-8") as f:
266
+ system_prompt = f.read()
267
+
268
+ # System message
269
+ sys_msg = SystemMessage(content=system_prompt)
270
+
271
+ # build a retriever
272
+ embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2") # dim=768
273
+ #embeddings = GoogleGenerativeAIEmbeddings(model="models/gemini-embedding-exp-03-07")
274
+ supabase: Client = create_client(
275
+ os.environ.get("SUPABASE_URL"),
276
+ os.environ.get("SUPABASE_SERVICE_KEY"))
277
+ vector_store = SupabaseVectorStore(
278
+ client=supabase,
279
+ embedding= embeddings,
280
+ table_name="documents",
281
+ query_name="match_documents_langchain",
282
+ )
283
+ create_retriever_tool = create_retriever_tool(
284
+ retriever=vector_store.as_retriever(),
285
+ name="Question Search",
286
+ description="A tool to retrieve similar questions from a vector store.",
287
+ )
288
+
289
+ tools = [
290
+ multiply,
291
+ add,
292
+ subtract,
293
+ divide,
294
+ modulus,
295
+ power,
296
+ square_root,
297
+ wiki_search,
298
+ web_search,
299
+ arvix_search,
300
+ save_and_read_file,
301
+ analyze_csv_file,
302
+ analyze_excel_file,
303
+ # create_retriever_tool
304
+ ]
305
+
306
+ # Build graph function
307
+ def build_graph(provider: str = "groq"):
308
+ """Build the graph"""
309
+ # Load environment variables from .env file
310
+ if provider == "google":
311
+ # Google Gemini
312
+ llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0)
313
+ elif provider == "groq":
314
+ # Groq https://console.groq.com/docs/models
315
+ llm = ChatGroq(model="qwen-qwq-32b", temperature=0) # optional : qwen-qwq-32b gemma2-9b-it
316
+ elif provider == "huggingface":
317
+ # TODO: Add huggingface endpoint
318
+ llm = ChatHuggingFace(
319
+ llm=HuggingFaceEndpoint(
320
+ url="https://api-inference.huggingface.co/models/Meta-DeepLearning/llama-2-7b-chat-hf",
321
+ temperature=0,
322
+ ),
323
+ )
324
+ else:
325
+ raise ValueError("Invalid provider. Choose 'google', 'groq' or 'huggingface'.")
326
+ # Bind tools to LLM
327
+ llm_with_tools = llm.bind_tools(tools)
328
+
329
+ # Node
330
+ def assistant(state: MessagesState):
331
+ """Assistant node"""
332
+ return {"messages": [llm_with_tools.invoke(state["messages"])]}
333
+
334
+ def retriever(state: MessagesState):
335
+ """Retriever node"""
336
+ similar_question = vector_store.similarity_search(state["messages"][0].content)
337
+ # similar_question = "What is the surname of the equine veterinarian mentioned in 1.E Exercises from the chemistry materials licensed by Marisa Alviar-Agnew & Henry Agnew under the CK-12 license in LibreText's Introductory Chemistry materials as compiled 08/21/2023?"
338
+ if similar_question:
339
+ example_msg = HumanMessage(
340
+ content=f"Here I provide a similar question and answer for reference: \n\n{similar_question[0].page_content}",
341
+ )
342
+ else:
343
+ example_msg = HumanMessage(
344
+ content="No similar questions found in the database.",
345
+ )
346
+ return {"messages": [sys_msg] + state["messages"] + [example_msg]}
347
+
348
+ builder = StateGraph(MessagesState)
349
+ builder.add_node("retriever", retriever)
350
+ builder.add_node("assistant", assistant)
351
+ builder.add_node("tools", ToolNode(tools))
352
+ builder.add_edge(START, "retriever")
353
+ builder.add_edge("retriever", "assistant")
354
+ builder.add_conditional_edges(
355
+ "assistant",
356
+ tools_condition,
357
+ )
358
+ builder.add_edge("tools", "assistant")
359
+
360
+ # Compile graph
361
+ return builder.compile()
362
+
363
+ # test
364
+ if __name__ == "__main__":
365
+ #question = "When was a picture of St. Thomas Aquinas first added to the Wikipedia page on the Principle of double effect?"
366
+ question = "What is the surname of the equine veterinarian mentioned in 1.E Exercises from the chemistry materials licensed by Marisa Alviar-Agnew & Henry Agnew under the CK-12 license in LibreText's Introductory Chemistry materials as compiled 08/21/2023?"
367
+ # Build the graph
368
+ graph = build_graph(provider="google")
369
+ # Run the graph
370
+ messages = [HumanMessage(content=question)]
371
+ messages = graph.invoke({"messages": messages})
372
+ for m in messages["messages"]:
373
+ m.pretty_print()