|
import json |
|
import os |
|
from typing import List, Dict |
|
from agent import GAIAAgent |
|
|
|
def normalize_answer(answer: str) -> str: |
|
"""Normalize answer for comparison.""" |
|
if not answer: |
|
return "" |
|
|
|
|
|
answer = answer.strip() |
|
|
|
|
|
if (answer.startswith('"') and answer.endswith('"')) or (answer.startswith("'") and answer.endswith("'")): |
|
answer = answer[1:-1] |
|
|
|
|
|
return answer.lower().strip() |
|
|
|
def extract_final_answer(response: str) -> str: |
|
"""Extract the final answer from the model response.""" |
|
if "FINAL ANSWER:" in response: |
|
answer = response.split("FINAL ANSWER:")[1].strip() |
|
|
|
answer = answer.split('\n')[0].strip() |
|
return answer |
|
|
|
|
|
lines = response.strip().split('\n') |
|
return lines[-1].strip() |
|
|
|
def load_gaia_dataset(dataset_path: str) -> List[Dict]: |
|
"""Load GAIA dataset from JSON/JSONL file.""" |
|
tasks = [] |
|
|
|
if not os.path.exists(dataset_path): |
|
print(f"Dataset file not found: {dataset_path}") |
|
return tasks |
|
|
|
try: |
|
with open(dataset_path, "r", encoding="utf-8") as f: |
|
if dataset_path.endswith('.jsonl'): |
|
|
|
for line_num, line in enumerate(f, 1): |
|
line = line.strip() |
|
if line: |
|
try: |
|
task = json.loads(line) |
|
tasks.append(task) |
|
except json.JSONDecodeError as e: |
|
print(f"Error parsing line {line_num}: {e}") |
|
else: |
|
|
|
data = json.load(f) |
|
if isinstance(data, list): |
|
tasks = data |
|
elif isinstance(data, dict) and 'tasks' in data: |
|
tasks = data['tasks'] |
|
else: |
|
print("Unexpected JSON format") |
|
|
|
except Exception as e: |
|
print(f"Error loading dataset: {e}") |
|
|
|
print(f"Loaded {len(tasks)} tasks from {dataset_path}") |
|
return tasks |
|
|
|
def create_sample_dataset() -> List[Dict]: |
|
"""Create a sample dataset for testing if no GAIA dataset is available.""" |
|
sample_tasks = [ |
|
{ |
|
"task_id": "sample_1", |
|
"question": "What is 15 + 27?", |
|
"answer": "42", |
|
"level": 1, |
|
"file_name": None |
|
}, |
|
{ |
|
"task_id": "sample_2", |
|
"question": "What is the capital of France?", |
|
"answer": "Paris", |
|
"level": 1, |
|
"file_name": None |
|
}, |
|
{ |
|
"task_id": "sample_3", |
|
"question": "How many days are in a leap year?", |
|
"answer": "366", |
|
"level": 1, |
|
"file_name": None |
|
}, |
|
{ |
|
"task_id": "sample_4", |
|
"question": "What is 2 * 6 * 7?", |
|
"answer": "84", |
|
"level": 1, |
|
"file_name": None |
|
}, |
|
{ |
|
"task_id": "sample_5", |
|
"question": "What year did World War II end?", |
|
"answer": "1945", |
|
"level": 1, |
|
"file_name": None |
|
} |
|
] |
|
|
|
print("Using sample dataset for testing") |
|
return sample_tasks |
|
|
|
def evaluate_agent(dataset_path: str = None, max_tasks: int = None) -> float: |
|
"""Evaluate the GAIA agent on the dataset.""" |
|
|
|
if dataset_path and os.path.exists(dataset_path): |
|
tasks = load_gaia_dataset(dataset_path) |
|
else: |
|
print("No dataset file found, using sample tasks for testing") |
|
tasks = create_sample_dataset() |
|
|
|
if not tasks: |
|
print("No tasks to evaluate") |
|
return 0.0 |
|
|
|
|
|
if max_tasks: |
|
tasks = tasks[:max_tasks] |
|
print(f"Evaluating on first {len(tasks)} tasks") |
|
|
|
|
|
print("Initializing GAIA agent...") |
|
agent = GAIAAgent() |
|
|
|
|
|
print("Testing API connection...") |
|
test_response = agent.test_grok() |
|
if "error" in test_response.lower(): |
|
print(f"API test failed: {test_response}") |
|
return 0.0 |
|
else: |
|
print("API connection successful!") |
|
|
|
|
|
correct = 0 |
|
total = len(tasks) |
|
submission_entries = [] |
|
|
|
print(f"\nStarting evaluation on {total} tasks...") |
|
print("=" * 50) |
|
|
|
for i, task in enumerate(tasks, 1): |
|
task_id = task.get("task_id", f"task_{i}") |
|
question = task.get("question", "") |
|
expected_answer = task.get("answer", "") |
|
|
|
print(f"\nTask {i}/{total}: {task_id}") |
|
print(f"Question: {question[:100]}{'...' if len(question) > 100 else ''}") |
|
|
|
try: |
|
|
|
response = agent.process_task(task) |
|
predicted_answer = extract_final_answer(response) |
|
|
|
print(f"Expected: {expected_answer}") |
|
print(f"Predicted: {predicted_answer}") |
|
|
|
|
|
is_correct = normalize_answer(predicted_answer) == normalize_answer(expected_answer) |
|
|
|
if is_correct: |
|
correct += 1 |
|
print("β
CORRECT") |
|
else: |
|
print("β INCORRECT") |
|
|
|
|
|
submission_entries.append({ |
|
"task_id": task_id, |
|
"model_answer": predicted_answer, |
|
"reasoning_trace": response |
|
}) |
|
|
|
except Exception as e: |
|
print(f"Error processing task {task_id}: {e}") |
|
submission_entries.append({ |
|
"task_id": task_id, |
|
"model_answer": "ERROR", |
|
"reasoning_trace": f"Error: {str(e)}" |
|
}) |
|
|
|
|
|
current_score = (correct / i) * 100 |
|
print(f"Current score: {correct}/{i} = {current_score:.1f}%") |
|
print("-" * 30) |
|
|
|
|
|
final_score = (correct / total) * 100 |
|
|
|
|
|
try: |
|
with open("submission.jsonl", "w", encoding="utf-8") as f: |
|
for entry in submission_entries: |
|
f.write(json.dumps(entry) + "\n") |
|
print(f"\nSubmission saved to submission.jsonl") |
|
except Exception as e: |
|
print(f"Error saving submission: {e}") |
|
|
|
|
|
print("=" * 50) |
|
print("FINAL RESULTS") |
|
print("=" * 50) |
|
print(f"Total tasks: {total}") |
|
print(f"Correct answers: {correct}") |
|
print(f"Final score: {final_score:.2f}%") |
|
|
|
if final_score >= 30: |
|
print("π CONGRATULATIONS! Score β₯30% - Certificate achieved!") |
|
else: |
|
print(f"π Score below 30%. Need {30 - final_score:.2f}% more for certificate.") |
|
|
|
return final_score |
|
|
|
def main(): |
|
"""Main evaluation function.""" |
|
import argparse |
|
|
|
parser = argparse.ArgumentParser(description="Evaluate GAIA agent") |
|
parser.add_argument("--dataset", type=str, default="gaia_test.json", |
|
help="Path to GAIA dataset file") |
|
parser.add_argument("--max-tasks", type=int, default=None, |
|
help="Maximum number of tasks to evaluate") |
|
|
|
args = parser.parse_args() |
|
|
|
score = evaluate_agent(args.dataset, args.max_tasks) |
|
|
|
print(f"\nFinal evaluation score: {score:.2f}%") |
|
|
|
if score >= 30: |
|
print("Certificate requirements met! π") |
|
else: |
|
print("Keep working to reach 30% for the certificate! πͺ") |
|
|
|
if __name__ == "__main__": |
|
main() |