Spaces:
Running
Running
GAIA Developer
๐ feat: Add comprehensive GAIA evaluation system and batch testing infrastructure
1a3088a
#!/usr/bin/env python3 | |
""" | |
GAIA Evaluator | |
A comprehensive evaluation system for analyzing GAIA agent performance across different dimensions. | |
""" | |
import json | |
import logging | |
from pathlib import Path | |
from typing import Dict, List, Any, Optional, Tuple | |
import statistics | |
from datetime import datetime | |
import pandas as pd | |
import matplotlib.pyplot as plt | |
import seaborn as sns | |
import numpy as np | |
from answer_validator import AnswerValidator | |
class GAIAEvaluator: | |
""" | |
A comprehensive evaluation system for GAIA benchmark performance analysis. | |
Provides detailed metrics, visualizations, and comparative analysis. | |
""" | |
def __init__(self, | |
results_dir: Optional[str] = None, | |
validation_file: Optional[str] = "gaia_validation_metadata.jsonl"): | |
""" | |
Initialize the GAIA evaluator. | |
Args: | |
results_dir: Directory containing test results (None to provide later) | |
validation_file: Path to validation metadata file | |
""" | |
self.logger = logging.getLogger("GAIAEvaluator") | |
self.results_dir = Path(results_dir) if results_dir else None | |
self.validation_file = Path(validation_file) if validation_file else None | |
self.validator = AnswerValidator() | |
# Performance metrics | |
self.metrics = {} | |
self.question_details = {} | |
self.validation_data = {} | |
# Load validation data if provided | |
if self.validation_file and self.validation_file.exists(): | |
self._load_validation_data() | |
def _load_validation_data(self) -> None: | |
"""Load validation data from JSONL file.""" | |
self.logger.info(f"Loading validation data from {self.validation_file}") | |
try: | |
with open(self.validation_file, 'r') as f: | |
for line in f: | |
try: | |
entry = json.loads(line) | |
question_id = entry.get('question_id') | |
if question_id: | |
self.validation_data[question_id] = entry | |
except json.JSONDecodeError: | |
self.logger.warning(f"Could not parse line in validation file: {line[:50]}...") | |
except Exception as e: | |
self.logger.error(f"Error loading validation data: {e}") | |
def set_results_directory(self, results_dir: str) -> None: | |
"""Set or update the results directory.""" | |
self.results_dir = Path(results_dir) | |
def load_results(self, results_file: Optional[str] = None) -> Dict: | |
""" | |
Load test results from the specified file or search for it. | |
Args: | |
results_file: Specific results file to load (None to search in results_dir) | |
Returns: | |
Dict of loaded results | |
""" | |
if results_file: | |
file_path = Path(results_file) | |
elif self.results_dir: | |
# Find the most recent results.json file | |
json_files = list(self.results_dir.glob("**/results.json")) | |
if not json_files: | |
self.logger.error(f"No results.json files found in {self.results_dir}") | |
return {} | |
# Sort by modification time, newest first | |
file_path = sorted(json_files, key=lambda x: x.stat().st_mtime, reverse=True)[0] | |
else: | |
self.logger.error("No results directory or file specified") | |
return {} | |
try: | |
self.logger.info(f"Loading results from {file_path}") | |
with open(file_path, 'r') as f: | |
results = json.load(f) | |
return results | |
except Exception as e: | |
self.logger.error(f"Error loading results: {e}") | |
return {} | |
def evaluate(self, results: Dict = None) -> Dict: | |
""" | |
Evaluate GAIA test results with comprehensive metrics. | |
Args: | |
results: Test results dict (None to load from file) | |
Returns: | |
Dict of evaluation metrics | |
""" | |
if not results: | |
results = self.load_results() | |
if not results: | |
return {} | |
# Calculate basic metrics | |
total_questions = len(results) | |
correct_answers = 0 | |
partial_answers = 0 | |
incorrect_answers = 0 | |
errors = 0 | |
timeouts = 0 | |
classification_accuracy = 0 | |
total_classified = 0 | |
processing_times = [] | |
confidence_scores = [] | |
# Analyze each question | |
question_metrics = {} | |
for question_id, data in results.items(): | |
# Extract validation status | |
validation = data.get('validation', {}) | |
validation_status = validation.get('validation_status', 'error') | |
# Basic counters | |
if validation_status == 'correct': | |
correct_answers += 1 | |
elif validation_status == 'partial': | |
partial_answers += 1 | |
elif validation_status == 'incorrect': | |
incorrect_answers += 1 | |
elif validation_status == 'error': | |
errors += 1 | |
elif validation_status == 'timeout': | |
timeouts += 1 | |
# Track processing time | |
if 'processing_time' in data: | |
processing_times.append(data['processing_time']) | |
# Track confidence scores | |
if 'confidence_score' in validation: | |
confidence_scores.append(validation['confidence_score']) | |
# Track classification accuracy | |
if 'classification' in data: | |
classification_data = data['classification'] | |
total_classified += 1 | |
if classification_data.get('is_correct', False): | |
classification_accuracy += 1 | |
# Store detailed metrics per question | |
question_metrics[question_id] = { | |
'validation_status': validation_status, | |
'processing_time': data.get('processing_time'), | |
'confidence_score': validation.get('confidence_score'), | |
'classification': data.get('classification', {}).get('classification'), | |
'is_classification_correct': data.get('classification', {}).get('is_correct', False), | |
'tools_used': data.get('tools_used', []), | |
'steps_count': len(data.get('steps', [])), | |
} | |
# Calculate derived metrics | |
accuracy = (correct_answers / total_questions) * 100 if total_questions > 0 else 0 | |
success_rate = ((correct_answers + partial_answers) / total_questions) * 100 if total_questions > 0 else 0 | |
classification_accuracy_pct = (classification_accuracy / total_classified) * 100 if total_classified > 0 else 0 | |
avg_processing_time = statistics.mean(processing_times) if processing_times else 0 | |
median_processing_time = statistics.median(processing_times) if processing_times else 0 | |
avg_confidence = statistics.mean(confidence_scores) if confidence_scores else 0 | |
# Store metrics | |
self.metrics = { | |
'total_questions': total_questions, | |
'correct_answers': correct_answers, | |
'partial_answers': partial_answers, | |
'incorrect_answers': incorrect_answers, | |
'errors': errors, | |
'timeouts': timeouts, | |
'accuracy': accuracy, | |
'success_rate': success_rate, | |
'classification_accuracy': classification_accuracy_pct, | |
'avg_processing_time': avg_processing_time, | |
'median_processing_time': median_processing_time, | |
'avg_confidence_score': avg_confidence, | |
} | |
self.question_details = question_metrics | |
return self.metrics | |
def visualize_performance(self, output_dir: Optional[str] = None) -> None: | |
""" | |
Generate visualizations of performance metrics. | |
Args: | |
output_dir: Directory to save visualizations (None to use results_dir) | |
""" | |
if not self.metrics: | |
self.logger.error("No metrics available. Run evaluate() first.") | |
return | |
if not output_dir: | |
output_dir = self.results_dir | |
output_path = Path(output_dir) | |
output_path.mkdir(exist_ok=True) | |
# Set the style | |
sns.set(style="whitegrid") | |
plt.rcParams.update({'font.size': 12}) | |
# Create visualizations | |
self._create_accuracy_chart(output_path) | |
self._create_timing_chart(output_path) | |
self._create_question_type_chart(output_path) | |
self._create_confidence_distribution(output_path) | |
def _create_accuracy_chart(self, output_path: Path) -> None: | |
"""Create accuracy breakdown chart.""" | |
categories = ['Correct', 'Partial', 'Incorrect', 'Error', 'Timeout'] | |
values = [ | |
self.metrics['correct_answers'], | |
self.metrics['partial_answers'], | |
self.metrics['incorrect_answers'], | |
self.metrics['errors'], | |
self.metrics['timeouts'] | |
] | |
plt.figure(figsize=(10, 6)) | |
colors = ['#2ecc71', '#f39c12', '#e74c3c', '#7f8c8d', '#95a5a6'] | |
ax = plt.bar(categories, values, color=colors) | |
for i, v in enumerate(values): | |
plt.text(i, v + 0.1, str(v), ha='center') | |
plt.title('Accuracy Breakdown') | |
plt.ylabel('Number of Questions') | |
plt.tight_layout() | |
plt.savefig(output_path / 'accuracy_breakdown.png', dpi=300) | |
plt.close() | |
def _create_timing_chart(self, output_path: Path) -> None: | |
"""Create timing analysis chart.""" | |
if not self.question_details: | |
return | |
# Extract times and statuses | |
times = [] | |
statuses = [] | |
labels = [] | |
for q_id, details in self.question_details.items(): | |
if details.get('processing_time'): | |
times.append(details['processing_time']) | |
statuses.append(details['validation_status']) | |
labels.append(q_id) | |
if not times: | |
return | |
# Convert to dataframe | |
df = pd.DataFrame({ | |
'Question': labels, | |
'Time (s)': times, | |
'Status': statuses | |
}) | |
# Sort by time | |
df = df.sort_values('Time (s)', ascending=False) | |
plt.figure(figsize=(12, 8)) | |
# Color mapping | |
color_map = { | |
'correct': '#2ecc71', | |
'partial': '#f39c12', | |
'incorrect': '#e74c3c', | |
'error': '#7f8c8d', | |
'timeout': '#95a5a6' | |
} | |
sns.barplot(x='Time (s)', y='Question', hue='Status', data=df, | |
palette=color_map, dodge=False) | |
plt.title('Processing Time by Question') | |
plt.tight_layout() | |
plt.savefig(output_path / 'processing_times.png', dpi=300) | |
plt.close() | |
def _create_question_type_chart(self, output_path: Path) -> None: | |
"""Create question type performance chart.""" | |
if not self.question_details: | |
return | |
# Group by classification type | |
question_types = {} | |
for q_id, details in self.question_details.items(): | |
q_type = details.get('classification', 'unknown') | |
if q_type not in question_types: | |
question_types[q_type] = { | |
'total': 0, | |
'correct': 0, | |
'partial': 0, | |
'incorrect': 0, | |
'other': 0 | |
} | |
question_types[q_type]['total'] += 1 | |
status = details.get('validation_status') | |
if status == 'correct': | |
question_types[q_type]['correct'] += 1 | |
elif status == 'partial': | |
question_types[q_type]['partial'] += 1 | |
elif status == 'incorrect': | |
question_types[q_type]['incorrect'] += 1 | |
else: | |
question_types[q_type]['other'] += 1 | |
# Convert to dataframe | |
types = [] | |
statuses = [] | |
counts = [] | |
for q_type, stats in question_types.items(): | |
for status, count in stats.items(): | |
if status != 'total': | |
types.append(q_type) | |
statuses.append(status) | |
counts.append(count) | |
df = pd.DataFrame({ | |
'Question Type': types, | |
'Status': statuses, | |
'Count': counts | |
}) | |
plt.figure(figsize=(12, 8)) | |
# Create grouped bar chart | |
sns.barplot(x='Question Type', y='Count', hue='Status', data=df) | |
plt.title('Performance by Question Type') | |
plt.tight_layout() | |
plt.savefig(output_path / 'question_type_performance.png', dpi=300) | |
plt.close() | |
def _create_confidence_distribution(self, output_path: Path) -> None: | |
"""Create confidence score distribution chart.""" | |
if not self.question_details: | |
return | |
# Extract confidence scores and statuses | |
scores = [] | |
statuses = [] | |
for details in self.question_details.values(): | |
conf_score = details.get('confidence_score') | |
if conf_score is not None: | |
scores.append(conf_score) | |
statuses.append(details['validation_status']) | |
if not scores: | |
return | |
# Create dataframe | |
df = pd.DataFrame({ | |
'Confidence Score': scores, | |
'Status': statuses | |
}) | |
plt.figure(figsize=(10, 6)) | |
# Create histogram with KDE | |
sns.histplot(data=df, x='Confidence Score', hue='Status', kde=True) | |
plt.title('Confidence Score Distribution') | |
plt.tight_layout() | |
plt.savefig(output_path / 'confidence_distribution.png', dpi=300) | |
plt.close() | |
def generate_report(self, output_file: Optional[str] = None) -> str: | |
""" | |
Generate a comprehensive evaluation report. | |
Args: | |
output_file: Path to save the report (None for no saving) | |
Returns: | |
HTML report as string | |
""" | |
if not self.metrics: | |
self.logger.error("No metrics available. Run evaluate() first.") | |
return "" | |
# Create report HTML | |
report = f""" | |
<html> | |
<head> | |
<title>GAIA Performance Evaluation Report</title> | |
<style> | |
body {{ font-family: Arial, sans-serif; margin: 20px; }} | |
h1 {{ color: #2c3e50; }} | |
h2 {{ color: #3498db; }} | |
.metric-card {{ | |
background-color: #f8f9fa; | |
border-radius: 8px; | |
padding: 15px; | |
margin-bottom: 20px; | |
box-shadow: 0 2px 5px rgba(0,0,0,0.1); | |
}} | |
.metric-title {{ font-weight: bold; margin-bottom: 8px; }} | |
.metric-value {{ font-size: 24px; color: #2c3e50; }} | |
.good {{ color: #2ecc71; }} | |
.medium {{ color: #f39c12; }} | |
.poor {{ color: #e74c3c; }} | |
table {{ border-collapse: collapse; width: 100%; }} | |
th, td {{ padding: 12px; text-align: left; border-bottom: 1px solid #ddd; }} | |
th {{ background-color: #f2f2f2; }} | |
tr:hover {{background-color: #f5f5f5;}} | |
.chart-container {{ margin-top: 30px; margin-bottom: 30px; }} | |
.chart {{ max-width: 100%; height: auto; }} | |
</style> | |
</head> | |
<body> | |
<h1>GAIA Performance Evaluation Report</h1> | |
<p>Generated on {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}</p> | |
<div class="metric-card"> | |
<h2>Summary Metrics</h2> | |
<div class="metric-row"> | |
<div class="metric-title">Accuracy</div> | |
<div class="metric-value {self._get_color_class(self.metrics['accuracy'])}"> | |
{self.metrics['accuracy']:.2f}% | |
</div> | |
</div> | |
<div class="metric-row"> | |
<div class="metric-title">Success Rate (Correct + Partial)</div> | |
<div class="metric-value {self._get_color_class(self.metrics['success_rate'])}"> | |
{self.metrics['success_rate']:.2f}% | |
</div> | |
</div> | |
<div class="metric-row"> | |
<div class="metric-title">Classification Accuracy</div> | |
<div class="metric-value {self._get_color_class(self.metrics['classification_accuracy'])}"> | |
{self.metrics['classification_accuracy']:.2f}% | |
</div> | |
</div> | |
<div class="metric-row"> | |
<div class="metric-title">Average Processing Time</div> | |
<div class="metric-value"> | |
{self.metrics['avg_processing_time']:.2f} seconds | |
</div> | |
</div> | |
</div> | |
<div class="metric-card"> | |
<h2>Accuracy Breakdown</h2> | |
<table> | |
<tr> | |
<th>Metric</th> | |
<th>Count</th> | |
<th>Percentage</th> | |
</tr> | |
<tr> | |
<td>Correct Answers</td> | |
<td>{self.metrics['correct_answers']}</td> | |
<td>{(self.metrics['correct_answers'] / self.metrics['total_questions'] * 100):.2f}%</td> | |
</tr> | |
<tr> | |
<td>Partial Answers</td> | |
<td>{self.metrics['partial_answers']}</td> | |
<td>{(self.metrics['partial_answers'] / self.metrics['total_questions'] * 100):.2f}%</td> | |
</tr> | |
<tr> | |
<td>Incorrect Answers</td> | |
<td>{self.metrics['incorrect_answers']}</td> | |
<td>{(self.metrics['incorrect_answers'] / self.metrics['total_questions'] * 100):.2f}%</td> | |
</tr> | |
<tr> | |
<td>Errors</td> | |
<td>{self.metrics['errors']}</td> | |
<td>{(self.metrics['errors'] / self.metrics['total_questions'] * 100):.2f}%</td> | |
</tr> | |
<tr> | |
<td>Timeouts</td> | |
<td>{self.metrics['timeouts']}</td> | |
<td>{(self.metrics['timeouts'] / self.metrics['total_questions'] * 100):.2f}%</td> | |
</tr> | |
</table> | |
</div> | |
<!-- Include charts if available --> | |
<div class="chart-container"> | |
<h2>Performance Visualizations</h2> | |
<img class="chart" src="accuracy_breakdown.png" alt="Accuracy Breakdown" /> | |
<img class="chart" src="processing_times.png" alt="Processing Times" /> | |
<img class="chart" src="question_type_performance.png" alt="Question Type Performance" /> | |
<img class="chart" src="confidence_distribution.png" alt="Confidence Distribution" /> | |
</div> | |
<!-- Detailed results table --> | |
<div class="metric-card"> | |
<h2>Detailed Question Results</h2> | |
<table> | |
<tr> | |
<th>Question ID</th> | |
<th>Status</th> | |
<th>Processing Time (s)</th> | |
<th>Confidence</th> | |
<th>Classification</th> | |
</tr> | |
{self._generate_question_rows()} | |
</table> | |
</div> | |
</body> | |
</html> | |
""" | |
# Save if output file provided | |
if output_file: | |
try: | |
with open(output_file, 'w') as f: | |
f.write(report) | |
self.logger.info(f"Report saved to {output_file}") | |
except Exception as e: | |
self.logger.error(f"Error saving report: {e}") | |
return report | |
def _get_color_class(self, value: float) -> str: | |
"""Get CSS class based on value.""" | |
if value >= 80: | |
return "good" | |
elif value >= 60: | |
return "medium" | |
else: | |
return "poor" | |
def _generate_question_rows(self) -> str: | |
"""Generate HTML table rows for question details.""" | |
rows = "" | |
for q_id, details in self.question_details.items(): | |
status = details.get('validation_status', 'unknown') | |
proc_time = f"{details.get('processing_time', 'N/A'):.2f}" if details.get('processing_time') else 'N/A' | |
confidence = f"{details.get('confidence_score', 'N/A'):.2f}" if details.get('confidence_score') is not None else 'N/A' | |
classification = details.get('classification', 'unknown') | |
# Get status class | |
status_class = "" | |
if status == 'correct': | |
status_class = "good" | |
elif status == 'partial': | |
status_class = "medium" | |
elif status in ('incorrect', 'error', 'timeout'): | |
status_class = "poor" | |
rows += f""" | |
<tr> | |
<td>{q_id}</td> | |
<td class="{status_class}">{status}</td> | |
<td>{proc_time}</td> | |
<td>{confidence}</td> | |
<td>{classification}</td> | |
</tr> | |
""" | |
return rows | |
def compare_runs(self, results_files: List[str], labels: List[str]) -> Dict: | |
""" | |
Compare metrics across multiple test runs. | |
Args: | |
results_files: List of results files to compare | |
labels: Labels for each run | |
Returns: | |
Dict with comparison data | |
""" | |
if len(results_files) != len(labels): | |
self.logger.error("Number of result files must match number of labels") | |
return {} | |
comparison_data = { | |
'runs': {}, | |
'metrics': ['accuracy', 'success_rate', 'classification_accuracy', | |
'avg_processing_time', 'correct_answers', 'partial_answers', | |
'incorrect_answers', 'errors', 'timeouts'] | |
} | |
for i, (file_path, label) in enumerate(zip(results_files, labels)): | |
# Create a temporary evaluator to analyze this run | |
temp_evaluator = GAIAEvaluator(validation_file=self.validation_file) | |
results = temp_evaluator.load_results(file_path) | |
metrics = temp_evaluator.evaluate(results) | |
if metrics: | |
comparison_data['runs'][label] = metrics | |
return comparison_data | |
def visualize_comparison(self, comparison_data: Dict, output_dir: str) -> None: | |
""" | |
Create visualizations comparing multiple runs. | |
Args: | |
comparison_data: Data from compare_runs method | |
output_dir: Directory to save visualizations | |
""" | |
if not comparison_data or not comparison_data.get('runs'): | |
self.logger.error("No comparison data available") | |
return | |
output_path = Path(output_dir) | |
output_path.mkdir(exist_ok=True) | |
# Set style | |
sns.set(style="whitegrid") | |
plt.rcParams.update({'font.size': 12}) | |
# Get run labels and metrics | |
run_labels = list(comparison_data['runs'].keys()) | |
all_metrics = comparison_data['metrics'] | |
# Create bar chart for key metrics | |
key_metrics = ['accuracy', 'success_rate', 'classification_accuracy'] | |
# Extract data | |
metric_values = {metric: [] for metric in key_metrics} | |
for run_label in run_labels: | |
run_data = comparison_data['runs'][run_label] | |
for metric in key_metrics: | |
metric_values[metric].append(run_data.get(metric, 0)) | |
# Create grouped bar chart | |
plt.figure(figsize=(12, 8)) | |
x = np.arange(len(run_labels)) | |
width = 0.25 | |
for i, metric in enumerate(key_metrics): | |
plt.bar(x + i*width - width, metric_values[metric], width, label=metric.replace('_', ' ').title()) | |
plt.xlabel('Test Run') | |
plt.ylabel('Percentage') | |
plt.title('Key Metrics Comparison') | |
plt.xticks(x, run_labels) | |
plt.legend() | |
plt.tight_layout() | |
plt.savefig(output_path / 'metrics_comparison.png', dpi=300) | |
plt.close() | |
# Create processing time comparison | |
times = [comparison_data['runs'][label].get('avg_processing_time', 0) for label in run_labels] | |
plt.figure(figsize=(10, 6)) | |
plt.bar(run_labels, times) | |
plt.xlabel('Test Run') | |
plt.ylabel('Average Processing Time (s)') | |
plt.title('Processing Time Comparison') | |
plt.tight_layout() | |
plt.savefig(output_path / 'processing_time_comparison.png', dpi=300) | |
plt.close() | |
if __name__ == "__main__": | |
import argparse | |
# Configure logging | |
logging.basicConfig( | |
level=logging.INFO, | |
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s", | |
handlers=[logging.StreamHandler()] | |
) | |
# Parse arguments | |
parser = argparse.ArgumentParser(description="GAIA Benchmark Evaluation Tool") | |
parser.add_argument("--results_dir", type=str, help="Directory containing test results") | |
parser.add_argument("--results_file", type=str, help="Specific results file to evaluate") | |
parser.add_argument("--validation_file", type=str, default="gaia_validation_metadata.jsonl", | |
help="Path to validation metadata file") | |
parser.add_argument("--output_dir", type=str, help="Directory to save evaluation outputs") | |
parser.add_argument("--report_file", type=str, help="Path to save HTML report") | |
parser.add_argument("--compare", action="store_true", help="Compare multiple test runs") | |
parser.add_argument("--compare_files", type=str, nargs="+", help="List of files to compare") | |
parser.add_argument("--compare_labels", type=str, nargs="+", help="Labels for comparison runs") | |
args = parser.parse_args() | |
# Initialize evaluator | |
evaluator = GAIAEvaluator( | |
results_dir=args.results_dir, | |
validation_file=args.validation_file | |
) | |
# Handle comparison mode | |
if args.compare and args.compare_files and args.compare_labels: | |
comparison_data = evaluator.compare_runs(args.compare_files, args.compare_labels) | |
if comparison_data and args.output_dir: | |
evaluator.visualize_comparison(comparison_data, args.output_dir) | |
print(f"Comparison visualizations saved to {args.output_dir}") | |
else: | |
# Regular evaluation | |
if args.results_file: | |
results = evaluator.load_results(args.results_file) | |
else: | |
results = evaluator.load_results() | |
if results: | |
metrics = evaluator.evaluate(results) | |
print(f"Evaluation metrics: {metrics}") | |
if args.output_dir: | |
evaluator.visualize_performance(args.output_dir) | |
print(f"Performance visualizations saved to {args.output_dir}") | |
if args.report_file: | |
evaluator.generate_report(args.report_file) | |
print(f"Report saved to {args.report_file}") | |