tonthatthienvu's picture
feat: major refactoring - transform monolithic architecture into modular system
ba68fc1
#!/usr/bin/env python3
"""
Main GAIA solver with refactored architecture.
Coordinates question classification, tool execution, and answer extraction.
"""
from typing import Dict, Any, Optional
from dataclasses import dataclass
from ..config.settings import Config, config
from ..models.manager import ModelManager
from ..utils.exceptions import GAIAError, ModelError, ClassificationError
from .answer_extractor import AnswerExtractor
from .question_processor import QuestionProcessor
@dataclass
class SolverResult:
"""Result from solving a question."""
answer: str
confidence: float
method_used: str
execution_time: Optional[float] = None
metadata: Dict[str, Any] = None
def __post_init__(self):
if self.metadata is None:
self.metadata = {}
class GAIASolver:
"""Main GAIA solver using refactored architecture."""
def __init__(self, config_instance: Optional[Config] = None):
self.config = config_instance or config
# Initialize components
self.model_manager = ModelManager(self.config)
self.answer_extractor = AnswerExtractor()
self.question_processor = QuestionProcessor(self.model_manager, self.config)
# Initialize models
self._initialize_models()
print(f"βœ… GAIA Solver ready with refactored architecture!")
def _initialize_models(self) -> None:
"""Initialize all model providers."""
try:
results = self.model_manager.initialize_all()
# Report initialization results
success_count = sum(1 for success in results.values() if success)
total_count = len(results)
print(f"πŸ€– Initialized {success_count}/{total_count} model providers")
for name, success in results.items():
status = "βœ…" if success else "❌"
print(f" {status} {name}")
if success_count == 0:
raise ModelError("No model providers successfully initialized")
except Exception as e:
raise ModelError(f"Model initialization failed: {e}")
def solve_question(self, question_data: Dict[str, Any]) -> SolverResult:
"""Solve a single GAIA question."""
import time
start_time = time.time()
try:
# Extract question details
task_id = question_data.get("task_id", "unknown")
question_text = question_data.get("question", "")
if not question_text.strip():
raise GAIAError("Empty question provided")
print(f"\n🧩 Solving question {task_id}")
print(f"πŸ“ Question: {question_text[:100]}...")
# Process question with specialized processor
raw_response = self.question_processor.process_question(question_data)
# Extract final answer
final_answer = self.answer_extractor.extract_final_answer(
raw_response, question_text
)
execution_time = time.time() - start_time
return SolverResult(
answer=final_answer,
confidence=0.8, # Could be enhanced with actual confidence scoring
method_used="refactored_architecture",
execution_time=execution_time,
metadata={
"task_id": task_id,
"question_length": len(question_text),
"response_length": len(raw_response)
}
)
except Exception as e:
execution_time = time.time() - start_time
error_msg = f"Error solving question: {str(e)}"
print(f"❌ {error_msg}")
return SolverResult(
answer=error_msg,
confidence=0.0,
method_used="error_fallback",
execution_time=execution_time,
metadata={"error": str(e)}
)
def solve_random_question(self) -> Optional[SolverResult]:
"""Solve a random question from the loaded set."""
try:
question = self.question_processor.get_random_question()
if not question:
print("❌ No questions available!")
return None
result = self.solve_question(question)
return result
except Exception as e:
print(f"❌ Error getting random question: {e}")
return None
def solve_multiple_questions(self, max_questions: int = 5) -> list[SolverResult]:
"""Solve multiple questions for testing."""
print(f"\n🎯 Solving up to {max_questions} questions...")
results = []
try:
questions = self.question_processor.get_questions(max_questions)
for i, question in enumerate(questions):
print(f"\n--- Question {i+1}/{len(questions)} ---")
result = self.solve_question(question)
results.append(result)
except Exception as e:
print(f"❌ Error in batch processing: {e}")
return results
def get_system_status(self) -> Dict[str, Any]:
"""Get comprehensive system status."""
return {
"models": self.model_manager.get_model_status(),
"available_providers": self.model_manager.get_available_providers(),
"current_provider": self.model_manager.current_provider,
"config": {
"debug_mode": self.config.debug_mode,
"log_level": self.config.log_level,
"available_models": [model.value for model in self.config.get_available_models()]
},
"components": {
"model_manager": "initialized",
"answer_extractor": "initialized",
"question_processor": "initialized"
}
}
def switch_model(self, provider_name: str) -> bool:
"""Switch to a specific model provider."""
try:
success = self.model_manager.switch_to_provider(provider_name)
if success:
print(f"βœ… Switched to model provider: {provider_name}")
else:
print(f"❌ Failed to switch to provider: {provider_name}")
return success
except Exception as e:
print(f"❌ Error switching model: {e}")
return False
def reset_models(self) -> None:
"""Reset all model providers."""
try:
self.model_manager.reset_all_providers()
print("βœ… Reset all model providers")
except Exception as e:
print(f"❌ Error resetting models: {e}")
# Backward compatibility function
def extract_final_answer(raw_answer: str, question_text: str) -> str:
"""Backward compatibility function for the old extract_final_answer."""
extractor = AnswerExtractor()
return extractor.extract_final_answer(raw_answer, question_text)