Coool2 commited on
Commit
4c20a23
·
1 Parent(s): 6e0b138

Update agent2.py

Browse files
Files changed (1) hide show
  1. agent2.py +260 -51
agent2.py CHANGED
@@ -1,5 +1,6 @@
1
  import os
2
  import requests
 
3
  from typing import Dict, Any, List
4
  from langchain.docstore.document import Document
5
  from langchain.text_splitter import RecursiveCharacterTextSplitter
@@ -16,6 +17,16 @@ from io import BytesIO
16
  from time import sleep
17
  from smolagents import PythonInterpreterTool, SpeechToTextTool
18
 
 
 
 
 
 
 
 
 
 
 
19
  class BM25RetrieverTool(Tool):
20
  """
21
  BM25 retriever tool for document search when text documents are available
@@ -171,17 +182,20 @@ def save_screenshot_callback(memory_step: ActionStep, agent: CodeAgent) -> None:
171
 
172
  class GAIAAgent:
173
  """
174
- Simplified GAIA agent using smolagents with Gemini 2.0 Flash
175
  """
176
 
177
- def __init__(self):
178
- """Initialize the agent with Gemini 2.0 Flash and tools"""
179
 
180
- # Get Gemini API key
181
  gemini_api_key = os.environ.get("GOOGLE_API_KEY")
182
  if not gemini_api_key:
183
  raise ValueError("GOOGLE_API_KEY environment variable not found")
184
 
 
 
 
185
  # Initialize Gemini 2.0 Flash model
186
  self.model = OpenAIServerModel(
187
  model_id="gemini-2.0-flash",
@@ -189,6 +203,10 @@ class GAIAAgent:
189
  api_key=gemini_api_key,
190
  )
191
 
 
 
 
 
192
  # GAIA system prompt from the leaderboard
193
  self.system_prompt = """You are a general AI assistant. I will ask you a question. Report your thoughts and reasoning process clearly. You should use the available tools to gather information and solve problems step by step.
194
 
@@ -214,6 +232,42 @@ Your final answer should be as few words as possible, a number, or a comma-separ
214
  self.agent = None
215
  self._create_agent()
216
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
217
  def _create_agent(self):
218
  """Create the CodeAgent with tools"""
219
  base_tools = [
@@ -228,8 +282,8 @@ Your final answer should be as few words as possible, a number, or a comma-separ
228
  self.agent = CodeAgent(
229
  tools=base_tools + [PythonInterpreterTool(), SpeechToTextTool()],
230
  model=self.model,
231
- add_base_tools=False, # Adds web search, python execution, etc.
232
- planning_interval=2, # Plan every 2 steps
233
  additional_authorized_imports=["helium", "requests", "BeautifulSoup", "json"],
234
  step_callbacks=[save_screenshot_callback] if self.driver else [],
235
  max_steps=10,
@@ -308,70 +362,225 @@ Your final answer should be as few words as possible, a number, or a comma-separ
308
  print(f"Failed to download file for task {task_id}: {e}")
309
  return None
310
 
311
- def solve_gaia_question(self, question_data: Dict[str, Any]) -> str:
312
  """
313
- Solve a GAIA question
314
  """
315
  question = question_data.get("Question", "")
316
  task_id = question_data.get("task_id", "")
317
-
318
- # Download and load file if task_id provided
 
 
 
319
  if task_id:
320
- file_path = self.download_gaia_file(task_id)
321
- if file_path:
322
- self.load_documents_from_file(file_path)
323
- print(f"Loaded file for task {task_id}")
324
-
325
- # Check if this requires web browsing
326
- web_indicators = ["navigate", "browser", "website", "webpage", "url", "click", "search on"]
327
- needs_browser = any(indicator in question.lower() for indicator in web_indicators)
328
-
329
- if needs_browser and not self.driver:
330
- print("Initializing browser for web automation...")
331
- self.initialize_browser()
332
-
333
- # Prepare the prompt
334
- prompt = f"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
335
  Question: {question}
336
  {f'Task ID: {task_id}' if task_id else ''}
337
- {f'File loaded: Yes' if task_id else 'File loaded: No'}
338
 
339
  Solve this step by step. Use the available tools to gather information and provide a precise answer.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
340
  """
 
 
 
 
341
 
342
- if needs_browser:
343
- prompt += "\n" + helium_instructions
 
344
 
345
- try:
346
- print("=== AGENT REASONING ===")
347
- result = self.agent.run(prompt)
348
- print("=== END REASONING ===")
349
 
350
- return str(result)
 
351
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
352
  except Exception as e:
353
- error_msg = f"Error processing question: {str(e)}"
354
- print(error_msg)
355
- return error_msg
356
- finally:
357
- # Clean up browser if initialized
358
- if self.driver:
359
- try:
360
- helium.kill_browser()
361
- except:
362
- pass
363
-
364
-
365
- # Example usage
366
  if __name__ == "__main__":
367
- # Test the agent
368
- agent = GAIAAgent()
 
 
 
 
 
 
 
 
369
 
370
  # Example question
371
  question_data = {
372
- "Question": "How many studio albums Mercedes Sosa has published between 2000-2009 ?",
373
  "task_id": ""
374
  }
375
 
376
- answer = agent.solve_gaia_question(question_data)
377
- print(f"Answer: {answer}")
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
  import requests
3
+ import base64
4
  from typing import Dict, Any, List
5
  from langchain.docstore.document import Document
6
  from langchain.text_splitter import RecursiveCharacterTextSplitter
 
17
  from time import sleep
18
  from smolagents import PythonInterpreterTool, SpeechToTextTool
19
 
20
+ # Langfuse observability imports
21
+ from opentelemetry.sdk.trace import TracerProvider
22
+ from openinference.instrumentation.smolagents import SmolagentsInstrumentor
23
+ from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
24
+ from opentelemetry.sdk.trace.export import SimpleSpanProcessor
25
+ from opentelemetry import trace
26
+ from opentelemetry.trace import format_trace_id
27
+ from langfuse import Langfuse
28
+
29
+
30
  class BM25RetrieverTool(Tool):
31
  """
32
  BM25 retriever tool for document search when text documents are available
 
182
 
183
  class GAIAAgent:
184
  """
185
+ GAIA agent using smolagents with Gemini 2.0 Flash and Langfuse observability
186
  """
187
 
188
+ def __init__(self, user_id: str = None, session_id: str = None):
189
+ """Initialize the agent with Gemini 2.0 Flash, tools, and Langfuse observability"""
190
 
191
+ # Get API keys
192
  gemini_api_key = os.environ.get("GOOGLE_API_KEY")
193
  if not gemini_api_key:
194
  raise ValueError("GOOGLE_API_KEY environment variable not found")
195
 
196
+ # Initialize Langfuse observability
197
+ self._setup_langfuse_observability()
198
+
199
  # Initialize Gemini 2.0 Flash model
200
  self.model = OpenAIServerModel(
201
  model_id="gemini-2.0-flash",
 
203
  api_key=gemini_api_key,
204
  )
205
 
206
+ # Store user and session IDs for tracking
207
+ self.user_id = user_id or "gaia-user"
208
+ self.session_id = session_id or "gaia-session"
209
+
210
  # GAIA system prompt from the leaderboard
211
  self.system_prompt = """You are a general AI assistant. I will ask you a question. Report your thoughts and reasoning process clearly. You should use the available tools to gather information and solve problems step by step.
212
 
 
232
  self.agent = None
233
  self._create_agent()
234
 
235
+ # Initialize Langfuse client
236
+ self.langfuse = Langfuse()
237
+
238
+ def _setup_langfuse_observability(self):
239
+ """Set up Langfuse observability with OpenTelemetry"""
240
+ # Get Langfuse keys from environment variables
241
+ langfuse_public_key = os.environ.get("LANGFUSE_PUBLIC_KEY")
242
+ langfuse_secret_key = os.environ.get("LANGFUSE_SECRET_KEY")
243
+
244
+ if not langfuse_public_key or not langfuse_secret_key:
245
+ print("Warning: LANGFUSE_PUBLIC_KEY or LANGFUSE_SECRET_KEY not found. Observability will be limited.")
246
+ return
247
+
248
+ # Set up Langfuse environment variables
249
+ os.environ["LANGFUSE_HOST"] = os.environ.get("LANGFUSE_HOST", "https://cloud.langfuse.com")
250
+
251
+ langfuse_auth = base64.b64encode(
252
+ f"{langfuse_public_key}:{langfuse_secret_key}".encode()
253
+ ).decode()
254
+
255
+ os.environ["OTEL_EXPORTER_OTLP_ENDPOINT"] = os.environ.get("LANGFUSE_HOST") + "/api/public/otel"
256
+ os.environ["OTEL_EXPORTER_OTLP_HEADERS"] = f"Authorization=Basic {langfuse_auth}"
257
+
258
+ # Create a TracerProvider for OpenTelemetry
259
+ trace_provider = TracerProvider()
260
+
261
+ # Add a SimpleSpanProcessor with the OTLPSpanExporter to send traces
262
+ trace_provider.add_span_processor(SimpleSpanProcessor(OTLPSpanExporter()))
263
+
264
+ # Set the global default tracer provider
265
+ trace.set_tracer_provider(trace_provider)
266
+ self.tracer = trace.get_tracer(__name__)
267
+
268
+ # Instrument smolagents with the configured provider
269
+ SmolagentsInstrumentor().instrument(tracer_provider=trace_provider)
270
+
271
  def _create_agent(self):
272
  """Create the CodeAgent with tools"""
273
  base_tools = [
 
282
  self.agent = CodeAgent(
283
  tools=base_tools + [PythonInterpreterTool(), SpeechToTextTool()],
284
  model=self.model,
285
+ add_base_tools=False,
286
+ planning_interval=2,
287
  additional_authorized_imports=["helium", "requests", "BeautifulSoup", "json"],
288
  step_callbacks=[save_screenshot_callback] if self.driver else [],
289
  max_steps=10,
 
362
  print(f"Failed to download file for task {task_id}: {e}")
363
  return None
364
 
365
+ def solve_gaia_question(self, question_data: Dict[str, Any], tags: List[str] = None) -> str:
366
  """
367
+ Solve a GAIA question with full Langfuse observability
368
  """
369
  question = question_data.get("Question", "")
370
  task_id = question_data.get("task_id", "")
371
+
372
+ # Prepare tags for observability
373
+ trace_tags = ["gaia-agent", "question-solving"]
374
+ if tags:
375
+ trace_tags.extend(tags)
376
  if task_id:
377
+ trace_tags.append(f"task-{task_id}")
378
+
379
+ # Start Langfuse trace with OpenTelemetry
380
+ with self.tracer.start_as_current_span("GAIA-Question-Solving") as span:
381
+ try:
382
+ # Set span attributes for tracking
383
+ span.set_attribute("langfuse.user.id", self.user_id)
384
+ span.set_attribute("langfuse.session.id", self.session_id)
385
+ span.set_attribute("langfuse.tags", trace_tags)
386
+ span.set_attribute("gaia.task_id", task_id)
387
+ span.set_attribute("gaia.question_length", len(question))
388
+
389
+ # Get trace ID for Langfuse linking
390
+ current_span = trace.get_current_span()
391
+ span_context = current_span.get_span_context()
392
+ trace_id = span_context.trace_id
393
+ formatted_trace_id = format_trace_id(trace_id)
394
+
395
+ # Create Langfuse trace
396
+ langfuse_trace = self.langfuse.trace(
397
+ id=formatted_trace_id,
398
+ name="GAIA Question Solving",
399
+ input={"question": question, "task_id": task_id},
400
+ user_id=self.user_id,
401
+ session_id=self.session_id,
402
+ tags=trace_tags,
403
+ metadata={
404
+ "model": self.model.model_id,
405
+ "question_length": len(question),
406
+ "has_file": bool(task_id)
407
+ }
408
+ )
409
+
410
+ # Download and load file if task_id provided
411
+ file_loaded = False
412
+ if task_id:
413
+ file_path = self.download_gaia_file(task_id)
414
+ if file_path:
415
+ file_loaded = self.load_documents_from_file(file_path)
416
+ span.set_attribute("gaia.file_loaded", file_loaded)
417
+ print(f"Loaded file for task {task_id}")
418
+
419
+ # Check if this requires web browsing
420
+ web_indicators = ["navigate", "browser", "website", "webpage", "url", "click", "search on"]
421
+ needs_browser = any(indicator in question.lower() for indicator in web_indicators)
422
+ span.set_attribute("gaia.needs_browser", needs_browser)
423
+
424
+ if needs_browser and not self.driver:
425
+ print("Initializing browser for web automation...")
426
+ browser_initialized = self.initialize_browser()
427
+ span.set_attribute("gaia.browser_initialized", browser_initialized)
428
+
429
+ # Prepare the prompt
430
+ prompt = f"""
431
  Question: {question}
432
  {f'Task ID: {task_id}' if task_id else ''}
433
+ {f'File loaded: Yes' if file_loaded else 'File loaded: No'}
434
 
435
  Solve this step by step. Use the available tools to gather information and provide a precise answer.
436
+ """
437
+
438
+ if needs_browser:
439
+ prompt += "\n" + helium_instructions
440
+
441
+ print("=== AGENT REASONING ===")
442
+ result = self.agent.run(prompt)
443
+ print("=== END REASONING ===")
444
+
445
+ # Update Langfuse trace with result
446
+ langfuse_trace.update(
447
+ output={"answer": str(result)},
448
+ end_time=None # Will be set automatically
449
+ )
450
+
451
+ # Add success attributes
452
+ span.set_attribute("gaia.success", True)
453
+ span.set_attribute("gaia.answer_length", len(str(result)))
454
+
455
+ # Flush Langfuse data
456
+ self.langfuse.flush()
457
+
458
+ return str(result)
459
+
460
+ except Exception as e:
461
+ error_msg = f"Error processing question: {str(e)}"
462
+ print(error_msg)
463
+
464
+ # Log error to span and Langfuse
465
+ span.set_attribute("gaia.success", False)
466
+ span.set_attribute("gaia.error", str(e))
467
+
468
+ if 'langfuse_trace' in locals():
469
+ langfuse_trace.update(
470
+ output={"error": error_msg},
471
+ level="ERROR"
472
+ )
473
+
474
+ self.langfuse.flush()
475
+ return error_msg
476
+
477
+ finally:
478
+ # Clean up browser if initialized
479
+ if self.driver:
480
+ try:
481
+ helium.kill_browser()
482
+ except:
483
+ pass
484
+
485
+ def evaluate_answer(self, question: str, answer: str, expected_answer: str = None) -> Dict[str, Any]:
486
  """
487
+ Evaluate the agent's answer using LLM-as-a-Judge and optionally compare with expected answer
488
+ """
489
+ evaluation_prompt = f"""
490
+ Please evaluate the following answer to a question on a scale of 1-5:
491
 
492
+ Question: {question}
493
+ Answer: {answer}
494
+ {f'Expected Answer: {expected_answer}' if expected_answer else ''}
495
 
496
+ Rate the answer on:
497
+ 1. Accuracy (1-5)
498
+ 2. Completeness (1-5)
499
+ 3. Clarity (1-5)
500
 
501
+ Provide your rating as JSON: {{"accuracy": X, "completeness": Y, "clarity": Z, "overall": W, "reasoning": "explanation"}}
502
+ """
503
 
504
+ try:
505
+ # Use the same model to evaluate
506
+ evaluation_result = self.agent.run(evaluation_prompt)
507
+
508
+ # Try to parse JSON response
509
+ import json
510
+ try:
511
+ scores = json.loads(evaluation_result)
512
+ return scores
513
+ except:
514
+ # Fallback if JSON parsing fails
515
+ return {
516
+ "accuracy": 3,
517
+ "completeness": 3,
518
+ "clarity": 3,
519
+ "overall": 3,
520
+ "reasoning": "Could not parse evaluation response",
521
+ "raw_evaluation": evaluation_result
522
+ }
523
+
524
+ except Exception as e:
525
+ return {
526
+ "accuracy": 1,
527
+ "completeness": 1,
528
+ "clarity": 1,
529
+ "overall": 1,
530
+ "reasoning": f"Evaluation failed: {str(e)}"
531
+ }
532
+
533
+ def add_user_feedback(self, trace_id: str, feedback_score: int, comment: str = None):
534
+ """
535
+ Add user feedback to a specific trace
536
+
537
+ Args:
538
+ trace_id: The trace ID to add feedback to
539
+ feedback_score: Score from 0-5 (0=very bad, 5=excellent)
540
+ comment: Optional comment from user
541
+ """
542
+ try:
543
+ self.langfuse.score(
544
+ trace_id=trace_id,
545
+ name="user-feedback",
546
+ value=feedback_score,
547
+ comment=comment
548
+ )
549
+ self.langfuse.flush()
550
+ print(f"User feedback added: {feedback_score}/5")
551
  except Exception as e:
552
+ print(f"Error adding user feedback: {e}")
553
+
554
+
555
+ # Example usage with observability
 
 
 
 
 
 
 
 
 
556
  if __name__ == "__main__":
557
+ # Set up environment variables (you need to set these)
558
+ # os.environ["GOOGLE_API_KEY"] = "your-gemini-api-key"
559
+ # os.environ["LANGFUSE_PUBLIC_KEY"] = "pk-lf-..."
560
+ # os.environ["LANGFUSE_SECRET_KEY"] = "sk-lf-..."
561
+
562
+ # Test the agent with observability
563
+ agent = GAIAAgent(
564
+ user_id="test-user-123",
565
+ session_id="test-session-456"
566
+ )
567
 
568
  # Example question
569
  question_data = {
570
+ "Question": "How many studio albums Mercedes Sosa has published between 2000-2009?",
571
  "task_id": ""
572
  }
573
 
574
+ # Solve with full observability
575
+ answer = agent.solve_gaia_question(
576
+ question_data,
577
+ tags=["music-question", "discography"]
578
+ )
579
+ print(f"Answer: {answer}")
580
+
581
+ # Evaluate the answer
582
+ evaluation = agent.evaluate_answer(
583
+ question_data["Question"],
584
+ answer
585
+ )
586
+ print(f"Evaluation: {evaluation}")