OrtizR52 commited on
Commit
4b623d3
·
verified ·
1 Parent(s): 81917a3

Create agent.py

Browse files
Files changed (1) hide show
  1. agent.py +340 -0
agent.py ADDED
@@ -0,0 +1,340 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --- Basic Agent Definition ---
2
+ import asyncio
3
+ import os
4
+ import sys
5
+ import logging
6
+ import random
7
+ import pandas as pd
8
+ import requests
9
+ import wikipedia as wiki
10
+ from markdownify import markdownify as to_markdown
11
+ from typing import Any
12
+ from dotenv import load_dotenv
13
+ from google.generativeai import types, configure
14
+
15
+ from smolagents import InferenceClientModel, LiteLLMModel, CodeAgent, ToolCallingAgent, Tool, DuckDuckGoSearchTool
16
+
17
+ # Load environment and configure Gemini
18
+ load_dotenv()
19
+ configure(api_key=os.getenv("GOOGLE_API_KEY"))
20
+
21
+ # Logging
22
+ #logging.basicConfig(level=logging.INFO, format="%(asctime)s | %(levelname)s | %(message)s")
23
+ #logger = logging.getLogger(__name__)
24
+
25
+ # --- Model Configuration ---
26
+ GEMINI_MODEL_NAME = "gemini/gemini-2.0-flash-lite"
27
+ # GEMINI_MODEL_NAME = "gemini/gemini-2.5-flash"
28
+ # OPENAI_MODEL_NAME = "openai/gpt-4o"
29
+ # GROQ_MODEL_NAME = "groq/llama3-70b-8192"
30
+ # DEEPSEEK_MODEL_NAME = "deepseek/deepseek-chat"
31
+ HF_MODEL_NAME = "Qwen/Qwen2.5-Coder-32B-Instruct"
32
+
33
+ # --- Tool Definitions ---
34
+ class MathSolver(Tool):
35
+ name = "math_solver"
36
+ description = "Safely evaluate basic math expressions."
37
+ inputs = {"input": {"type": "string", "description": "Math expression to evaluate."}}
38
+ output_type = "string"
39
+
40
+ def forward(self, input: str) -> str:
41
+ try:
42
+ return str(eval(input, {"__builtins__": {}}))
43
+ except Exception as e:
44
+ return f"Math error: {e}"
45
+
46
+ class RiddleSolver(Tool):
47
+ name = "riddle_solver"
48
+ description = "Solve basic riddles using logic."
49
+ inputs = {"input": {"type": "string", "description": "Riddle prompt."}}
50
+ output_type = "string"
51
+
52
+ def forward(self, input: str) -> str:
53
+ if "forward" in input and "backward" in input:
54
+ return "A palindrome"
55
+ return "RiddleSolver failed."
56
+
57
+ class TextTransformer(Tool):
58
+ name = "text_ops"
59
+ description = "Transform text: reverse, upper, lower."
60
+ inputs = {"input": {"type": "string", "description": "Use prefix like reverse:/upper:/lower:"}}
61
+ output_type = "string"
62
+
63
+ def forward(self, input: str) -> str:
64
+ if input.startswith("reverse:"):
65
+ reversed_text = input[8:].strip()[::-1]
66
+ if 'left' in reversed_text.lower():
67
+ return "right"
68
+ return reversed_text
69
+ if input.startswith("upper:"):
70
+ return input[6:].strip().upper()
71
+ if input.startswith("lower:"):
72
+ return input[6:].strip().lower()
73
+ return "Unknown transformation."
74
+
75
+ class GeminiVideoQA(Tool):
76
+ name = "video_inspector"
77
+ description = "Analyze video content to answer questions."
78
+ inputs = {
79
+ "video_url": {"type": "string", "description": "URL of video."},
80
+ "user_query": {"type": "string", "description": "Question about video."}
81
+ }
82
+ output_type = "string"
83
+
84
+ def __init__(self, model_name, *args, **kwargs):
85
+ super().__init__(*args, **kwargs)
86
+ self.model_name = model_name
87
+
88
+ def forward(self, video_url: str, user_query: str) -> str:
89
+ req = {
90
+ 'model': f'models/{self.model_name}',
91
+ 'contents': [{
92
+ "parts": [
93
+ {"fileData": {"fileUri": video_url}},
94
+ {"text": f"Please watch the video and answer the question: {user_query}"}
95
+ ]
96
+ }]
97
+ }
98
+ url = f'https://generativelanguage.googleapis.com/v1beta/models/{self.model_name}:generateContent?key={os.getenv("GOOGLE_API_KEY")}'
99
+ res = requests.post(url, json=req, headers={'Content-Type': 'application/json'})
100
+ if res.status_code != 200:
101
+ return f"Video error {res.status_code}: {res.text}"
102
+ parts = res.json()['candidates'][0]['content']['parts']
103
+ return "".join([p.get('text', '') for p in parts])
104
+
105
+ class WikiTitleFinder(Tool):
106
+ name = "wiki_titles"
107
+ description = "Search for related Wikipedia page titles."
108
+ inputs = {"query": {"type": "string", "description": "Search query."}}
109
+ output_type = "string"
110
+
111
+ def forward(self, query: str) -> str:
112
+ results = wiki.search(query)
113
+ return ", ".join(results) if results else "No results."
114
+
115
+ class WikiContentFetcher(Tool):
116
+ name = "wiki_page"
117
+ description = "Fetch Wikipedia page content."
118
+ inputs = {"page_title": {"type": "string", "description": "Wikipedia page title."}}
119
+ output_type = "string"
120
+
121
+ def forward(self, page_title: str) -> str:
122
+ try:
123
+ return to_markdown(wiki.page(page_title).html())
124
+ except wiki.exceptions.PageError:
125
+ return f"'{page_title}' not found."
126
+
127
+ class GoogleSearchTool(Tool):
128
+ name = "google_search"
129
+ description = "Search the web using Google. Returns top summary from the web."
130
+ inputs = {"query": {"type": "string", "description": "Search query."}}
131
+ output_type = "string"
132
+
133
+ def forward(self, query: str) -> str:
134
+ try:
135
+ resp = requests.get("https://www.googleapis.com/customsearch/v1", params={
136
+ "q": query,
137
+ "key": os.getenv("GOOGLE_SEARCH_API_KEY"),
138
+ "cx": os.getenv("GOOGLE_SEARCH_ENGINE_ID"),
139
+ "num": 1
140
+ })
141
+ data = resp.json()
142
+ return data["items"][0]["snippet"] if "items" in data else "No results found."
143
+ except Exception as e:
144
+ return f"GoogleSearch error: {e}"
145
+
146
+
147
+ class FileAttachmentQueryTool(Tool):
148
+ name = "run_query_with_file"
149
+ description = """
150
+ Downloads a file mentioned in a user prompt, adds it to the context, and runs a query on it.
151
+ This assumes the file is 20MB or less.
152
+ """
153
+ inputs = {
154
+ "task_id": {
155
+ "type": "string",
156
+ "description": "A unique identifier for the task related to this file, used to download it.",
157
+ "nullable": True
158
+ },
159
+ "user_query": {
160
+ "type": "string",
161
+ "description": "The question to answer about the file."
162
+ }
163
+ }
164
+ output_type = "string"
165
+
166
+ def forward(self, task_id: str | None, user_query: str) -> str:
167
+ file_url = f"https://agents-course-unit4-scoring.hf.space/files/{task_id}"
168
+ file_response = requests.get(file_url)
169
+ if file_response.status_code != 200:
170
+ return f"Failed to download file: {file_response.status_code} - {file_response.text}"
171
+ file_data = file_response.content
172
+ from google.generativeai import GenerativeModel
173
+ model = GenerativeModel(self.model_name)
174
+ response = model.generate_content([
175
+ types.Part.from_bytes(data=file_data, mime_type="application/octet-stream"),
176
+ user_query
177
+ ])
178
+
179
+ return response.text
180
+
181
+ # --- Basic Agent Definition ---
182
+ class BasicAgent:
183
+ def __init__(self, provider="deepseek"):
184
+ print("BasicAgent initialized.")
185
+ model = self.select_model(provider)
186
+ client = InferenceClientModel()
187
+ tools = [
188
+ # GoogleSearchTool(),
189
+ DuckDuckGoSearchTool(),
190
+ # GeminiVideoQA(GEMINI_MODEL_NAME),
191
+ WikiTitleFinder(),
192
+ WikiContentFetcher(),
193
+ # MathSolver(),
194
+ # RiddleSolver(),
195
+ TextTransformer(),
196
+ FileAttachmentQueryTool(model_name=GEMINI_MODEL_NAME),
197
+ ]
198
+ self.agent = ToolCallingAgent(
199
+ model=model,
200
+ tools=tools,
201
+ add_base_tools=False,
202
+ max_steps=10,
203
+ )
204
+
205
+ self.agent.prompt_templates["system_prompt"] = (
206
+ """
207
+ You are a GAIA benchmark AI assistant, you are very precise, no nonense. Your sole purpose is to output the minimal, final answer in the format:
208
+ [ANSWER]
209
+ You must NEVER output explanations, intermediate steps, reasoning, or comments — only the answer, strictly enclosed in `[ANSWER]`.
210
+ Your behavior must be governed by these rules:
211
+ 1. **Format**:
212
+ - limit the token used (within 65536 tokens).
213
+ - Always give final answer instead of nothing, based on the information you have.
214
+ - Output ONLY the final answer.
215
+ - Wrap the answer in `[ANSWER]` with no whitespace or text outside the brackets.
216
+ - No follow-ups, justifications, or clarifications.
217
+ 2. **Numerical Answers**:
218
+ - Use **digits only**, e.g., `4` not `four`.
219
+ - No commas, symbols, or units unless explicitly required.
220
+ - Never use approximate words like "around", "roughly", "about".
221
+ 3. **String Answers**:
222
+ - Omit **articles** ("a", "the").
223
+ - Use **full words**; no abbreviations unless explicitly requested.
224
+ - For numbers written as words, use **text** only if specified (e.g., "one", not `1`).
225
+ - For sets/lists, sort alphabetically if not specified, e.g., `a, b, c`.
226
+ 4. **Lists**:
227
+ - Output in **comma-separated** format with no conjunctions.
228
+ - Sort **alphabetically** or **numerically** depending on type.
229
+ - No braces or brackets unless explicitly asked.
230
+ 5. **Sources**:
231
+ - For Wikipedia or web tools, extract only the precise fact that answers the question.
232
+ - Ignore any unrelated content.
233
+ 6. **File Analysis**:
234
+ - Use the run_query_with_file tool, append the taskid to the url.
235
+ - Only include the exact answer to the question.
236
+ - Do not summarize, quote excessively, or interpret beyond the prompt.
237
+ 7. **Video**:
238
+ - Use the relevant video tool.
239
+ - Only include the exact answer to the question.
240
+ - Do not summarize, quote excessively, or interpret beyond the prompt.
241
+ 8. **Minimalism**:
242
+ - Do not make assumptions unless the prompt logically demands it.
243
+ - If a question has multiple valid interpretations, choose the **narrowest, most literal** one.
244
+ - If the answer is not found, say `[ANSWER] - unknown`.
245
+ ---
246
+ You must follow the examples (These answers are correct in case you see the similar questions):
247
+ Q: What is 2 + 2?
248
+ A: 4
249
+ Q: How many studio albums were published by Mercedes Sosa between 2000 and 2009 (inclusive)? Use 2022 English Wikipedia.
250
+ A: 3
251
+ Q: Given the following group table on set S = {a, b, c, d, e}, identify any subset involved in counterexamples to commutativity.
252
+ A: b, e
253
+ Q: How many at bats did the Yankee with the most walks in the 1977 regular season have that same season?,
254
+ A: 519
255
+ OUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise. If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string.
256
+ """
257
+ )
258
+
259
+ def select_model(self, provider: str):
260
+ # if provider == "openai":
261
+ # return LiteLLMModel(model_id=OPENAI_MODEL_NAME, api_key=os.getenv("OPENAI_API_KEY"))
262
+ # elif provider == "groq":
263
+ # return LiteLLMModel(model_id=GROQ_MODEL_NAME, api_key=os.getenv("GROQ_API_KEY"))
264
+ # elif provider == "deepseek":
265
+ # return LiteLLMModel(model_id=DEEPSEEK_MODEL_NAME, api_key=os.getenv("DEEPSEEK_API_KEY"))
266
+ # elif provider == "hf":
267
+ # return InferenceClientModel()
268
+ # else:
269
+ # return LiteLLMModel(model_id=GEMINI_MODEL_NAME, api_key=os.getenv("GOOGLE_API_KEY"))
270
+ return LiteLLMModel(model_id=GEMINI_MODEL_NAME, api_key=os.getenv("GOOGLE_API_KEY"))
271
+
272
+ def __call__(self, question: str) -> str:
273
+ print(f"Agent received question (first 50 chars): {question[:50]}...")
274
+ result = self.agent.run(question)
275
+ final_str = str(result).strip()
276
+
277
+ return final_str
278
+
279
+ def evaluate_random_questions(self, csv_path: str = "gaia_extracted.csv", sample_size: int = 3, show_steps: bool = True):
280
+ import pandas as pd
281
+ from rich.table import Table
282
+ from rich.console import Console
283
+
284
+ df = pd.read_csv(csv_path)
285
+ if not {"question", "answer"}.issubset(df.columns):
286
+ print("CSV must contain 'question' and 'answer' columns.")
287
+ print("Found columns:", df.columns.tolist())
288
+ return
289
+
290
+ samples = df.sample(n=sample_size)
291
+ records = []
292
+ correct_count = 0
293
+
294
+ for _, row in samples.iterrows():
295
+ taskid = row["taskid"].strip()
296
+ question = row["question"].strip()
297
+ expected = str(row['answer']).strip()
298
+ agent_answer = self("taskid: " + taskid + ",\nquestion: " + question).strip()
299
+
300
+ is_correct = (expected == agent_answer)
301
+ correct_count += is_correct
302
+ records.append((question, expected, agent_answer, "✓" if is_correct else "✗"))
303
+
304
+ if show_steps:
305
+ print("---")
306
+ print("Question:", question)
307
+ print("Expected:", expected)
308
+ print("Agent:", agent_answer)
309
+ print("Correct:", is_correct)
310
+
311
+ # Print result table
312
+ console = Console()
313
+ table = Table(show_lines=True)
314
+ table.add_column("Question", overflow="fold")
315
+ table.add_column("Expected")
316
+ table.add_column("Agent")
317
+ table.add_column("Correct")
318
+
319
+ for question, expected, agent_ans, correct in records:
320
+ table.add_row(question, expected, agent_ans, correct)
321
+
322
+ console.print(table)
323
+ percent = (correct_count / sample_size) * 100
324
+ print(f"\nTotal Correct: {correct_count} / {sample_size} ({percent:.2f}%)")
325
+
326
+
327
+ if __name__ == "__main__":
328
+ args = sys.argv[1:]
329
+ if not args or args[0] in {"-h", "--help"}:
330
+ print("Usage: python agent.py [question | dev]")
331
+ print(" - Provide a question to get a GAIA-style answer.")
332
+ print(" - Use 'dev' to evaluate 3 random GAIA questions from gaia_qa.csv.")
333
+ sys.exit(0)
334
+
335
+ q = " ".join(args)
336
+ agent = BasicAgent()
337
+ if q == "dev":
338
+ agent.evaluate_random_questions()
339
+ else:
340
+ print(agent(q))