superone001 commited on
Commit
76cd2ec
·
verified ·
1 Parent(s): 8d69b80

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +131 -189
agent.py CHANGED
@@ -1,269 +1,211 @@
 
 
1
  from dotenv import load_dotenv
2
-
3
- from langchain_openai import ChatOpenAI
4
- from langchain_core.tools import tool
 
 
 
 
5
  from langchain_community.document_loaders import WikipediaLoader
6
  from langchain_community.document_loaders import ArxivLoader
7
- from langchain_community.tools.tavily_search import TavilySearchResults
8
- from langchain_tavily import TavilyExtract
9
- from youtube_transcript_api import YouTubeTranscriptApi
10
-
11
  from langchain_core.messages import SystemMessage, HumanMessage
12
- from langgraph.graph import START, StateGraph, MessagesState
13
- from langgraph.prebuilt import ToolNode
14
- from langgraph.prebuilt import tools_condition
15
- import base64
16
- import httpx
17
-
18
 
19
  load_dotenv()
20
 
21
  @tool
22
- def add(a: int, b: int) -> int:
23
- """
24
- Add b to a.
25
-
26
  Args:
27
- a: first int number
28
- b: second int number
29
  """
30
- return a + b
31
 
32
  @tool
33
- def substract(a: int, b: int) -> int:
34
- """
35
- Subtract b from a.
36
 
37
  Args:
38
- a: first int number
39
- b: second int number
40
  """
41
- return a - b
42
 
43
  @tool
44
- def multiply(a: int, b: int) -> int:
45
- """
46
- Multiply a by b.
47
 
48
  Args:
49
- a: first int number
50
- b: second int number
51
  """
52
- return a * b
53
 
54
  @tool
55
  def divide(a: int, b: int) -> int:
56
- """
57
- Divide a by b.
58
 
59
  Args:
60
- a: first int number
61
- b: second int number
62
  """
63
  if b == 0:
64
- raise ValueError("Can't divide by zero.")
65
  return a / b
66
 
67
  @tool
68
- def mod(a: int, b: int) -> int:
69
- """
70
- Remainder of a devided by b.
71
 
72
  Args:
73
- a: first int number
74
- b: second int number
75
  """
76
  return a % b
77
 
78
  @tool
79
  def wiki_search(query: str) -> str:
80
- """
81
- Search Wikipedia.
82
 
83
  Args:
84
- query: what to search for
85
- """
86
- search_docs = WikipediaLoader(query=query, load_max_docs=3).load()
87
- formatted_search_docs = "".join(
88
  [
89
- f'<START source="{doc.metadata["source"]}">{doc.page_content[:1000]}<END>'
90
  for doc in search_docs
91
  ])
92
  return {"wiki_results": formatted_search_docs}
93
 
94
  @tool
95
- def arvix_search(query: str) -> str:
96
- """
97
- Search arXiv which is online archive of preprint and postprint manuscripts
98
- for different fields of science.
99
 
100
  Args:
101
- query: what to search for
102
- """
103
- search_docs = ArxivLoader(query=query, load_max_docs=3).load()
104
- formatted_search_docs = "".join(
105
  [
106
- f'<START source="{doc.metadata["source"]}">{doc.page_content[:1000]}<END>'
107
  for doc in search_docs
108
  ])
109
- return {"arvix_results": formatted_search_docs}
110
 
111
  @tool
112
- def web_search(query: str) -> str:
113
- """
114
- Search WEB.
115
 
116
  Args:
117
- query: what to search for
118
- """
119
- search_docs = TavilySearchResults(max_results=3, include_answer=True).invoke({"query": query})
120
- formatted_search_docs = "".join(
121
  [
122
- f'<START source="{doc["url"]}">{doc["content"][:1000]}<END>'
123
  for doc in search_docs
124
  ])
125
- return {"web_results": formatted_search_docs}
126
 
127
- @tool
128
- def open_web_page(url: str) -> str:
129
- """
130
- Open web page and get its content.
131
-
132
- Args:
133
- url: web page url in ""
134
- """
135
- search_docs = TavilyExtract().invoke({"urls": [url]})
136
- formatted_search_docs = f'<START source="{search_docs["results"][0]["url"]}">{search_docs["results"][0]["raw_content"][:1000]}<END>'
137
- return {"web_page_content": formatted_search_docs}
138
 
139
- @tool
140
- def youtube_transcript(url: str) -> str:
141
- """
142
- Get transcript of YouTube video.
143
- Args:
144
- url: YouTube video url in ""
145
- """
146
- video_id = url.partition("https://www.youtube.com/watch?v=")[2]
147
- transcript = YouTubeTranscriptApi.get_transcript(video_id)
148
- transcript_text = " ".join([item["text"] for item in transcript])
149
- return {"youtube_transcript": transcript_text}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
 
151
 
152
  tools = [
153
- add,
154
- substract,
155
  multiply,
 
 
156
  divide,
157
- mod,
158
  wiki_search,
159
- arvix_search,
160
  web_search,
161
- open_web_page,
162
- youtube_transcript,
163
  ]
164
 
165
- # System prompt
166
- system_prompt = f"""
167
- You are a general AI assistant. I will ask you a question.
168
- First, provide a step-by-step explanation of your reasoning to arrive at the answer.
169
- Then, respond with your final answer in a single line, formatted as follows: "FINAL ANSWER: [YOUR FINAL ANSWER]".
170
- [YOUR FINAL ANSWER] should be a number, a string, or a comma-separated list of numbers and/or strings, depending on the question.
171
- If the answer is a number, do not use commas or units (e.g., $, %) unless specified.
172
- If the answer is a string, do not use articles or abbreviations (e.g., for cities), and write digits in plain text unless specified.
173
- If the answer is a comma-separated list, apply the above rules for each element based on whether it is a number or a string.
174
- """
175
- system_message = SystemMessage(content=system_prompt)
176
-
177
- # Build graph
178
- def build_graph():
179
- """Build LangGrapth graph of agent."""
180
-
181
- # Language model and tools
182
- llm = HuggingFaceEndpoint(
183
- endpoint_url="https://api-inference.huggingface.co/models/mistralai/Mistral-7B-Instruct-v0.2",
184
- max_new_tokens=500, # 直接指定
185
- temperature=0.1, # 直接指定
186
- repetition_penalty=1.2, # 直接指定
187
- top_p=0.9, # 可选参数
188
- # 其他参数也可以直接在这里指定
189
- )
190
- llm_with_tools = llm.bind_tools(tools, strict=True)
191
 
192
- # Nodes
193
  def assistant(state: MessagesState):
194
- """Assistant node."""
195
- return {"messages": [llm_with_tools.invoke([system_message] + state["messages"])]}
 
 
 
 
 
 
 
 
196
 
197
- # Graph
198
  builder = StateGraph(MessagesState)
 
199
  builder.add_node("assistant", assistant)
200
  builder.add_node("tools", ToolNode(tools))
201
- builder.add_edge(START, "assistant")
202
- builder.add_conditional_edges("assistant", tools_condition)
 
 
 
 
203
  builder.add_edge("tools", "assistant")
204
 
205
  # Compile graph
206
  return builder.compile()
207
 
208
-
209
- # Testing and solving particular tasks
210
  if __name__ == "__main__":
211
-
212
- agent = build_graph()
213
-
214
- question = """
215
- Review the chess position provided in the image. It is black's turn.
216
- Provide the correct next move for black which guarantees a win.
217
- Please provide your response in algebraic notation.
218
- """
219
- content_urls = {
220
- "image": "https://agents-course-unit4-scoring.hf.space/files/cca530fc-4052-43b2-b130-b30968d8aa44",
221
- "audio": None
222
- }
223
-
224
- # Define user message and add all the content
225
- content = [
226
- {
227
- "type": "text",
228
- "text": question
229
- }
230
- ]
231
- if content_urls["image"]:
232
- image_data = base64.b64encode(httpx.get(content_urls["image"]).content).decode("utf-8")
233
- content.append(
234
- {
235
- "type": "image",
236
- "source_type": "base64",
237
- "data": image_data,
238
- "mime_type": "image/jpeg"
239
- }
240
- )
241
- if content_urls["audio"]:
242
- audio_data = base64.b64encode(httpx.get(content_urls["audio"]).content).decode("utf-8")
243
- content.append(
244
- {
245
- "type": "audio",
246
- "source_type": "base64",
247
- "data": audio_data,
248
- "mime_type": "audio/wav"
249
- }
250
- )
251
- messages = {
252
- "role": "user",
253
- "content": content
254
- }
255
-
256
- # Run agent on the question
257
- messages = agent.invoke({"messages": messages})
258
- for message in messages["messages"]:
259
- message.pretty_print()
260
-
261
- answer = messages["messages"][-1].content
262
- index = answer.find("FINAL ANSWER: ")
263
-
264
- print("\n")
265
- print("="*30)
266
- if index == -1:
267
- print(answer)
268
- print(answer[index+14:])
269
- print("="*30)
 
1
+ """LangGraph Agent"""
2
+ import os
3
  from dotenv import load_dotenv
4
+ from langgraph.graph import START, StateGraph, MessagesState
5
+ from langgraph.prebuilt import tools_condition
6
+ from langgraph.prebuilt import ToolNode
7
+ from langchain_google_genai import ChatGoogleGenerativeAI
8
+ from langchain_groq import ChatGroq
9
+ from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFaceEmbeddings
10
+ from langchain_community.tools.tavily_search import TavilySearchResults
11
  from langchain_community.document_loaders import WikipediaLoader
12
  from langchain_community.document_loaders import ArxivLoader
13
+ from langchain_community.vectorstores import SupabaseVectorStore
 
 
 
14
  from langchain_core.messages import SystemMessage, HumanMessage
15
+ from langchain_core.tools import tool
16
+ from langchain.tools.retriever import create_retriever_tool
17
+ from supabase.client import Client, create_client
 
 
 
18
 
19
  load_dotenv()
20
 
21
  @tool
22
+ def multiply(a: int, b: int) -> int:
23
+ """Multiply two numbers.
 
 
24
  Args:
25
+ a: first int
26
+ b: second int
27
  """
28
+ return a * b
29
 
30
  @tool
31
+ def add(a: int, b: int) -> int:
32
+ """Add two numbers.
 
33
 
34
  Args:
35
+ a: first int
36
+ b: second int
37
  """
38
+ return a + b
39
 
40
  @tool
41
+ def subtract(a: int, b: int) -> int:
42
+ """Subtract two numbers.
 
43
 
44
  Args:
45
+ a: first int
46
+ b: second int
47
  """
48
+ return a - b
49
 
50
  @tool
51
  def divide(a: int, b: int) -> int:
52
+ """Divide two numbers.
 
53
 
54
  Args:
55
+ a: first int
56
+ b: second int
57
  """
58
  if b == 0:
59
+ raise ValueError("Cannot divide by zero.")
60
  return a / b
61
 
62
  @tool
63
+ def modulus(a: int, b: int) -> int:
64
+ """Get the modulus of two numbers.
 
65
 
66
  Args:
67
+ a: first int
68
+ b: second int
69
  """
70
  return a % b
71
 
72
  @tool
73
  def wiki_search(query: str) -> str:
74
+ """Search Wikipedia for a query and return maximum 2 results.
 
75
 
76
  Args:
77
+ query: The search query."""
78
+ search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
79
+ formatted_search_docs = "\n\n---\n\n".join(
 
80
  [
81
+ f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
82
  for doc in search_docs
83
  ])
84
  return {"wiki_results": formatted_search_docs}
85
 
86
  @tool
87
+ def web_search(query: str) -> str:
88
+ """Search Tavily for a query and return maximum 3 results.
 
 
89
 
90
  Args:
91
+ query: The search query."""
92
+ search_docs = TavilySearchResults(max_results=3).invoke(query=query)
93
+ formatted_search_docs = "\n\n---\n\n".join(
 
94
  [
95
+ f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
96
  for doc in search_docs
97
  ])
98
+ return {"web_results": formatted_search_docs}
99
 
100
  @tool
101
+ def arvix_search(query: str) -> str:
102
+ """Search Arxiv for a query and return maximum 3 result.
 
103
 
104
  Args:
105
+ query: The search query."""
106
+ search_docs = ArxivLoader(query=query, load_max_docs=3).load()
107
+ formatted_search_docs = "\n\n---\n\n".join(
 
108
  [
109
+ f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content[:1000]}\n</Document>'
110
  for doc in search_docs
111
  ])
112
+ return {"arvix_results": formatted_search_docs}
113
 
 
 
 
 
 
 
 
 
 
 
 
114
 
115
+
116
+ # load the system prompt from the file
117
+ with open("system_prompt.txt", "r", encoding="utf-8") as f:
118
+ system_prompt = f.read()
119
+
120
+ # System message
121
+ sys_msg = SystemMessage(content=system_prompt)
122
+
123
+ # build a retriever
124
+ embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2") # dim=768
125
+ supabase: Client = create_client(
126
+ os.environ.get("SUPABASE_URL"),
127
+ os.environ.get("SUPABASE_SERVICE_KEY"))
128
+ vector_store = SupabaseVectorStore(
129
+ client=supabase,
130
+ embedding= embeddings,
131
+ table_name="documents",
132
+ query_name="match_documents_langchain",
133
+ )
134
+ create_retriever_tool = create_retriever_tool(
135
+ retriever=vector_store.as_retriever(),
136
+ name="Question Search",
137
+ description="A tool to retrieve similar questions from a vector store.",
138
+ )
139
+
140
 
141
 
142
  tools = [
 
 
143
  multiply,
144
+ add,
145
+ subtract,
146
  divide,
147
+ modulus,
148
  wiki_search,
 
149
  web_search,
150
+ arvix_search,
 
151
  ]
152
 
153
+ # Build graph function
154
+ def build_graph(provider: str = "huggingface"):
155
+ """Build the graph"""
156
+ # Load environment variables from .env file
157
+ if provider == "google":
158
+ # Google Gemini
159
+ llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0)
160
+ elif provider == "groq":
161
+ # Groq https://console.groq.com/docs/models
162
+ llm = ChatGroq(model="qwen-qwq-32b", temperature=0) # optional : qwen-qwq-32b gemma2-9b-it
163
+ elif provider == "huggingface":
164
+ llm = ChatHuggingFace(
165
+ llm=HuggingFaceEndpoint(
166
+ repo_id = "Qwen/Qwen2.5-Coder-32B-Instruct"
167
+ ),
168
+ )
169
+ else:
170
+ raise ValueError("Invalid provider. Choose 'google', 'groq' or 'huggingface'.")
171
+ # Bind tools to LLM
172
+ llm_with_tools = llm.bind_tools(tools)
 
 
 
 
 
 
173
 
174
+ # Node
175
  def assistant(state: MessagesState):
176
+ """Assistant node"""
177
+ return {"messages": [llm_with_tools.invoke(state["messages"])]}
178
+
179
+ def retriever(state: MessagesState):
180
+ """Retriever node"""
181
+ similar_question = vector_store.similarity_search(state["messages"][0].content)
182
+ example_msg = HumanMessage(
183
+ content=f"Here I provide a similar question and answer for reference: \n\n{similar_question[0].page_content}",
184
+ )
185
+ return {"messages": [sys_msg] + state["messages"] + [example_msg]}
186
 
 
187
  builder = StateGraph(MessagesState)
188
+ builder.add_node("retriever", retriever)
189
  builder.add_node("assistant", assistant)
190
  builder.add_node("tools", ToolNode(tools))
191
+ builder.add_edge(START, "retriever")
192
+ builder.add_edge("retriever", "assistant")
193
+ builder.add_conditional_edges(
194
+ "assistant",
195
+ tools_condition,
196
+ )
197
  builder.add_edge("tools", "assistant")
198
 
199
  # Compile graph
200
  return builder.compile()
201
 
202
+ # test
 
203
  if __name__ == "__main__":
204
+ question = "When was a picture of St. Thomas Aquinas first added to the Wikipedia page on the Principle of double effect?"
205
+ # Build the graph
206
+ graph = build_graph(provider="groq")
207
+ # Run the graph
208
+ messages = [HumanMessage(content=question)]
209
+ messages = graph.invoke({"messages": messages})
210
+ for m in messages["messages"]:
211
+ m.pretty_print()