Final_Assignment / async_complete_test_hf.py
tonthatthienvu's picture
Clean repository without binary files
37cadfb
raw
history blame
14.1 kB
#!/usr/bin/env python3
"""
HF Space Async Complete GAIA Test System
Adapted version for Hugging Face Spaces with comprehensive testing capabilities.
"""
import asyncio
import json
import logging
import time
import os
from datetime import datetime
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import sys
# Import core components (adapted for HF Space)
from main import GAIASolver
from gaia_web_loader import GAIAQuestionLoaderWeb
from question_classifier import QuestionClassifier
class HFAsyncGAIATestSystem:
"""Async GAIA test system adapted for Hugging Face Spaces."""
def __init__(self,
max_concurrent: int = 2, # Lower for HF Spaces
timeout_seconds: int = 600, # 10 minutes for HF
output_dir: str = "/tmp/async_test_results"):
"""
Initialize the HF async test system.
Args:
max_concurrent: Maximum concurrent processors (2 for HF Spaces)
timeout_seconds: Timeout per question (10 minutes for HF)
output_dir: Directory for test results (use /tmp for HF)
"""
self.max_concurrent = max_concurrent
self.timeout_seconds = timeout_seconds
self.output_dir = Path(output_dir)
self.output_dir.mkdir(exist_ok=True)
# Create timestamped session directory
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
self.session_dir = self.output_dir / f"hf_session_{timestamp}"
self.session_dir.mkdir(exist_ok=True)
# Initialize components
self.solver = GAIASolver()
self.classifier = QuestionClassifier()
self.loader = GAIAQuestionLoaderWeb()
# Setup logging
self.setup_logging()
# Test results tracking
self.results: Dict[str, Dict] = {}
self.start_time: Optional[float] = None
self.end_time: Optional[float] = None
self.progress_callback = None
def setup_logging(self):
"""Setup logging for HF Space environment."""
log_file = self.session_dir / "hf_async_test.log"
# Configure logger
self.logger = logging.getLogger("HFAsyncGAIATest")
self.logger.setLevel(logging.INFO)
# Clear existing handlers
for handler in self.logger.handlers[:]:
self.logger.removeHandler(handler)
# File handler
file_handler = logging.FileHandler(log_file)
file_handler.setLevel(logging.INFO)
# Console handler for HF logs
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO)
# Formatter
formatter = logging.Formatter(
'%(asctime)s - %(levelname)s - %(message)s'
)
file_handler.setFormatter(formatter)
console_handler.setFormatter(formatter)
# Add handlers
self.logger.addHandler(file_handler)
self.logger.addHandler(console_handler)
def set_progress_callback(self, callback):
"""Set progress callback for Gradio interface."""
self.progress_callback = callback
def update_progress(self, message: str, current: int, total: int):
"""Update progress for Gradio interface."""
if self.progress_callback:
progress = current / total if total > 0 else 0
self.progress_callback(progress, message)
self.logger.info(f"Progress: {message} ({current}/{total})")
async def load_gaia_questions(self, limit: int = 20) -> List[Dict]:
"""Load GAIA questions (adapted for HF Space)."""
try:
# Try to load from local file first
questions_file = Path("gaia_questions_list.txt")
if questions_file.exists():
self.logger.info("Loading questions from local file...")
questions = []
with open(questions_file, 'r') as f:
for line in f:
line = line.strip()
if line and line.startswith('{'):
try:
question = json.loads(line)
questions.append(question)
if len(questions) >= limit:
break
except json.JSONDecodeError:
continue
self.logger.info(f"Loaded {len(questions)} questions from file")
return questions[:limit]
else:
# Fallback to web loader
self.logger.info("Loading questions from web...")
questions = await self.loader.load_questions_async(limit=limit)
self.logger.info(f"Loaded {len(questions)} questions from web")
return questions
except Exception as e:
self.logger.error(f"Failed to load questions: {e}")
return []
async def process_single_question(self, question: Dict, semaphore: asyncio.Semaphore) -> Tuple[str, Dict]:
"""Process a single question with semaphore control."""
async with semaphore:
question_id = question.get('task_id', 'unknown')
start_time = time.time()
try:
self.logger.info(f"Starting question {question_id}")
# Classify question
classification = await asyncio.get_event_loop().run_in_executor(
None, self.classifier.classify_question, question.get('Question', '')
)
# Solve question with timeout
try:
result = await asyncio.wait_for(
asyncio.get_event_loop().run_in_executor(
None, self.solver.solve_question, question
),
timeout=self.timeout_seconds
)
duration = time.time() - start_time
# Handle string result from solver
answer = str(result) if result else ""
# Validate result if possible
validation_status = "unknown"
if 'Final Answer' in question:
expected = str(question['Final Answer']).strip().lower()
actual = answer.strip().lower()
validation_status = "correct" if expected == actual else "incorrect"
return question_id, {
'status': 'completed',
'answer': answer,
'explanation': f"Solved via {classification.get('primary_agent', 'unknown')} agent",
'classification': classification,
'validation_status': validation_status,
'expected_answer': question.get('Final Answer', ''),
'duration_seconds': duration,
'timestamp': datetime.now().isoformat()
}
except asyncio.TimeoutError:
duration = time.time() - start_time
self.logger.warning(f"Question {question_id} timed out after {duration:.2f}s")
return question_id, {
'status': 'timeout',
'error': f'Timeout after {self.timeout_seconds}s',
'duration_seconds': duration,
'timestamp': datetime.now().isoformat()
}
except Exception as e:
duration = time.time() - start_time
self.logger.error(f"Question {question_id} failed: {e}")
return question_id, {
'status': 'error',
'error': str(e),
'duration_seconds': duration,
'timestamp': datetime.now().isoformat()
}
async def run_comprehensive_test(self, question_limit: int = 20) -> Dict:
"""Run comprehensive test on HF Space."""
self.logger.info("=== HF ASYNC GAIA TEST STARTING ===")
self.start_time = time.time()
try:
# Load questions
self.update_progress("Loading GAIA questions...", 0, question_limit)
questions = await self.load_gaia_questions(limit=question_limit)
if not questions:
return {"status": "error", "message": "No questions loaded"}
actual_count = len(questions)
self.logger.info(f"Processing {actual_count} questions")
# Create semaphore for concurrency control
semaphore = asyncio.Semaphore(self.max_concurrent)
# Process questions with progress tracking
tasks = []
for i, question in enumerate(questions):
task = self.process_single_question(question, semaphore)
tasks.append(task)
# Process with progress updates
completed = 0
results = {}
for coro in asyncio.as_completed(tasks):
question_id, result = await coro
results[question_id] = result
completed += 1
status = result.get('status', 'unknown')
self.update_progress(
f"Completed {completed}/{actual_count} questions (last: {status})",
completed,
actual_count
)
self.results = results
self.end_time = time.time()
total_duration = self.end_time - self.start_time
# Generate summary
summary = self.generate_test_summary(total_duration)
# Save results
await self.save_results(summary)
self.update_progress("Test completed!", actual_count, actual_count)
return summary
except Exception as e:
self.logger.error(f"Test failed: {e}")
return {"status": "error", "message": str(e)}
def generate_test_summary(self, duration: float) -> Dict:
"""Generate comprehensive test summary."""
total_questions = len(self.results)
status_counts = {}
validation_counts = {}
classification_counts = {}
for result in self.results.values():
# Status counts
status = result.get('status', 'unknown')
status_counts[status] = status_counts.get(status, 0) + 1
# Validation counts
validation = result.get('validation_status', 'unknown')
validation_counts[validation] = validation_counts.get(validation, 0) + 1
# Classification counts
classification = result.get('classification', {})
agent_type = classification.get('primary_agent', 'unknown')
classification_counts[agent_type] = classification_counts.get(agent_type, 0) + 1
# Calculate accuracy
correct_count = validation_counts.get('correct', 0)
total_with_answers = validation_counts.get('correct', 0) + validation_counts.get('incorrect', 0)
accuracy = (correct_count / total_with_answers * 100) if total_with_answers > 0 else 0
return {
"session_id": self.session_dir.name,
"timestamp": datetime.now().isoformat(),
"duration_seconds": duration,
"total_questions": total_questions,
"status_counts": status_counts,
"validation_counts": validation_counts,
"classification_counts": classification_counts,
"accuracy_percent": round(accuracy, 1),
"questions_per_minute": round(total_questions / (duration / 60), 2),
"results": self.results
}
async def save_results(self, summary: Dict):
"""Save test results to files."""
try:
# Save main summary
summary_file = self.session_dir / "hf_test_summary.json"
with open(summary_file, 'w') as f:
json.dump(summary, f, indent=2)
# Save individual results
results_file = self.session_dir / "individual_results.json"
with open(results_file, 'w') as f:
json.dump(self.results, f, indent=2)
self.logger.info(f"Results saved to {self.session_dir}")
except Exception as e:
self.logger.error(f"Failed to save results: {e}")
async def run_hf_comprehensive_test(
question_limit: int = 20,
max_concurrent: int = 2,
progress_callback=None
) -> Dict:
"""
Run comprehensive GAIA test for HF Space.
Args:
question_limit: Number of questions to test
max_concurrent: Maximum concurrent processors
progress_callback: Gradio progress callback
Returns:
Test summary dictionary
"""
system = HFAsyncGAIATestSystem(
max_concurrent=max_concurrent,
timeout_seconds=600 # 10 minutes per question
)
if progress_callback:
system.set_progress_callback(progress_callback)
return await system.run_comprehensive_test(question_limit)
if __name__ == "__main__":
# For testing
async def main():
result = await run_hf_comprehensive_test(question_limit=5)
print(json.dumps(result, indent=2))
asyncio.run(main())