Spaces:
Running
Running
#!/usr/bin/env python3 | |
""" | |
Async Batch Processor for GAIA Questions | |
Comprehensive concurrent processing with progress tracking and error handling | |
""" | |
import asyncio | |
import time | |
from datetime import datetime | |
from typing import List, Dict, Any, Optional, Callable | |
from pathlib import Path | |
import sys | |
# Add parent directory to path for imports | |
sys.path.append(str(Path(__file__).parent.parent)) | |
from tests.async_batch_logger import AsyncBatchLogger, QuestionResult | |
from tests.async_batch_gaia_solver import AsyncGAIASolver | |
from main import GAIASolver | |
from question_classifier import QuestionClassifier | |
class BatchQuestionProcessor: | |
""" | |
Comprehensive async batch processor for GAIA questions | |
Features: Concurrency control, progress tracking, error resilience, real-time logging | |
""" | |
def __init__(self, | |
max_concurrent: int = 3, | |
question_timeout: int = 300, # 5 minutes per question | |
progress_interval: int = 10): # Progress update every 10 seconds | |
self.max_concurrent = max_concurrent | |
self.question_timeout = question_timeout | |
self.progress_interval = progress_interval | |
# Semaphore for concurrency control | |
self.semaphore = asyncio.Semaphore(max_concurrent) | |
# Progress tracking | |
self.completed_count = 0 | |
self.total_questions = 0 | |
self.start_time = None | |
# Logger | |
self.logger = AsyncBatchLogger() | |
async def process_questions_batch(self, | |
questions: List[Dict[str, Any]], | |
solver_kwargs: Optional[Dict] = None) -> Dict[str, Any]: | |
""" | |
Process a batch of questions with full async concurrency | |
Args: | |
questions: List of question dictionaries | |
solver_kwargs: Kwargs to pass to GAIASolver initialization | |
Returns: | |
Comprehensive batch results with classification analysis | |
""" | |
self.total_questions = len(questions) | |
self.start_time = time.time() | |
# Initialize batch logging | |
await self.logger.log_batch_start(self.total_questions, self.max_concurrent) | |
# Default solver configuration | |
if solver_kwargs is None: | |
solver_kwargs = { | |
"use_kluster": True, | |
"kluster_model": "qwen3-235b" | |
} | |
# Create async solver | |
async_solver = AsyncGAIASolver( | |
solver_class=GAIASolver, | |
classifier_class=QuestionClassifier, | |
**solver_kwargs | |
) | |
# Start progress tracking task | |
progress_task = asyncio.create_task(self._track_progress()) | |
try: | |
# Process all questions concurrently | |
print(f"π Starting concurrent processing of {len(questions)} questions...") | |
print(f"π Max concurrent: {self.max_concurrent} | Timeout: {self.question_timeout}s") | |
tasks = [] | |
for question_data in questions: | |
task = asyncio.create_task( | |
self._process_single_question(async_solver, question_data) | |
) | |
tasks.append(task) | |
# Wait for all questions to complete | |
results = await asyncio.gather(*tasks, return_exceptions=True) | |
# Process results | |
batch_results = await self._compile_batch_results(results, questions) | |
# Complete batch logging | |
await self.logger.log_batch_complete() | |
return batch_results | |
finally: | |
# Stop progress tracking | |
progress_task.cancel() | |
try: | |
await progress_task | |
except asyncio.CancelledError: | |
pass | |
async def _process_single_question(self, | |
async_solver: AsyncGAIASolver, | |
question_data: Dict[str, Any]) -> QuestionResult: | |
"""Process a single question with full error handling and logging""" | |
task_id = question_data.get('task_id', 'unknown') | |
async with self.semaphore: # Acquire semaphore for concurrency control | |
try: | |
# Log question start | |
await self.logger.log_question_start(task_id, question_data) | |
# Process with timeout | |
result = await asyncio.wait_for( | |
async_solver.solve_question_async(question_data, task_id), | |
timeout=self.question_timeout | |
) | |
# Create QuestionResult object | |
question_result = QuestionResult( | |
task_id=task_id, | |
question_text=question_data.get('question', ''), | |
classification=result.get('classification', {}).get('primary_agent', 'unknown'), | |
complexity=result.get('classification', {}).get('complexity', 0), | |
confidence=result.get('classification', {}).get('confidence', 0.0), | |
expected_answer=result.get('validation', {}).get('expected', ''), | |
our_answer=result.get('answer', ''), | |
status=result.get('validation', {}).get('status', 'UNKNOWN'), | |
accuracy_score=result.get('validation', {}).get('accuracy_score', 0.0), | |
total_duration=result.get('timing_info', {}).get('total_duration', 0.0), | |
classification_time=result.get('timing_info', {}).get('classification_time', 0.0), | |
solving_time=result.get('timing_info', {}).get('solving_time', 0.0), | |
validation_time=result.get('timing_info', {}).get('validation_time', 0.0), | |
error_type=result.get('error_type'), | |
error_details=str(result.get('error_details', '')), | |
tools_used=result.get('classification', {}).get('tools_needed', []), | |
anti_hallucination_applied=False, # TODO: Track this from solver | |
override_reason=None | |
) | |
# Log classification details | |
if result.get('classification'): | |
await self.logger.log_classification(task_id, result['classification']) | |
# Log answer processing (if available in result) | |
if result.get('answer'): | |
await self.logger.log_answer_processing( | |
task_id, | |
str(result.get('answer', '')), | |
str(result.get('answer', '')) | |
) | |
# Log question completion | |
await self.logger.log_question_complete(task_id, question_result) | |
# Update progress | |
self.completed_count += 1 | |
return question_result | |
except asyncio.TimeoutError: | |
print(f"β±οΈ [{task_id[:8]}...] Question timed out after {self.question_timeout}s") | |
timeout_result = QuestionResult( | |
task_id=task_id, | |
question_text=question_data.get('question', ''), | |
classification='timeout', | |
complexity=0, | |
confidence=0.0, | |
expected_answer='', | |
our_answer='', | |
status='TIMEOUT', | |
accuracy_score=0.0, | |
total_duration=self.question_timeout, | |
classification_time=0.0, | |
solving_time=self.question_timeout, | |
validation_time=0.0, | |
error_type='timeout', | |
error_details=f'Question processing timed out after {self.question_timeout} seconds', | |
tools_used=[], | |
anti_hallucination_applied=False, | |
override_reason=None | |
) | |
await self.logger.log_question_complete(task_id, timeout_result) | |
self.completed_count += 1 | |
return timeout_result | |
except Exception as e: | |
print(f"β [{task_id[:8]}...] Unexpected error: {str(e)}") | |
error_result = QuestionResult( | |
task_id=task_id, | |
question_text=question_data.get('question', ''), | |
classification='error', | |
complexity=0, | |
confidence=0.0, | |
expected_answer='', | |
our_answer='', | |
status='ERROR', | |
accuracy_score=0.0, | |
total_duration=time.time() - self.start_time if self.start_time else 0.0, | |
classification_time=0.0, | |
solving_time=0.0, | |
validation_time=0.0, | |
error_type='unexpected_error', | |
error_details=str(e), | |
tools_used=[], | |
anti_hallucination_applied=False, | |
override_reason=None | |
) | |
await self.logger.log_question_complete(task_id, error_result) | |
self.completed_count += 1 | |
return error_result | |
async def _track_progress(self): | |
"""Background task for real-time progress tracking""" | |
while True: | |
try: | |
await asyncio.sleep(self.progress_interval) | |
await self.logger.log_batch_progress() | |
except asyncio.CancelledError: | |
break | |
except Exception as e: | |
print(f"β οΈ Progress tracking error: {e}") | |
async def _compile_batch_results(self, | |
results: List[QuestionResult], | |
questions: List[Dict[str, Any]]) -> Dict[str, Any]: | |
"""Compile comprehensive batch results with analysis""" | |
# Count results by status | |
status_counts = { | |
"CORRECT": 0, | |
"PARTIAL": 0, | |
"INCORRECT": 0, | |
"TIMEOUT": 0, | |
"ERROR": 0 | |
} | |
# Count by classification | |
classification_counts = {} | |
# Timing analysis | |
total_duration = 0.0 | |
successful_questions = [] | |
for result in results: | |
if isinstance(result, QuestionResult): | |
# Status counting | |
status = result.status | |
if status in status_counts: | |
status_counts[status] += 1 | |
# Classification counting | |
classification = result.classification | |
if classification not in classification_counts: | |
classification_counts[classification] = 0 | |
classification_counts[classification] += 1 | |
# Timing analysis | |
total_duration += result.total_duration | |
if result.status in ["CORRECT", "PARTIAL"]: | |
successful_questions.append(result) | |
# Calculate accuracy metrics | |
total_completed = len([r for r in results if isinstance(r, QuestionResult)]) | |
accuracy_rate = status_counts["CORRECT"] / total_completed if total_completed > 0 else 0.0 | |
success_rate = (status_counts["CORRECT"] + status_counts["PARTIAL"]) / total_completed if total_completed > 0 else 0.0 | |
# Performance metrics | |
avg_duration = total_duration / total_completed if total_completed > 0 else 0.0 | |
batch_summary = { | |
"timestamp": datetime.now().isoformat(), | |
"total_questions": self.total_questions, | |
"completed_questions": total_completed, | |
"accuracy_metrics": { | |
"accuracy_rate": accuracy_rate, | |
"success_rate": success_rate, | |
"correct_answers": status_counts["CORRECT"], | |
"partial_answers": status_counts["PARTIAL"], | |
"incorrect_answers": status_counts["INCORRECT"], | |
"timeouts": status_counts["TIMEOUT"], | |
"errors": status_counts["ERROR"] | |
}, | |
"classification_breakdown": classification_counts, | |
"performance_metrics": { | |
"total_duration": total_duration, | |
"average_duration": avg_duration, | |
"max_concurrent": self.max_concurrent, | |
"question_timeout": self.question_timeout | |
}, | |
"detailed_results": [result for result in results if isinstance(result, QuestionResult)] | |
} | |
return batch_summary | |
async def main(): | |
"""Test the async batch processor with a small subset of questions""" | |
try: | |
# Import required classes | |
from gaia_web_loader import GAIAQuestionLoaderWeb | |
print("π§ͺ Testing Async Batch Processor") | |
print("=" * 60) | |
# Load a few test questions | |
print("π Loading test questions...") | |
loader = GAIAQuestionLoaderWeb() | |
all_questions = loader.questions | |
# Use first 3 questions for testing | |
test_questions = all_questions[:3] | |
print(f"β Loaded {len(test_questions)} test questions") | |
for i, q in enumerate(test_questions): | |
task_id = q.get('task_id', 'unknown') | |
question = q.get('question', '')[:50] + "..." | |
print(f" {i+1}. {task_id[:8]}... - {question}") | |
# Initialize processor | |
print(f"\nπ Initializing batch processor...") | |
processor = BatchQuestionProcessor( | |
max_concurrent=2, # Lower concurrency for testing | |
question_timeout=180, # 3 minutes timeout for testing | |
progress_interval=5 # Progress updates every 5 seconds | |
) | |
# Process batch | |
print(f"\nπ Starting batch processing...") | |
results = await processor.process_questions_batch(test_questions) | |
# Display results | |
print(f"\nπ BATCH RESULTS:") | |
print("=" * 60) | |
accuracy = results["accuracy_metrics"]["accuracy_rate"] | |
success = results["accuracy_metrics"]["success_rate"] | |
print(f"β Accuracy Rate: {accuracy:.1%}") | |
print(f"π― Success Rate: {success:.1%}") | |
print(f"β±οΈ Total Duration: {results['performance_metrics']['total_duration']:.1f}s") | |
print(f"β‘ Average Duration: {results['performance_metrics']['average_duration']:.1f}s") | |
print(f"\nπ Classification Breakdown:") | |
for classification, count in results["classification_breakdown"].items(): | |
print(f" - {classification}: {count}") | |
print(f"\nπ Status Breakdown:") | |
for status, count in results["accuracy_metrics"].items(): | |
if isinstance(count, int): | |
print(f" - {status}: {count}") | |
print(f"\nβ Async batch processing test completed successfully!") | |
except Exception as e: | |
print(f"β Test failed: {e}") | |
import traceback | |
traceback.print_exc() | |
if __name__ == "__main__": | |
asyncio.run(main()) |