#!/usr/bin/env python3 """ Async Batch Processor for GAIA Questions Comprehensive concurrent processing with progress tracking and error handling """ import asyncio import time from datetime import datetime from typing import List, Dict, Any, Optional, Callable from pathlib import Path import sys # Add parent directory to path for imports sys.path.append(str(Path(__file__).parent.parent)) from tests.async_batch_logger import AsyncBatchLogger, QuestionResult from tests.async_batch_gaia_solver import AsyncGAIASolver from main import GAIASolver from question_classifier import QuestionClassifier class BatchQuestionProcessor: """ Comprehensive async batch processor for GAIA questions Features: Concurrency control, progress tracking, error resilience, real-time logging """ def __init__(self, max_concurrent: int = 3, question_timeout: int = 300, # 5 minutes per question progress_interval: int = 10): # Progress update every 10 seconds self.max_concurrent = max_concurrent self.question_timeout = question_timeout self.progress_interval = progress_interval # Semaphore for concurrency control self.semaphore = asyncio.Semaphore(max_concurrent) # Progress tracking self.completed_count = 0 self.total_questions = 0 self.start_time = None # Logger self.logger = AsyncBatchLogger() async def process_questions_batch(self, questions: List[Dict[str, Any]], solver_kwargs: Optional[Dict] = None) -> Dict[str, Any]: """ Process a batch of questions with full async concurrency Args: questions: List of question dictionaries solver_kwargs: Kwargs to pass to GAIASolver initialization Returns: Comprehensive batch results with classification analysis """ self.total_questions = len(questions) self.start_time = time.time() # Initialize batch logging await self.logger.log_batch_start(self.total_questions, self.max_concurrent) # Default solver configuration if solver_kwargs is None: solver_kwargs = { "use_kluster": True, "kluster_model": "qwen3-235b" } # Create async solver async_solver = AsyncGAIASolver( solver_class=GAIASolver, classifier_class=QuestionClassifier, **solver_kwargs ) # Start progress tracking task progress_task = asyncio.create_task(self._track_progress()) try: # Process all questions concurrently print(f"๐Ÿš€ Starting concurrent processing of {len(questions)} questions...") print(f"๐Ÿ“Š Max concurrent: {self.max_concurrent} | Timeout: {self.question_timeout}s") tasks = [] for question_data in questions: task = asyncio.create_task( self._process_single_question(async_solver, question_data) ) tasks.append(task) # Wait for all questions to complete results = await asyncio.gather(*tasks, return_exceptions=True) # Process results batch_results = await self._compile_batch_results(results, questions) # Complete batch logging await self.logger.log_batch_complete() return batch_results finally: # Stop progress tracking progress_task.cancel() try: await progress_task except asyncio.CancelledError: pass async def _process_single_question(self, async_solver: AsyncGAIASolver, question_data: Dict[str, Any]) -> QuestionResult: """Process a single question with full error handling and logging""" task_id = question_data.get('task_id', 'unknown') async with self.semaphore: # Acquire semaphore for concurrency control try: # Log question start await self.logger.log_question_start(task_id, question_data) # Process with timeout result = await asyncio.wait_for( async_solver.solve_question_async(question_data, task_id), timeout=self.question_timeout ) # Create QuestionResult object question_result = QuestionResult( task_id=task_id, question_text=question_data.get('question', ''), classification=result.get('classification', {}).get('primary_agent', 'unknown'), complexity=result.get('classification', {}).get('complexity', 0), confidence=result.get('classification', {}).get('confidence', 0.0), expected_answer=result.get('validation', {}).get('expected', ''), our_answer=result.get('answer', ''), status=result.get('validation', {}).get('status', 'UNKNOWN'), accuracy_score=result.get('validation', {}).get('accuracy_score', 0.0), total_duration=result.get('timing_info', {}).get('total_duration', 0.0), classification_time=result.get('timing_info', {}).get('classification_time', 0.0), solving_time=result.get('timing_info', {}).get('solving_time', 0.0), validation_time=result.get('timing_info', {}).get('validation_time', 0.0), error_type=result.get('error_type'), error_details=str(result.get('error_details', '')), tools_used=result.get('classification', {}).get('tools_needed', []), anti_hallucination_applied=False, # TODO: Track this from solver override_reason=None ) # Log classification details if result.get('classification'): await self.logger.log_classification(task_id, result['classification']) # Log answer processing (if available in result) if result.get('answer'): await self.logger.log_answer_processing( task_id, str(result.get('answer', '')), str(result.get('answer', '')) ) # Log question completion await self.logger.log_question_complete(task_id, question_result) # Update progress self.completed_count += 1 return question_result except asyncio.TimeoutError: print(f"โฑ๏ธ [{task_id[:8]}...] Question timed out after {self.question_timeout}s") timeout_result = QuestionResult( task_id=task_id, question_text=question_data.get('question', ''), classification='timeout', complexity=0, confidence=0.0, expected_answer='', our_answer='', status='TIMEOUT', accuracy_score=0.0, total_duration=self.question_timeout, classification_time=0.0, solving_time=self.question_timeout, validation_time=0.0, error_type='timeout', error_details=f'Question processing timed out after {self.question_timeout} seconds', tools_used=[], anti_hallucination_applied=False, override_reason=None ) await self.logger.log_question_complete(task_id, timeout_result) self.completed_count += 1 return timeout_result except Exception as e: print(f"โŒ [{task_id[:8]}...] Unexpected error: {str(e)}") error_result = QuestionResult( task_id=task_id, question_text=question_data.get('question', ''), classification='error', complexity=0, confidence=0.0, expected_answer='', our_answer='', status='ERROR', accuracy_score=0.0, total_duration=time.time() - self.start_time if self.start_time else 0.0, classification_time=0.0, solving_time=0.0, validation_time=0.0, error_type='unexpected_error', error_details=str(e), tools_used=[], anti_hallucination_applied=False, override_reason=None ) await self.logger.log_question_complete(task_id, error_result) self.completed_count += 1 return error_result async def _track_progress(self): """Background task for real-time progress tracking""" while True: try: await asyncio.sleep(self.progress_interval) await self.logger.log_batch_progress() except asyncio.CancelledError: break except Exception as e: print(f"โš ๏ธ Progress tracking error: {e}") async def _compile_batch_results(self, results: List[QuestionResult], questions: List[Dict[str, Any]]) -> Dict[str, Any]: """Compile comprehensive batch results with analysis""" # Count results by status status_counts = { "CORRECT": 0, "PARTIAL": 0, "INCORRECT": 0, "TIMEOUT": 0, "ERROR": 0 } # Count by classification classification_counts = {} # Timing analysis total_duration = 0.0 successful_questions = [] for result in results: if isinstance(result, QuestionResult): # Status counting status = result.status if status in status_counts: status_counts[status] += 1 # Classification counting classification = result.classification if classification not in classification_counts: classification_counts[classification] = 0 classification_counts[classification] += 1 # Timing analysis total_duration += result.total_duration if result.status in ["CORRECT", "PARTIAL"]: successful_questions.append(result) # Calculate accuracy metrics total_completed = len([r for r in results if isinstance(r, QuestionResult)]) accuracy_rate = status_counts["CORRECT"] / total_completed if total_completed > 0 else 0.0 success_rate = (status_counts["CORRECT"] + status_counts["PARTIAL"]) / total_completed if total_completed > 0 else 0.0 # Performance metrics avg_duration = total_duration / total_completed if total_completed > 0 else 0.0 batch_summary = { "timestamp": datetime.now().isoformat(), "total_questions": self.total_questions, "completed_questions": total_completed, "accuracy_metrics": { "accuracy_rate": accuracy_rate, "success_rate": success_rate, "correct_answers": status_counts["CORRECT"], "partial_answers": status_counts["PARTIAL"], "incorrect_answers": status_counts["INCORRECT"], "timeouts": status_counts["TIMEOUT"], "errors": status_counts["ERROR"] }, "classification_breakdown": classification_counts, "performance_metrics": { "total_duration": total_duration, "average_duration": avg_duration, "max_concurrent": self.max_concurrent, "question_timeout": self.question_timeout }, "detailed_results": [result for result in results if isinstance(result, QuestionResult)] } return batch_summary async def main(): """Test the async batch processor with a small subset of questions""" try: # Import required classes from gaia_web_loader import GAIAQuestionLoaderWeb print("๐Ÿงช Testing Async Batch Processor") print("=" * 60) # Load a few test questions print("๐Ÿ“‹ Loading test questions...") loader = GAIAQuestionLoaderWeb() all_questions = loader.questions # Use first 3 questions for testing test_questions = all_questions[:3] print(f"โœ… Loaded {len(test_questions)} test questions") for i, q in enumerate(test_questions): task_id = q.get('task_id', 'unknown') question = q.get('question', '')[:50] + "..." print(f" {i+1}. {task_id[:8]}... - {question}") # Initialize processor print(f"\n๐Ÿš€ Initializing batch processor...") processor = BatchQuestionProcessor( max_concurrent=2, # Lower concurrency for testing question_timeout=180, # 3 minutes timeout for testing progress_interval=5 # Progress updates every 5 seconds ) # Process batch print(f"\n๐Ÿ”„ Starting batch processing...") results = await processor.process_questions_batch(test_questions) # Display results print(f"\n๐Ÿ“Š BATCH RESULTS:") print("=" * 60) accuracy = results["accuracy_metrics"]["accuracy_rate"] success = results["accuracy_metrics"]["success_rate"] print(f"โœ… Accuracy Rate: {accuracy:.1%}") print(f"๐ŸŽฏ Success Rate: {success:.1%}") print(f"โฑ๏ธ Total Duration: {results['performance_metrics']['total_duration']:.1f}s") print(f"โšก Average Duration: {results['performance_metrics']['average_duration']:.1f}s") print(f"\n๐Ÿ“‹ Classification Breakdown:") for classification, count in results["classification_breakdown"].items(): print(f" - {classification}: {count}") print(f"\n๐Ÿ“ˆ Status Breakdown:") for status, count in results["accuracy_metrics"].items(): if isinstance(count, int): print(f" - {status}: {count}") print(f"\nโœ… Async batch processing test completed successfully!") except Exception as e: print(f"โŒ Test failed: {e}") import traceback traceback.print_exc() if __name__ == "__main__": asyncio.run(main())