Spaces:
Running
Running
#!/usr/bin/env python3 | |
""" | |
HF Space Async Complete GAIA Test System | |
Adapted version for Hugging Face Spaces with comprehensive testing capabilities. | |
""" | |
import asyncio | |
import json | |
import logging | |
import time | |
import os | |
from datetime import datetime | |
from pathlib import Path | |
from typing import Dict, List, Optional, Tuple | |
import sys | |
# Import core components (adapted for HF Space) | |
from main import GAIASolver | |
from gaia_web_loader import GAIAQuestionLoaderWeb | |
from question_classifier import QuestionClassifier | |
# Import advanced testing infrastructure from source | |
try: | |
from async_complete_test import AsyncGAIATestSystem | |
from async_question_processor import AsyncQuestionProcessor | |
from classification_analyzer import ClassificationAnalyzer | |
from summary_report_generator import SummaryReportGenerator | |
ADVANCED_TESTING = True | |
except ImportError as e: | |
print(f"โ ๏ธ Advanced testing components not available: {e}") | |
ADVANCED_TESTING = False | |
class HFAsyncGAIATestSystem: | |
"""Async GAIA test system adapted for Hugging Face Spaces.""" | |
def __init__(self, | |
max_concurrent: int = 2, # Lower for HF Spaces | |
timeout_seconds: int = 600, # 10 minutes for HF | |
output_dir: str = "/tmp/async_test_results"): | |
""" | |
Initialize the HF async test system. | |
Args: | |
max_concurrent: Maximum concurrent processors (2 for HF Spaces) | |
timeout_seconds: Timeout per question (10 minutes for HF) | |
output_dir: Directory for test results (use /tmp for HF) | |
""" | |
self.max_concurrent = max_concurrent | |
self.timeout_seconds = timeout_seconds | |
self.output_dir = Path(output_dir) | |
self.output_dir.mkdir(exist_ok=True) | |
# Create timestamped session directory | |
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
self.session_dir = self.output_dir / f"hf_session_{timestamp}" | |
self.session_dir.mkdir(exist_ok=True) | |
# Initialize components based on available testing infrastructure | |
if ADVANCED_TESTING: | |
# Use advanced testing system for full functionality | |
self.advanced_system = AsyncGAIATestSystem( | |
max_concurrent=max_concurrent, | |
timeout_seconds=timeout_seconds, | |
output_dir=str(output_dir) | |
) | |
self.solver = None # Will use advanced system's solver | |
self.classifier = None # Will use advanced system's classifier | |
self.loader = None # Will use advanced system's loader | |
print("โ Using advanced testing infrastructure with honest accuracy measurement") | |
else: | |
# Fallback to basic components | |
self.advanced_system = None | |
self.solver = GAIASolver() | |
self.classifier = QuestionClassifier() | |
self.loader = GAIAQuestionLoaderWeb() | |
print("โ ๏ธ Using basic testing infrastructure (some features may be limited)") | |
# Setup logging | |
self.setup_logging() | |
# Test results tracking | |
self.results: Dict[str, Dict] = {} | |
self.start_time: Optional[float] = None | |
self.end_time: Optional[float] = None | |
self.progress_callback = None | |
def setup_logging(self): | |
"""Setup logging for HF Space environment.""" | |
log_file = self.session_dir / "hf_async_test.log" | |
# Configure logger | |
self.logger = logging.getLogger("HFAsyncGAIATest") | |
self.logger.setLevel(logging.INFO) | |
# Clear existing handlers | |
for handler in self.logger.handlers[:]: | |
self.logger.removeHandler(handler) | |
# File handler | |
file_handler = logging.FileHandler(log_file) | |
file_handler.setLevel(logging.INFO) | |
# Console handler for HF logs | |
console_handler = logging.StreamHandler() | |
console_handler.setLevel(logging.INFO) | |
# Formatter | |
formatter = logging.Formatter( | |
'%(asctime)s - %(levelname)s - %(message)s' | |
) | |
file_handler.setFormatter(formatter) | |
console_handler.setFormatter(formatter) | |
# Add handlers | |
self.logger.addHandler(file_handler) | |
self.logger.addHandler(console_handler) | |
def set_progress_callback(self, callback): | |
"""Set progress callback for Gradio interface.""" | |
self.progress_callback = callback | |
def update_progress(self, message: str, current: int, total: int): | |
"""Update progress for Gradio interface.""" | |
if self.progress_callback: | |
progress = current / total if total > 0 else 0 | |
self.progress_callback(progress, message) | |
self.logger.info(f"Progress: {message} ({current}/{total})") | |
async def load_gaia_questions(self, limit: int = 20) -> List[Dict]: | |
"""Load GAIA questions (adapted for HF Space).""" | |
try: | |
# Try to load from local file first | |
questions_file = Path("gaia_questions_list.txt") | |
if questions_file.exists(): | |
self.logger.info("Loading questions from local file...") | |
questions = [] | |
with open(questions_file, 'r') as f: | |
for line in f: | |
line = line.strip() | |
if line and line.startswith('{'): | |
try: | |
question = json.loads(line) | |
questions.append(question) | |
if len(questions) >= limit: | |
break | |
except json.JSONDecodeError: | |
continue | |
self.logger.info(f"Loaded {len(questions)} questions from file") | |
return questions[:limit] | |
else: | |
# Fallback to web loader | |
self.logger.info("Loading questions from web...") | |
questions = await self.loader.load_questions_async(limit=limit) | |
self.logger.info(f"Loaded {len(questions)} questions from web") | |
return questions | |
except Exception as e: | |
self.logger.error(f"Failed to load questions: {e}") | |
return [] | |
async def process_single_question(self, question: Dict, semaphore: asyncio.Semaphore) -> Tuple[str, Dict]: | |
"""Process a single question with semaphore control.""" | |
async with semaphore: | |
question_id = question.get('task_id', 'unknown') | |
start_time = time.time() | |
try: | |
self.logger.info(f"Starting question {question_id}") | |
# Classify question | |
classification = await asyncio.get_event_loop().run_in_executor( | |
None, self.classifier.classify_question, question.get('Question', '') | |
) | |
# Solve question with timeout | |
try: | |
result = await asyncio.wait_for( | |
asyncio.get_event_loop().run_in_executor( | |
None, self.solver.solve_question, question | |
), | |
timeout=self.timeout_seconds | |
) | |
duration = time.time() - start_time | |
# Handle string result from solver | |
answer = str(result) if result else "" | |
# Validate result if possible | |
validation_status = "unknown" | |
if 'Final Answer' in question: | |
expected = str(question['Final Answer']).strip().lower() | |
actual = answer.strip().lower() | |
validation_status = "correct" if expected == actual else "incorrect" | |
return question_id, { | |
'status': 'completed', | |
'answer': answer, | |
'explanation': f"Solved via {classification.get('primary_agent', 'unknown')} agent", | |
'classification': classification, | |
'validation_status': validation_status, | |
'expected_answer': question.get('Final Answer', ''), | |
'duration_seconds': duration, | |
'timestamp': datetime.now().isoformat() | |
} | |
except asyncio.TimeoutError: | |
duration = time.time() - start_time | |
self.logger.warning(f"Question {question_id} timed out after {duration:.2f}s") | |
return question_id, { | |
'status': 'timeout', | |
'error': f'Timeout after {self.timeout_seconds}s', | |
'duration_seconds': duration, | |
'timestamp': datetime.now().isoformat() | |
} | |
except Exception as e: | |
duration = time.time() - start_time | |
self.logger.error(f"Question {question_id} failed: {e}") | |
return question_id, { | |
'status': 'error', | |
'error': str(e), | |
'duration_seconds': duration, | |
'timestamp': datetime.now().isoformat() | |
} | |
async def run_comprehensive_test(self, question_limit: int = 20) -> Dict: | |
"""Run comprehensive test on HF Space with advanced features when available.""" | |
self.logger.info("=== HF ASYNC GAIA TEST STARTING ===") | |
self.start_time = time.time() | |
# Use advanced system if available for full functionality | |
if ADVANCED_TESTING and self.advanced_system: | |
self.update_progress("Using advanced testing system with honest accuracy measurement...", 0, question_limit) | |
return await self._run_advanced_test(question_limit) | |
# Fallback to basic testing | |
self.update_progress("Using basic testing system...", 0, question_limit) | |
return await self._run_basic_test(question_limit) | |
async def _run_advanced_test(self, question_limit: int) -> Dict: | |
"""Run test using the advanced testing system.""" | |
try: | |
# Use the advanced system directly | |
return await self.advanced_system.run_complete_test_async(max_questions=question_limit) | |
except Exception as e: | |
self.logger.error(f"Advanced test failed: {e}") | |
self.update_progress(f"Advanced test failed, falling back to basic test: {e}", 0, question_limit) | |
return await self._run_basic_test(question_limit) | |
async def _run_basic_test(self, question_limit: int) -> Dict: | |
"""Run basic test for fallback.""" | |
try: | |
# Load questions | |
self.update_progress("Loading GAIA questions...", 0, question_limit) | |
questions = await self.load_gaia_questions(limit=question_limit) | |
if not questions: | |
return {"status": "error", "message": "No questions loaded"} | |
actual_count = len(questions) | |
self.logger.info(f"Processing {actual_count} questions") | |
# Create semaphore for concurrency control | |
semaphore = asyncio.Semaphore(self.max_concurrent) | |
# Process questions with progress tracking | |
tasks = [] | |
for i, question in enumerate(questions): | |
task = self.process_single_question(question, semaphore) | |
tasks.append(task) | |
# Process with progress updates | |
completed = 0 | |
results = {} | |
for coro in asyncio.as_completed(tasks): | |
question_id, result = await coro | |
results[question_id] = result | |
completed += 1 | |
status = result.get('status', 'unknown') | |
self.update_progress( | |
f"Completed {completed}/{actual_count} questions (last: {status})", | |
completed, | |
actual_count | |
) | |
self.results = results | |
self.end_time = time.time() | |
total_duration = self.end_time - self.start_time | |
# Generate summary | |
summary = self.generate_test_summary(total_duration) | |
# Save results | |
await self.save_results(summary) | |
self.update_progress("Test completed!", actual_count, actual_count) | |
return summary | |
except Exception as e: | |
self.logger.error(f"Test failed: {e}") | |
return {"status": "error", "message": str(e)} | |
def generate_test_summary(self, duration: float) -> Dict: | |
"""Generate comprehensive test summary.""" | |
total_questions = len(self.results) | |
status_counts = {} | |
validation_counts = {} | |
classification_counts = {} | |
for result in self.results.values(): | |
# Status counts | |
status = result.get('status', 'unknown') | |
status_counts[status] = status_counts.get(status, 0) + 1 | |
# Validation counts | |
validation = result.get('validation_status', 'unknown') | |
validation_counts[validation] = validation_counts.get(validation, 0) + 1 | |
# Classification counts | |
classification = result.get('classification', {}) | |
agent_type = classification.get('primary_agent', 'unknown') | |
classification_counts[agent_type] = classification_counts.get(agent_type, 0) + 1 | |
# Calculate accuracy | |
correct_count = validation_counts.get('correct', 0) | |
total_with_answers = validation_counts.get('correct', 0) + validation_counts.get('incorrect', 0) | |
accuracy = (correct_count / total_with_answers * 100) if total_with_answers > 0 else 0 | |
return { | |
"session_id": self.session_dir.name, | |
"timestamp": datetime.now().isoformat(), | |
"duration_seconds": duration, | |
"total_questions": total_questions, | |
"status_counts": status_counts, | |
"validation_counts": validation_counts, | |
"classification_counts": classification_counts, | |
"accuracy_percent": round(accuracy, 1), | |
"questions_per_minute": round(total_questions / (duration / 60), 2), | |
"results": self.results | |
} | |
async def save_results(self, summary: Dict): | |
"""Save test results to files.""" | |
try: | |
# Save main summary | |
summary_file = self.session_dir / "hf_test_summary.json" | |
with open(summary_file, 'w') as f: | |
json.dump(summary, f, indent=2) | |
# Save individual results | |
results_file = self.session_dir / "individual_results.json" | |
with open(results_file, 'w') as f: | |
json.dump(self.results, f, indent=2) | |
self.logger.info(f"Results saved to {self.session_dir}") | |
except Exception as e: | |
self.logger.error(f"Failed to save results: {e}") | |
async def run_hf_comprehensive_test( | |
question_limit: int = 20, | |
max_concurrent: int = 2, | |
progress_callback=None | |
) -> Dict: | |
""" | |
Run comprehensive GAIA test for HF Space. | |
Args: | |
question_limit: Number of questions to test | |
max_concurrent: Maximum concurrent processors | |
progress_callback: Gradio progress callback | |
Returns: | |
Test summary dictionary | |
""" | |
system = HFAsyncGAIATestSystem( | |
max_concurrent=max_concurrent, | |
timeout_seconds=600 # 10 minutes per question | |
) | |
if progress_callback: | |
system.set_progress_callback(progress_callback) | |
return await system.run_comprehensive_test(question_limit) | |
if __name__ == "__main__": | |
# For testing | |
async def main(): | |
result = await run_hf_comprehensive_test(question_limit=5) | |
print(json.dumps(result, indent=2)) | |
asyncio.run(main()) |