|
|
|
""" |
|
Test script to verify GAIA agent setup and functionality. |
|
""" |
|
|
|
from agent import GAIAAgent |
|
from tools import web_search, read_file, calculate_simple_math |
|
|
|
def test_api_connection(): |
|
"""Test xAI API connection.""" |
|
print("Testing xAI API connection...") |
|
agent = GAIAAgent() |
|
|
|
try: |
|
response = agent.test_grok() |
|
print(f"API Response: {response}") |
|
|
|
if "error" in response.lower(): |
|
print("β API test failed") |
|
return False |
|
else: |
|
print("β
API connection successful") |
|
return True |
|
except Exception as e: |
|
print(f"β API test error: {e}") |
|
return False |
|
|
|
def test_basic_reasoning(): |
|
"""Test basic reasoning capabilities.""" |
|
print("\nTesting basic reasoning...") |
|
agent = GAIAAgent() |
|
|
|
test_cases = [ |
|
{ |
|
"task_id": "test_math", |
|
"question": "What is 25 + 17?", |
|
"expected": "42" |
|
}, |
|
{ |
|
"task_id": "test_general", |
|
"question": "What is the capital of Japan?", |
|
"expected": "tokyo" |
|
} |
|
] |
|
|
|
for test_case in test_cases: |
|
print(f"\nTest: {test_case['question']}") |
|
try: |
|
response = agent.process_task(test_case) |
|
predicted = agent.extract_final_answer(response) |
|
print(f"Response: {predicted}") |
|
|
|
|
|
if test_case['expected'].lower() in predicted.lower(): |
|
print("β
Test passed") |
|
else: |
|
print("β Test failed") |
|
|
|
except Exception as e: |
|
print(f"β Test error: {e}") |
|
|
|
def test_tools(): |
|
"""Test individual tools.""" |
|
print("\nTesting tools...") |
|
|
|
|
|
print("\n1. Testing math calculation:") |
|
result = calculate_simple_math("15 + 27") |
|
print(f"15 + 27 = {result}") |
|
|
|
|
|
print("\n2. Testing web search:") |
|
search_result = web_search("capital of France", None) |
|
print(f"Search result: {search_result[:100]}...") |
|
|
|
|
|
print("\n3. Testing file reading:") |
|
file_result = read_file("nonexistent.txt") |
|
print(f"File read result: {file_result}") |
|
|
|
def test_sample_task(): |
|
"""Test with a sample GAIA-like task.""" |
|
print("\nTesting sample GAIA task...") |
|
|
|
agent = GAIAAgent() |
|
|
|
sample_task = { |
|
"task_id": "sample_test", |
|
"question": "If a store has 150 apples and sells 87 of them, how many apples are left?", |
|
"answer": "63", |
|
"file_name": None |
|
} |
|
|
|
try: |
|
print(f"Question: {sample_task['question']}") |
|
response = agent.process_task(sample_task) |
|
predicted = agent.extract_final_answer(response) |
|
expected = sample_task['answer'] |
|
|
|
print(f"Expected: {expected}") |
|
print(f"Predicted: {predicted}") |
|
|
|
if predicted.strip() == expected: |
|
print("β
Sample task passed") |
|
else: |
|
print("β Sample task failed") |
|
|
|
except Exception as e: |
|
print(f"β Sample task error: {e}") |
|
|
|
def main(): |
|
"""Run all tests.""" |
|
print("GAIA Agent Test Suite") |
|
print("=" * 50) |
|
|
|
|
|
api_ok = test_api_connection() |
|
|
|
if not api_ok: |
|
print("\nβ API connection failed. Cannot proceed with other tests.") |
|
print("Please check your API key and internet connection.") |
|
return |
|
|
|
|
|
test_basic_reasoning() |
|
test_tools() |
|
test_sample_task() |
|
|
|
print("\n" + "=" * 50) |
|
print("Test suite completed!") |
|
print("If all tests passed, you can run: python evaluate.py") |
|
|
|
if __name__ == "__main__": |
|
main() |