Spaces:
Running
Running
#!/usr/bin/env python3 | |
""" | |
Enhanced GAIA Testing with Classification Filtering and Error Analysis | |
Test all questions by agent type with comprehensive error tracking and iterative improvement workflow. | |
""" | |
import json | |
import time | |
import argparse | |
import logging | |
import sys | |
from datetime import datetime | |
from typing import Dict, List, Optional | |
from collections import defaultdict | |
from pathlib import Path | |
# Add parent directory to path for imports | |
sys.path.append(str(Path(__file__).parent.parent)) | |
from gaia_web_loader import GAIAQuestionLoaderWeb | |
from main import GAIASolver | |
from question_classifier import QuestionClassifier | |
class GAIAClassificationTester: | |
"""Enhanced GAIA testing with classification-based filtering and error analysis""" | |
def __init__(self): | |
self.loader = GAIAQuestionLoaderWeb() | |
self.classifier = QuestionClassifier() | |
self.solver = GAIASolver() | |
self.results = [] | |
self.error_patterns = defaultdict(list) | |
# Create logs directory if it doesn't exist | |
Path("logs").mkdir(exist_ok=True) | |
# Setup logging | |
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
self.log_file = f"logs/classification_test_{timestamp}.log" | |
logging.basicConfig( | |
level=logging.INFO, | |
format='%(asctime)s - %(levelname)s - %(message)s', | |
handlers=[ | |
logging.FileHandler(self.log_file), | |
logging.StreamHandler() | |
] | |
) | |
self.logger = logging.getLogger(__name__) | |
# Load validation answers after logger is set up | |
self.validation_answers = self.load_validation_answers() | |
def load_validation_answers(self): | |
"""Load correct answers from GAIA validation metadata""" | |
answers = {} | |
try: | |
validation_path = Path(__file__).parent.parent / 'gaia_validation_metadata.jsonl' | |
with open(validation_path, 'r') as f: | |
for line in f: | |
if line.strip(): | |
data = json.loads(line.strip()) | |
task_id = data.get('task_id') | |
final_answer = data.get('Final answer') | |
if task_id and final_answer: | |
answers[task_id] = final_answer | |
self.logger.info(f"π Loaded {len(answers)} validation answers") | |
except Exception as e: | |
self.logger.error(f"β οΈ Could not load validation data: {e}") | |
return answers | |
def validate_answer(self, task_id: str, our_answer: str): | |
"""Validate our answer against the correct answer with format normalization""" | |
if task_id not in self.validation_answers: | |
return {"status": "NO_VALIDATION", "expected": "N/A", "our": our_answer} | |
expected = str(self.validation_answers[task_id]).strip() | |
our_clean = str(our_answer).strip() | |
# Exact match (case-insensitive) | |
if our_clean.lower() == expected.lower(): | |
return {"status": "CORRECT", "expected": expected, "our": our_clean} | |
# ENHANCED: Format normalization for comprehensive comparison | |
def normalize_format(text): | |
"""Enhanced normalization for fair comparison""" | |
import re | |
text = str(text).lower().strip() | |
# Remove currency symbols and normalize numbers | |
text = re.sub(r'[$β¬Β£Β₯]', '', text) | |
# Normalize spacing around commas and punctuation | |
text = re.sub(r'\s*,\s*', ', ', text) # "b,e" -> "b, e" | |
text = re.sub(r'\s*;\s*', '; ', text) # "a;b" -> "a; b" | |
text = re.sub(r'\s*:\s*', ': ', text) # "a:b" -> "a: b" | |
# Remove extra whitespace | |
text = re.sub(r'\s+', ' ', text).strip() | |
# Normalize decimal places and numbers | |
text = re.sub(r'(\d+)\.0+$', r'\1', text) # "89706.00" -> "89706" | |
text = re.sub(r'(\d+),(\d{3})', r'\1\2', text) # "89,706" -> "89706" | |
# Remove common formatting artifacts | |
text = re.sub(r'["""''`]', '"', text) # Normalize quotes | |
text = re.sub(r'[ββ]', '-', text) # Normalize dashes | |
text = re.sub(r'[^\w\s,.-]', '', text) # Remove special characters | |
# Handle common answer formats | |
text = re.sub(r'^the answer is\s*', '', text) | |
text = re.sub(r'^answer:\s*', '', text) | |
text = re.sub(r'^final answer:\s*', '', text) | |
return text | |
normalized_expected = normalize_format(expected) | |
normalized_our = normalize_format(our_clean) | |
# Check normalized exact match | |
if normalized_our == normalized_expected: | |
return {"status": "CORRECT", "expected": expected, "our": our_clean} | |
# For list-type answers, try element-wise comparison | |
if ',' in expected and ',' in our_clean: | |
expected_items = [item.strip().lower() for item in expected.split(',')] | |
our_items = [item.strip().lower() for item in our_clean.split(',')] | |
# Sort both lists for comparison (handles different ordering) | |
if sorted(expected_items) == sorted(our_items): | |
return {"status": "CORRECT", "expected": expected, "our": our_clean} | |
# Check if most items match (partial credit) | |
matching_items = set(expected_items) & set(our_items) | |
if len(matching_items) >= len(expected_items) * 0.7: # 70% match threshold | |
return {"status": "PARTIAL", "expected": expected, "our": our_clean} | |
# Check if our answer contains the expected answer (broader match) | |
if normalized_expected in normalized_our or normalized_our in normalized_expected: | |
return {"status": "PARTIAL", "expected": expected, "our": our_clean} | |
# ENHANCED: Numeric equivalence checking | |
import re | |
expected_numbers = re.findall(r'\d+(?:\.\d+)?', expected) | |
our_numbers = re.findall(r'\d+(?:\.\d+)?', our_clean) | |
if expected_numbers and our_numbers: | |
try: | |
# Compare primary numbers | |
expected_num = float(expected_numbers[0]) | |
our_num = float(our_numbers[0]) | |
# Allow small floating point differences | |
if abs(expected_num - our_num) < 0.01: | |
return {"status": "CORRECT", "expected": expected, "our": our_clean} | |
# Check for percentage differences (e.g., rounding errors) | |
if expected_num > 0: | |
percentage_diff = abs(expected_num - our_num) / expected_num | |
if percentage_diff < 0.01: # 1% tolerance | |
return {"status": "CORRECT", "expected": expected, "our": our_clean} | |
except (ValueError, IndexError): | |
pass | |
# ENHANCED: Fuzzy matching for near-correct answers | |
def fuzzy_similarity(str1, str2): | |
"""Calculate simple character-based similarity""" | |
if not str1 or not str2: | |
return 0.0 | |
# Convert to character sets | |
chars1 = set(str1.lower()) | |
chars2 = set(str2.lower()) | |
# Calculate Jaccard similarity | |
intersection = len(chars1 & chars2) | |
union = len(chars1 | chars2) | |
return intersection / union if union > 0 else 0.0 | |
# Check fuzzy similarity for near matches | |
similarity = fuzzy_similarity(normalized_expected, normalized_our) | |
if similarity > 0.8: # 80% character similarity | |
return {"status": "PARTIAL", "expected": expected, "our": our_clean} | |
# Final check: word-level matching | |
expected_words = set(normalized_expected.split()) | |
our_words = set(normalized_our.split()) | |
if expected_words and our_words: | |
word_overlap = len(expected_words & our_words) / len(expected_words) | |
if word_overlap > 0.7: # 70% word overlap | |
return {"status": "PARTIAL", "expected": expected, "our": our_clean} | |
return {"status": "INCORRECT", "expected": expected, "our": our_clean} | |
def classify_all_questions(self) -> Dict[str, List[Dict]]: | |
"""Classify all questions and group by agent type""" | |
self.logger.info("π§ Classifying all GAIA questions...") | |
questions_by_agent = defaultdict(list) | |
classification_stats = defaultdict(int) | |
for question_data in self.loader.questions: | |
task_id = question_data.get('task_id', 'unknown') | |
question_text = question_data.get('question', '') | |
file_name = question_data.get('file_name', '') | |
try: | |
classification = self.classifier.classify_question(question_text, file_name) | |
primary_agent = classification['primary_agent'] | |
# Add classification to question data | |
question_data['classification'] = classification | |
question_data['routing'] = self.classifier.get_routing_recommendation(classification) | |
questions_by_agent[primary_agent].append(question_data) | |
classification_stats[primary_agent] += 1 | |
self.logger.info(f" {task_id[:8]}... β {primary_agent} (confidence: {classification['confidence']:.3f})") | |
except Exception as e: | |
self.logger.error(f" β Classification failed for {task_id[:8]}...: {e}") | |
questions_by_agent['error'].append(question_data) | |
# Print classification summary | |
self.logger.info(f"\nπ CLASSIFICATION SUMMARY:") | |
total_questions = len(self.loader.questions) | |
for agent_type, count in sorted(classification_stats.items()): | |
percentage = (count / total_questions) * 100 | |
self.logger.info(f" {agent_type}: {count} questions ({percentage:.1f}%)") | |
return dict(questions_by_agent) | |
def test_agent_type(self, agent_type: str, questions: List[Dict], test_all: bool = False) -> List[Dict]: | |
"""Test all questions for a specific agent type""" | |
if not questions: | |
self.logger.warning(f"No questions found for agent type: {agent_type}") | |
return [] | |
self.logger.info(f"\nπ€ TESTING {agent_type.upper()} AGENT") | |
self.logger.info(f"=" * 60) | |
self.logger.info(f"Questions to test: {len(questions)}") | |
agent_results = [] | |
success_count = 0 | |
for i, question_data in enumerate(questions, 1): | |
task_id = question_data.get('task_id', 'unknown') | |
question_text = question_data.get('question', '') | |
file_name = question_data.get('file_name', '') | |
self.logger.info(f"\n[{i}/{len(questions)}] Testing {task_id[:8]}...") | |
self.logger.info(f"Question: {question_text[:100]}...") | |
if file_name: | |
self.logger.info(f"File: {file_name}") | |
try: | |
start_time = time.time() | |
answer = self.solver.solve_question(question_data) | |
solve_time = time.time() - start_time | |
# Validate answer against expected result | |
validation_result = self.validate_answer(task_id, answer) | |
# Log results with validation | |
self.logger.info(f"β Answer: {answer[:100]}...") | |
self.logger.info(f"β±οΈ Time: {solve_time:.1f}s") | |
self.logger.info(f"π Expected: {validation_result['expected']}") | |
self.logger.info(f"π Validation: {validation_result['status']}") | |
if validation_result['status'] == 'CORRECT': | |
self.logger.info(f"β PERFECT MATCH!") | |
actual_status = 'correct' | |
elif validation_result['status'] == 'PARTIAL': | |
self.logger.info(f"π‘ PARTIAL MATCH - contains correct answer") | |
actual_status = 'partial' | |
elif validation_result['status'] == 'INCORRECT': | |
self.logger.error(f"β INCORRECT - answers don't match") | |
actual_status = 'incorrect' | |
else: | |
self.logger.warning(f"β οΈ NO VALIDATION DATA") | |
actual_status = 'no_validation' | |
result = { | |
'question_id': task_id, | |
'question': question_text, | |
'file_name': file_name, | |
'agent_type': agent_type, | |
'classification': question_data.get('classification'), | |
'routing': question_data.get('routing'), | |
'answer': answer, | |
'solve_time': solve_time, | |
'status': 'completed', | |
'validation_status': validation_result['status'], | |
'expected_answer': validation_result['expected'], | |
'actual_status': actual_status, | |
'error_type': None, | |
'error_details': None | |
} | |
agent_results.append(result) | |
if actual_status == 'correct': | |
success_count += 1 | |
except Exception as e: | |
solve_time = time.time() - start_time | |
error_type = self.categorize_error(str(e)) | |
self.logger.error(f"β Error: {e}") | |
self.logger.error(f"Error Type: {error_type}") | |
result = { | |
'question_id': task_id, | |
'question': question_text, | |
'file_name': file_name, | |
'agent_type': agent_type, | |
'classification': question_data.get('classification'), | |
'routing': question_data.get('routing'), | |
'answer': f"Error: {str(e)}", | |
'solve_time': solve_time, | |
'status': 'error', | |
'error_type': error_type, | |
'error_details': str(e) | |
} | |
agent_results.append(result) | |
self.error_patterns[agent_type].append({ | |
'question_id': task_id, | |
'error_type': error_type, | |
'error_details': str(e), | |
'question_preview': question_text[:100] | |
}) | |
# Small delay to avoid overwhelming APIs | |
time.sleep(1) | |
# Agent type summary with accuracy metrics | |
error_count = len([r for r in agent_results if r['status'] == 'error']) | |
completed_count = len([r for r in agent_results if r['status'] == 'completed']) | |
correct_count = len([r for r in agent_results if r.get('actual_status') == 'correct']) | |
partial_count = len([r for r in agent_results if r.get('actual_status') == 'partial']) | |
incorrect_count = len([r for r in agent_results if r.get('actual_status') == 'incorrect']) | |
accuracy_rate = (correct_count / len(questions)) * 100 if questions else 0 | |
completion_rate = (completed_count / len(questions)) * 100 if questions else 0 | |
self.logger.info(f"\nπ {agent_type.upper()} AGENT RESULTS:") | |
self.logger.info(f" Completed: {completed_count}/{len(questions)} ({completion_rate:.1f}%)") | |
self.logger.info(f" β Correct: {correct_count}/{len(questions)} ({accuracy_rate:.1f}%)") | |
self.logger.info(f" π‘ Partial: {partial_count}/{len(questions)}") | |
self.logger.info(f" β Incorrect: {incorrect_count}/{len(questions)}") | |
self.logger.info(f" π₯ Errors: {error_count}/{len(questions)}") | |
if agent_results: | |
completed_results = [r for r in agent_results if r['status'] == 'completed'] | |
if completed_results: | |
avg_time = sum(r['solve_time'] for r in completed_results) / len(completed_results) | |
self.logger.info(f" β±οΈ Average Solve Time: {avg_time:.1f}s") | |
return agent_results | |
def categorize_error(self, error_message: str) -> str: | |
"""Categorize error types for analysis""" | |
error_message_lower = error_message.lower() | |
if '503' in error_message or 'service unavailable' in error_message_lower: | |
return 'API_OVERLOAD' | |
elif 'timeout' in error_message_lower or 'time out' in error_message_lower: | |
return 'TIMEOUT' | |
elif 'api' in error_message_lower and ('key' in error_message_lower or 'auth' in error_message_lower): | |
return 'AUTHENTICATION' | |
elif 'wikipedia' in error_message_lower or 'wiki' in error_message_lower: | |
return 'WIKIPEDIA_TOOL' | |
elif 'chess' in error_message_lower or 'fen' in error_message_lower: | |
return 'CHESS_TOOL' | |
elif 'excel' in error_message_lower or 'xlsx' in error_message_lower: | |
return 'EXCEL_TOOL' | |
elif 'video' in error_message_lower or 'youtube' in error_message_lower: | |
return 'VIDEO_TOOL' | |
elif 'gemini' in error_message_lower: | |
return 'GEMINI_API' | |
elif 'download' in error_message_lower or 'file' in error_message_lower: | |
return 'FILE_PROCESSING' | |
elif 'hallucination' in error_message_lower or 'fabricat' in error_message_lower: | |
return 'HALLUCINATION' | |
elif 'parsing' in error_message_lower or 'extract' in error_message_lower: | |
return 'PARSING_ERROR' | |
else: | |
return 'UNKNOWN' | |
def analyze_errors_by_agent(self): | |
"""Analyze error patterns by agent type""" | |
if not self.error_patterns: | |
self.logger.info("π No errors found across all agent types!") | |
return | |
self.logger.info(f"\nπ ERROR ANALYSIS BY AGENT TYPE") | |
self.logger.info("=" * 60) | |
for agent_type, errors in self.error_patterns.items(): | |
if not errors: | |
continue | |
self.logger.info(f"\nπ¨ {agent_type.upper()} AGENT ERRORS ({len(errors)} total):") | |
# Group errors by type | |
error_type_counts = defaultdict(int) | |
for error in errors: | |
error_type_counts[error['error_type']] += 1 | |
for error_type, count in sorted(error_type_counts.items(), key=lambda x: x[1], reverse=True): | |
percentage = (count / len(errors)) * 100 | |
self.logger.info(f" {error_type}: {count} errors ({percentage:.1f}%)") | |
# Show specific examples | |
self.logger.info(f" Examples:") | |
for error in errors[:3]: # Show first 3 errors | |
self.logger.info(f" - {error['question_id'][:8]}...: {error['error_type']} - {error['question_preview']}...") | |
def generate_improvement_recommendations(self): | |
"""Generate specific recommendations for improving each agent type""" | |
self.logger.info(f"\nπ‘ IMPROVEMENT RECOMMENDATIONS") | |
self.logger.info("=" * 60) | |
all_results = [r for agent_results in self.results for r in agent_results] | |
# Calculate success rates by agent type | |
agent_stats = defaultdict(lambda: {'total': 0, 'success': 0, 'errors': []}) | |
for result in all_results: | |
agent_type = result['agent_type'] | |
agent_stats[agent_type]['total'] += 1 | |
if result['status'] == 'completed': | |
agent_stats[agent_type]['success'] += 1 | |
else: | |
agent_stats[agent_type]['errors'].append(result) | |
# Generate recommendations for each agent type | |
for agent_type, stats in agent_stats.items(): | |
success_rate = (stats['success'] / stats['total']) * 100 if stats['total'] > 0 else 0 | |
self.logger.info(f"\nπ― {agent_type.upper()} AGENT (Success Rate: {success_rate:.1f}%):") | |
if success_rate >= 90: | |
self.logger.info(f" β Excellent performance! Minor optimizations only.") | |
elif success_rate >= 75: | |
self.logger.info(f" β οΈ Good performance with room for improvement.") | |
elif success_rate >= 50: | |
self.logger.info(f" π§ Moderate performance - needs attention.") | |
else: | |
self.logger.info(f" π¨ Poor performance - requires major improvements.") | |
# Analyze common error patterns for this agent | |
error_types = defaultdict(int) | |
for error in stats['errors']: | |
if error['error_type']: | |
error_types[error['error_type']] += 1 | |
if error_types: | |
self.logger.info(f" Common Issues:") | |
for error_type, count in sorted(error_types.items(), key=lambda x: x[1], reverse=True): | |
self.logger.info(f" - {error_type}: {count} occurrences") | |
self.suggest_fix_for_error_type(error_type, agent_type) | |
def suggest_fix_for_error_type(self, error_type: str, agent_type: str): | |
"""Suggest specific fixes for common error types""" | |
suggestions = { | |
'API_OVERLOAD': "Implement exponential backoff and retry logic", | |
'TIMEOUT': "Increase timeout limits or optimize processing pipeline", | |
'AUTHENTICATION': "Check API keys and authentication configuration", | |
'WIKIPEDIA_TOOL': "Enhance Wikipedia search logic and error handling", | |
'CHESS_TOOL': "Improve FEN parsing and chess engine integration", | |
'EXCEL_TOOL': "Add better Excel format validation and error recovery", | |
'VIDEO_TOOL': "Implement fallback mechanisms for video processing", | |
'GEMINI_API': "Add Gemini API error handling and fallback models", | |
'FILE_PROCESSING': "Improve file download and validation logic", | |
'HALLUCINATION': "Strengthen anti-hallucination prompts and tool output validation", | |
'PARSING_ERROR': "Enhance output parsing logic and format validation" | |
} | |
suggestion = suggestions.get(error_type, "Investigate error cause and implement appropriate fix") | |
self.logger.info(f" β Fix: {suggestion}") | |
def save_comprehensive_results(self, questions_by_agent: Dict[str, List[Dict]]): | |
"""Save comprehensive test results with error analysis""" | |
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
results_file = f"gaia_classification_test_results_{timestamp}.json" | |
# Flatten all results | |
all_results = [] | |
for agent_results in self.results: | |
all_results.extend(agent_results) | |
# Create comprehensive results | |
comprehensive_results = { | |
'test_metadata': { | |
'timestamp': timestamp, | |
'total_questions': len(self.loader.questions), | |
'questions_by_agent': {agent: len(questions) for agent, questions in questions_by_agent.items()}, | |
'log_file': self.log_file | |
}, | |
'overall_stats': { | |
'total_questions': len(all_results), | |
'successful': len([r for r in all_results if r['status'] == 'completed']), | |
'errors': len([r for r in all_results if r['status'] == 'error']), | |
'success_rate': len([r for r in all_results if r['status'] == 'completed']) / len(all_results) * 100 if all_results else 0 | |
}, | |
'agent_performance': {}, | |
'error_patterns': dict(self.error_patterns), | |
'detailed_results': all_results | |
} | |
# Calculate per-agent performance | |
agent_stats = defaultdict(lambda: {'total': 0, 'success': 0, 'avg_time': 0}) | |
for result in all_results: | |
agent_type = result['agent_type'] | |
agent_stats[agent_type]['total'] += 1 | |
if result['status'] == 'completed': | |
agent_stats[agent_type]['success'] += 1 | |
agent_stats[agent_type]['avg_time'] += result['solve_time'] | |
for agent_type, stats in agent_stats.items(): | |
success_rate = (stats['success'] / stats['total']) * 100 if stats['total'] > 0 else 0 | |
avg_time = stats['avg_time'] / stats['success'] if stats['success'] > 0 else 0 | |
comprehensive_results['agent_performance'][agent_type] = { | |
'total_questions': stats['total'], | |
'successful': stats['success'], | |
'success_rate': success_rate, | |
'average_solve_time': avg_time | |
} | |
# Save results | |
with open(results_file, 'w') as f: | |
json.dump(comprehensive_results, f, indent=2, ensure_ascii=False) | |
self.logger.info(f"\nπΎ Comprehensive results saved to: {results_file}") | |
return results_file | |
def run_classification_test(self, agent_types: Optional[List[str]] = None, test_all: bool = True): | |
"""Run the complete classification-based testing workflow""" | |
self.logger.info("π GAIA CLASSIFICATION-BASED TESTING") | |
self.logger.info("=" * 70) | |
self.logger.info(f"Log file: {self.log_file}") | |
# Step 1: Classify all questions | |
questions_by_agent = self.classify_all_questions() | |
# Step 2: Filter agent types to test | |
if agent_types: | |
agent_types_to_test = [agent for agent in agent_types if agent in questions_by_agent] | |
if not agent_types_to_test: | |
self.logger.error(f"No questions found for specified agent types: {agent_types}") | |
return | |
else: | |
agent_types_to_test = list(questions_by_agent.keys()) | |
self.logger.info(f"\nTesting agent types: {agent_types_to_test}") | |
# Step 3: Test each agent type | |
for agent_type in agent_types_to_test: | |
if agent_type == 'error': # Skip classification errors for now | |
continue | |
questions = questions_by_agent[agent_type] | |
agent_results = self.test_agent_type(agent_type, questions, test_all) | |
self.results.append(agent_results) | |
# Step 4: Comprehensive analysis | |
self.analyze_errors_by_agent() | |
self.generate_improvement_recommendations() | |
# Step 5: Save results | |
results_file = self.save_comprehensive_results(questions_by_agent) | |
self.logger.info(f"\nβ CLASSIFICATION TESTING COMPLETE!") | |
self.logger.info(f"π Results saved to: {results_file}") | |
self.logger.info(f"π Log file: {self.log_file}") | |
def main(): | |
"""Main CLI interface for classification-based testing""" | |
parser = argparse.ArgumentParser(description="GAIA Classification-Based Testing with Error Analysis") | |
parser.add_argument( | |
'--agent-types', | |
nargs='+', | |
choices=['multimedia', 'research', 'logic_math', 'file_processing', 'general'], | |
help='Specific agent types to test (default: all)' | |
) | |
parser.add_argument( | |
'--failed-only', | |
action='store_true', | |
help='Test only questions that failed in previous runs' | |
) | |
parser.add_argument( | |
'--quick-test', | |
action='store_true', | |
help='Run a quick test with limited questions per agent type' | |
) | |
args = parser.parse_args() | |
# Initialize and run tester | |
tester = GAIAClassificationTester() | |
print("π― Starting GAIA Classification-Based Testing...") | |
if args.agent_types: | |
print(f"π Testing specific agent types: {args.agent_types}") | |
else: | |
print("π Testing all agent types") | |
tester.run_classification_test( | |
agent_types=args.agent_types, | |
test_all=not args.quick_test | |
) | |
if __name__ == "__main__": | |
main() |