Final_Assignment / tests /clean_batch_test.py
GAIA Developer
πŸ§ͺ Add comprehensive test infrastructure and async testing system
c262d1a
#!/usr/bin/env python3
"""
Clean Batch Test - No overrides, pure LLM reasoning with tools
Based on test_specific_question.py but for all questions at once
"""
import os
import sys
import json
import time
from pathlib import Path
from dotenv import load_dotenv
from concurrent.futures import ThreadPoolExecutor, as_completed
# Load environment variables
load_dotenv()
# Add parent directory to path for imports
sys.path.append(str(Path(__file__).parent.parent))
# Local imports
from gaia_web_loader import GAIAQuestionLoaderWeb
from main import GAIASolver
from question_classifier import QuestionClassifier
def load_validation_answers():
"""Load correct answers from GAIA validation metadata"""
answers = {}
try:
validation_path = Path(__file__).parent.parent / 'gaia_validation_metadata.jsonl'
with open(validation_path, 'r') as f:
for line in f:
if line.strip():
data = json.loads(line.strip())
task_id = data.get('task_id')
final_answer = data.get('Final answer')
if task_id and final_answer:
answers[task_id] = final_answer
except Exception as e:
print(f"⚠️ Could not load validation data: {e}")
return answers
def validate_answer(task_id: str, our_answer: str, validation_answers: dict):
"""Validate our answer against the correct answer"""
if task_id not in validation_answers:
return None
expected = str(validation_answers[task_id]).strip()
our_clean = str(our_answer).strip()
# Exact match
if our_clean.lower() == expected.lower():
return {"status": "CORRECT", "expected": expected, "our": our_clean}
# Check if our answer contains the expected answer
if expected.lower() in our_clean.lower():
return {"status": "PARTIAL", "expected": expected, "our": our_clean}
return {"status": "INCORRECT", "expected": expected, "our": our_clean}
def test_single_question(question_data, validation_answers, model="qwen3-235b"):
"""Test a single question without any overrides"""
task_id = question_data.get('task_id', 'unknown')
try:
print(f"πŸ§ͺ [{task_id[:8]}...] Starting...")
# Initialize solver and classifier
solver = GAIASolver(use_kluster=True, kluster_model=model)
classifier = QuestionClassifier()
# Classify the question
question_text = question_data.get('question', '')
file_name = question_data.get('file_name', '')
classification = classifier.classify_question(question_text, file_name)
# Solve the question (NO OVERRIDES - pure LLM reasoning)
start_time = time.time()
answer = solver.solve_question(question_data)
end_time = time.time()
duration = end_time - start_time
# Validate answer
validation_result = validate_answer(task_id, answer, validation_answers)
result = {
'task_id': task_id,
'question_type': classification['primary_agent'],
'complexity': classification['complexity'],
'confidence': classification['confidence'],
'our_answer': str(answer),
'expected_answer': validation_result['expected'] if validation_result else 'N/A',
'status': validation_result['status'] if validation_result else 'NO_VALIDATION',
'duration': duration,
'question_preview': question_data.get('question', '')[:50] + "..."
}
status_icon = "βœ…" if result['status'] == "CORRECT" else "🟑" if result['status'] == "PARTIAL" else "❌"
print(f"{status_icon} [{task_id[:8]}...] {result['status']} | {result['question_type']} | {duration:.1f}s")
return result
except Exception as e:
print(f"❌ [{task_id[:8]}...] ERROR: {str(e)}")
return {
'task_id': task_id,
'question_type': 'error',
'complexity': 0,
'confidence': 0.0,
'our_answer': '',
'expected_answer': validation_answers.get(task_id, 'N/A'),
'status': 'ERROR',
'duration': 0.0,
'error': str(e),
'question_preview': question_data.get('question', '')[:50] + "..."
}
def run_clean_batch_test():
"""Run clean batch test on all questions"""
print("πŸ§ͺ CLEAN BATCH TEST - NO OVERRIDES")
print("=" * 60)
print("🎯 Goal: Measure real accuracy with pure LLM reasoning")
print("🚫 No hardcoded answers or overrides")
print("πŸ€– Pure LLM + Tools reasoning only")
print()
# Load questions and validation data
print("πŸ“‹ Loading GAIA questions...")
loader = GAIAQuestionLoaderWeb()
all_questions = loader.questions
validation_answers = load_validation_answers()
print(f"βœ… Loaded {len(all_questions)} questions")
print(f"βœ… Loaded {len(validation_answers)} validation answers")
# Show question preview
print(f"\nπŸ“‹ Questions to test:")
for i, q in enumerate(all_questions[:5]): # Show first 5
task_id = q.get('task_id', 'unknown')
question_preview = q.get('question', '')[:40] + "..."
level = q.get('Level', 'Unknown')
has_file = "πŸ“Ž" if q.get('file_name') else "πŸ“"
print(f" {i+1}. {task_id[:8]}... | L{level} | {has_file} | {question_preview}")
if len(all_questions) > 5:
print(f" ... and {len(all_questions) - 5} more questions")
print(f"\nπŸš€ Starting clean batch test...")
print(f"⏱️ Estimated time: ~{len(all_questions) * 2} minutes")
# Process all questions sequentially (to avoid resource conflicts)
start_time = time.time()
results = []
for i, question_data in enumerate(all_questions):
print(f"\nπŸ“Š Progress: {i+1}/{len(all_questions)}")
result = test_single_question(question_data, validation_answers)
results.append(result)
end_time = time.time()
total_duration = end_time - start_time
# Analyze results
print(f"\n" + "=" * 60)
print(f"🏁 CLEAN BATCH TEST RESULTS")
print(f"=" * 60)
# Calculate metrics
total_questions = len(results)
correct_answers = len([r for r in results if r['status'] == 'CORRECT'])
partial_answers = len([r for r in results if r['status'] == 'PARTIAL'])
incorrect_answers = len([r for r in results if r['status'] == 'INCORRECT'])
errors = len([r for r in results if r['status'] == 'ERROR'])
accuracy_rate = correct_answers / total_questions * 100
success_rate = (correct_answers + partial_answers) / total_questions * 100
print(f"⏱️ Total Duration: {int(total_duration // 60)}m {int(total_duration % 60)}s")
print(f"βœ… Pure Accuracy: {accuracy_rate:.1f}% ({correct_answers}/{total_questions})")
print(f"🎯 Success Rate: {success_rate:.1f}% (including partial)")
print(f"⚑ Avg per Question: {total_duration/total_questions:.1f}s")
print(f"\nπŸ“Š DETAILED BREAKDOWN:")
print(f" βœ… CORRECT: {correct_answers} ({correct_answers/total_questions:.1%})")
print(f" 🟑 PARTIAL: {partial_answers} ({partial_answers/total_questions:.1%})")
print(f" ❌ INCORRECT: {incorrect_answers} ({incorrect_answers/total_questions:.1%})")
print(f" πŸ’₯ ERROR: {errors} ({errors/total_questions:.1%})")
# Classification performance
print(f"\n🎯 CLASSIFICATION PERFORMANCE:")
classification_stats = {}
for result in results:
classification = result['question_type']
if classification not in classification_stats:
classification_stats[classification] = {'total': 0, 'correct': 0, 'partial': 0}
classification_stats[classification]['total'] += 1
if result['status'] == 'CORRECT':
classification_stats[classification]['correct'] += 1
elif result['status'] == 'PARTIAL':
classification_stats[classification]['partial'] += 1
for classification, stats in sorted(classification_stats.items()):
total = stats['total']
correct = stats['correct']
partial = stats['partial']
accuracy = correct / total * 100 if total > 0 else 0
success = (correct + partial) / total * 100 if total > 0 else 0
print(f" {classification:15} | {accuracy:5.1f}% acc | {success:5.1f}% success | {total:2d} questions")
# Detailed results
print(f"\nπŸ“‹ DETAILED QUESTION RESULTS:")
for i, result in enumerate(results):
status_icon = "βœ…" if result['status'] == "CORRECT" else "🟑" if result['status'] == "PARTIAL" else "❌"
print(f" {i+1:2d}. {status_icon} {result['task_id'][:8]}... | {result['question_type']:12} | {result['status']:9} | {result['duration']:5.1f}s")
print(f" Expected: {result['expected_answer']}")
print(f" Got: {result['our_answer']}")
if 'error' in result:
print(f" Error: {result['error']}")
print()
# Save results
timestamp = time.strftime("%Y%m%d_%H%M%S")
results_file = f"logs/clean_batch_test_{timestamp}.json"
with open(results_file, 'w') as f:
json.dump({
'test_metadata': {
'timestamp': timestamp,
'test_type': 'clean_batch_no_overrides',
'total_questions': total_questions,
'duration_seconds': total_duration,
'model': 'qwen3-235b'
},
'metrics': {
'accuracy_rate': accuracy_rate,
'success_rate': success_rate,
'correct_answers': correct_answers,
'partial_answers': partial_answers,
'incorrect_answers': incorrect_answers,
'errors': errors
},
'classification_performance': classification_stats,
'detailed_results': results
}, f, indent=2)
print(f"πŸ“ Results saved to: {results_file}")
# Final assessment
print(f"\n🎯 FINAL ASSESSMENT:")
if accuracy_rate >= 70:
print(f"πŸ† EXCELLENT: {accuracy_rate:.1f}% accuracy achieves 70%+ target!")
elif accuracy_rate >= 50:
print(f"πŸ”§ GOOD PROGRESS: {accuracy_rate:.1f}% accuracy, approaching target")
elif accuracy_rate >= 30:
print(f"⚠️ MODERATE: {accuracy_rate:.1f}% accuracy, significant room for improvement")
else:
print(f"🚨 NEEDS WORK: {accuracy_rate:.1f}% accuracy requires major improvements")
print(f"\nπŸ” This is the REAL accuracy without any hardcoded answers!")
print(f"πŸ“Š Pure LLM + Tools Performance: {accuracy_rate:.1f}%")
return accuracy_rate, results
if __name__ == "__main__":
accuracy, results = run_clean_batch_test()
print(f"\nπŸŽ‰ Clean batch test completed!")
print(f"πŸ“Š Real Accuracy: {accuracy:.1f}%")