Coool2 commited on
Commit
e61eb95
·
1 Parent(s): 4c11be0

Create agent2.py

Browse files
Files changed (1) hide show
  1. agent2.py +317 -0
agent2.py ADDED
@@ -0,0 +1,317 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import requests
3
+ from typing import Dict, Any, List
4
+ from langchain.docstore.document import Document
5
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
6
+ from langchain_community.retrievers import BM25Retriever
7
+ from smolagents import CodeAgent, OpenAIServerModel, tool, Tool
8
+ from smolagents.vision_web_browser import initialize_driver, save_screenshot, helium_instructions
9
+ from smolagents.agents import ActionStep
10
+ from selenium import webdriver
11
+ from selenium.webdriver.common.by import By
12
+ from selenium.webdriver.common.keys import Keys
13
+ import helium
14
+ from PIL import Image
15
+ from io import BytesIO
16
+ from time import sleep
17
+
18
+
19
+ class BM25RetrieverTool(Tool):
20
+ """
21
+ BM25 retriever tool for document search when text documents are available
22
+ """
23
+ name = "bm25_retriever"
24
+ description = "Uses BM25 search to retrieve relevant parts of uploaded documents. Use this when the question references an attached file or document."
25
+ inputs = {
26
+ "query": {
27
+ "type": "string",
28
+ "description": "The search query to find relevant document sections.",
29
+ }
30
+ }
31
+ output_type = "string"
32
+
33
+ def __init__(self, docs=None, **kwargs):
34
+ super().__init__(**kwargs)
35
+ self.docs = docs or []
36
+ self.retriever = None
37
+ if self.docs:
38
+ self.retriever = BM25Retriever.from_documents(self.docs, k=5)
39
+
40
+ def forward(self, query: str) -> str:
41
+ if not self.retriever:
42
+ return "No documents loaded for retrieval."
43
+
44
+ assert isinstance(query, str), "Your search query must be a string"
45
+
46
+ docs = self.retriever.invoke(query)
47
+ return "\nRetrieved documents:\n" + "".join([
48
+ f"\n\n===== Document {str(i)} =====\n" + doc.page_content
49
+ for i, doc in enumerate(docs)
50
+ ])
51
+
52
+
53
+ @tool
54
+ def search_item_ctrl_f(text: str, nth_result: int = 1) -> str:
55
+ """
56
+ Searches for text on the current page via Ctrl + F and jumps to the nth occurrence.
57
+ """
58
+ try:
59
+ driver = helium.get_driver()
60
+ elements = driver.find_elements(By.XPATH, f"//*[contains(text(), '{text}')]")
61
+ if nth_result > len(elements):
62
+ return f"Match n°{nth_result} not found (only {len(elements)} matches found)"
63
+ result = f"Found {len(elements)} matches for '{text}'."
64
+ elem = elements[nth_result - 1]
65
+ driver.execute_script("arguments[0].scrollIntoView(true);", elem)
66
+ result += f"Focused on element {nth_result} of {len(elements)}"
67
+ return result
68
+ except Exception as e:
69
+ return f"Error searching for text: {e}"
70
+
71
+
72
+ @tool
73
+ def go_back() -> str:
74
+ """Goes back to previous page."""
75
+ try:
76
+ driver = helium.get_driver()
77
+ driver.back()
78
+ return "Navigated back to previous page"
79
+ except Exception as e:
80
+ return f"Error going back: {e}"
81
+
82
+
83
+ @tool
84
+ def close_popups() -> str:
85
+ """
86
+ Closes any visible modal or pop-up on the page. Use this to dismiss pop-up windows!
87
+ """
88
+ try:
89
+ driver = helium.get_driver()
90
+ webdriver.ActionChains(driver).send_keys(Keys.ESCAPE).perform()
91
+ return "Attempted to close popups"
92
+ except Exception as e:
93
+ return f"Error closing popups: {e}"
94
+
95
+
96
+ def save_screenshot_callback(memory_step: ActionStep, agent: CodeAgent) -> None:
97
+ """Save screenshots for web browser automation"""
98
+ try:
99
+ sleep(1.0)
100
+ driver = helium.get_driver()
101
+ if driver is not None:
102
+ # Clean up old screenshots
103
+ for previous_memory_step in agent.memory.steps:
104
+ if isinstance(previous_memory_step, ActionStep) and previous_memory_step.step_number <= memory_step.step_number - 2:
105
+ previous_memory_step.observations_images = None
106
+
107
+ png_bytes = driver.get_screenshot_as_png()
108
+ image = Image.open(BytesIO(png_bytes))
109
+ memory_step.observations_images = [image.copy()]
110
+
111
+ # Update observations with current URL
112
+ url_info = f"Current url: {driver.current_url}"
113
+ memory_step.observations = (
114
+ url_info if memory_step.observations is None
115
+ else memory_step.observations + "\n" + url_info
116
+ )
117
+ except Exception as e:
118
+ print(f"Error in screenshot callback: {e}")
119
+
120
+
121
+ class GAIAAgent:
122
+ """
123
+ Simplified GAIA agent using smolagents with Gemini 2.0 Flash
124
+ """
125
+
126
+ def __init__(self):
127
+ """Initialize the agent with Gemini 2.0 Flash and tools"""
128
+
129
+ # Get Gemini API key
130
+ gemini_api_key = os.environ.get("GEMINI_API_KEY")
131
+ if not gemini_api_key:
132
+ raise ValueError("GEMINI_API_KEY environment variable not found")
133
+
134
+ # Initialize Gemini 2.0 Flash model
135
+ self.model = OpenAIServerModel(
136
+ model_id="gemini-2.0-flash",
137
+ api_base="https://generativelanguage.googleapis.com/v1beta/openai/",
138
+ api_key=gemini_api_key,
139
+ )
140
+
141
+ # GAIA system prompt from the leaderboard
142
+ self.system_prompt = """You are a general AI assistant. I will ask you a question. Report your thoughts and reasoning process clearly. You should use the available tools to gather information and solve problems step by step.
143
+
144
+ When using web browser automation:
145
+ - Use helium commands like go_to(), click(), scroll_down()
146
+ - Take screenshots to see what's happening
147
+ - Handle popups and forms appropriately
148
+ - Be patient with page loading
149
+
150
+ For document retrieval:
151
+ - Use the BM25 retriever when there are text documents attached
152
+ - Search with relevant keywords from the question
153
+
154
+ Your final answer should be as few words as possible, a number, or a comma-separated list. Don't use articles, abbreviations, or units unless specified."""
155
+
156
+ # Initialize retriever tool (will be updated when documents are loaded)
157
+ self.retriever_tool = BM25RetrieverTool()
158
+
159
+ # Initialize web driver for browser automation
160
+ self.driver = None
161
+
162
+ # Create the agent
163
+ self.agent = None
164
+ self._create_agent()
165
+
166
+ def _create_agent(self):
167
+ """Create the CodeAgent with tools"""
168
+ base_tools = [self.retriever_tool, search_item_ctrl_f, go_back, close_popups]
169
+
170
+ self.agent = CodeAgent(
171
+ tools=base_tools,
172
+ model=self.model,
173
+ add_base_tools=True, # Adds web search, python execution, etc.
174
+ planning_interval=5, # Plan every 5 steps
175
+ additional_authorized_imports=["helium", "requests", "BeautifulSoup", "json"],
176
+ step_callbacks=[save_screenshot_callback] if self.driver else [],
177
+ max_steps=20,
178
+ system_prompt=self.system_prompt,
179
+ verbosity_level=2,
180
+ )
181
+
182
+ def initialize_browser(self):
183
+ """Initialize browser for web automation tasks"""
184
+ try:
185
+ chrome_options = webdriver.ChromeOptions()
186
+ chrome_options.add_argument("--force-device-scale-factor=1")
187
+ chrome_options.add_argument("--window-size=1000,1350")
188
+ chrome_options.add_argument("--disable-pdf-viewer")
189
+ chrome_options.add_argument("--window-position=0,0")
190
+
191
+ self.driver = helium.start_chrome(headless=False, options=chrome_options)
192
+
193
+ # Recreate agent with browser tools
194
+ self._create_agent()
195
+
196
+ # Import helium for the agent
197
+ self.agent.python_executor("from helium import *")
198
+
199
+ return True
200
+ except Exception as e:
201
+ print(f"Failed to initialize browser: {e}")
202
+ return False
203
+
204
+ def load_documents_from_file(self, file_path: str):
205
+ """Load and process documents from a file for BM25 retrieval"""
206
+ try:
207
+ # Read file content
208
+ with open(file_path, 'r', encoding='utf-8') as f:
209
+ content = f.read()
210
+
211
+ # Split into chunks
212
+ text_splitter = RecursiveCharacterTextSplitter(
213
+ chunk_size=1000,
214
+ chunk_overlap=200,
215
+ separators=["\n\n", "\n", ".", " ", ""]
216
+ )
217
+
218
+ # Create documents
219
+ chunks = text_splitter.split_text(content)
220
+ docs = [Document(page_content=chunk, metadata={"source": file_path})
221
+ for chunk in chunks]
222
+
223
+ # Update retriever tool
224
+ self.retriever_tool = BM25RetrieverTool(docs)
225
+
226
+ # Recreate agent with updated retriever
227
+ self._create_agent()
228
+
229
+ print(f"Loaded {len(docs)} document chunks from {file_path}")
230
+ return True
231
+
232
+ except Exception as e:
233
+ print(f"Error loading documents from {file_path}: {e}")
234
+ return False
235
+
236
+ def download_gaia_file(self, task_id: str, api_url: str = "https://agents-course-unit4-scoring.hf.space") -> str:
237
+ """Download file associated with GAIA task_id"""
238
+ try:
239
+ response = requests.get(f"{api_url}/files/{task_id}", timeout=30)
240
+ response.raise_for_status()
241
+
242
+ filename = f"task_{task_id}_file.txt"
243
+ with open(filename, 'wb') as f:
244
+ f.write(response.content)
245
+
246
+ return filename
247
+ except Exception as e:
248
+ print(f"Failed to download file for task {task_id}: {e}")
249
+ return None
250
+
251
+ def solve_gaia_question(self, question_data: Dict[str, Any]) -> str:
252
+ """
253
+ Solve a GAIA question
254
+ """
255
+ question = question_data.get("Question", "")
256
+ task_id = question_data.get("task_id", "")
257
+
258
+ # Download and load file if task_id provided
259
+ if task_id:
260
+ file_path = self.download_gaia_file(task_id)
261
+ if file_path:
262
+ self.load_documents_from_file(file_path)
263
+ print(f"Loaded file for task {task_id}")
264
+
265
+ # Check if this requires web browsing
266
+ web_indicators = ["navigate", "browser", "website", "webpage", "url", "click", "search on"]
267
+ needs_browser = any(indicator in question.lower() for indicator in web_indicators)
268
+
269
+ if needs_browser and not self.driver:
270
+ print("Initializing browser for web automation...")
271
+ self.initialize_browser()
272
+
273
+ # Prepare the prompt
274
+ prompt = f"""
275
+ Question: {question}
276
+ {f'Task ID: {task_id}' if task_id else ''}
277
+ {f'File loaded: Yes' if task_id else 'File loaded: No'}
278
+
279
+ Solve this step by step. Use the available tools to gather information and provide a precise answer.
280
+ """
281
+
282
+ if needs_browser:
283
+ prompt += "\n" + helium_instructions
284
+
285
+ try:
286
+ print("=== AGENT REASONING ===")
287
+ result = self.agent.run(prompt)
288
+ print("=== END REASONING ===")
289
+
290
+ return str(result)
291
+
292
+ except Exception as e:
293
+ error_msg = f"Error processing question: {str(e)}"
294
+ print(error_msg)
295
+ return error_msg
296
+ finally:
297
+ # Clean up browser if initialized
298
+ if self.driver:
299
+ try:
300
+ helium.kill_browser()
301
+ except:
302
+ pass
303
+
304
+
305
+ # Example usage
306
+ if __name__ == "__main__":
307
+ # Test the agent
308
+ agent = GAIAAgent()
309
+
310
+ # Example question
311
+ question_data = {
312
+ "Question": "How many studio albums Mercedes Sosa has published between 2000-2009 ?",
313
+ "task_id": ""
314
+ }
315
+
316
+ answer = agent.solve_gaia_question(question_data)
317
+ print(f"Answer: {answer}")