File size: 7,263 Bytes
ba68fc1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
#!/usr/bin/env python3
"""
Main GAIA solver with refactored architecture.
Coordinates question classification, tool execution, and answer extraction.
"""

from typing import Dict, Any, Optional
from dataclasses import dataclass

from ..config.settings import Config, config
from ..models.manager import ModelManager
from ..utils.exceptions import GAIAError, ModelError, ClassificationError
from .answer_extractor import AnswerExtractor
from .question_processor import QuestionProcessor


@dataclass
class SolverResult:
    """Result from solving a question."""
    answer: str
    confidence: float
    method_used: str
    execution_time: Optional[float] = None
    metadata: Dict[str, Any] = None
    
    def __post_init__(self):
        if self.metadata is None:
            self.metadata = {}


class GAIASolver:
    """Main GAIA solver using refactored architecture."""
    
    def __init__(self, config_instance: Optional[Config] = None):
        self.config = config_instance or config
        
        # Initialize components
        self.model_manager = ModelManager(self.config)
        self.answer_extractor = AnswerExtractor()
        self.question_processor = QuestionProcessor(self.model_manager, self.config)
        
        # Initialize models
        self._initialize_models()
        
        print(f"βœ… GAIA Solver ready with refactored architecture!")
    
    def _initialize_models(self) -> None:
        """Initialize all model providers."""
        try:
            results = self.model_manager.initialize_all()
            
            # Report initialization results
            success_count = sum(1 for success in results.values() if success)
            total_count = len(results)
            
            print(f"πŸ€– Initialized {success_count}/{total_count} model providers")
            
            for name, success in results.items():
                status = "βœ…" if success else "❌"
                print(f"  {status} {name}")
            
            if success_count == 0:
                raise ModelError("No model providers successfully initialized")
                
        except Exception as e:
            raise ModelError(f"Model initialization failed: {e}")
    
    def solve_question(self, question_data: Dict[str, Any]) -> SolverResult:
        """Solve a single GAIA question."""
        import time
        start_time = time.time()
        
        try:
            # Extract question details
            task_id = question_data.get("task_id", "unknown")
            question_text = question_data.get("question", "")
            
            if not question_text.strip():
                raise GAIAError("Empty question provided")
            
            print(f"\n🧩 Solving question {task_id}")
            print(f"πŸ“ Question: {question_text[:100]}...")
            
            # Process question with specialized processor
            raw_response = self.question_processor.process_question(question_data)
            
            # Extract final answer
            final_answer = self.answer_extractor.extract_final_answer(
                raw_response, question_text
            )
            
            execution_time = time.time() - start_time
            
            return SolverResult(
                answer=final_answer,
                confidence=0.8,  # Could be enhanced with actual confidence scoring
                method_used="refactored_architecture",
                execution_time=execution_time,
                metadata={
                    "task_id": task_id,
                    "question_length": len(question_text),
                    "response_length": len(raw_response)
                }
            )
            
        except Exception as e:
            execution_time = time.time() - start_time
            error_msg = f"Error solving question: {str(e)}"
            print(f"❌ {error_msg}")
            
            return SolverResult(
                answer=error_msg,
                confidence=0.0,
                method_used="error_fallback",
                execution_time=execution_time,
                metadata={"error": str(e)}
            )
    
    def solve_random_question(self) -> Optional[SolverResult]:
        """Solve a random question from the loaded set."""
        try:
            question = self.question_processor.get_random_question()
            if not question:
                print("❌ No questions available!")
                return None
            
            result = self.solve_question(question)
            return result
            
        except Exception as e:
            print(f"❌ Error getting random question: {e}")
            return None
    
    def solve_multiple_questions(self, max_questions: int = 5) -> list[SolverResult]:
        """Solve multiple questions for testing."""
        print(f"\n🎯 Solving up to {max_questions} questions...")
        results = []
        
        try:
            questions = self.question_processor.get_questions(max_questions)
            
            for i, question in enumerate(questions):
                print(f"\n--- Question {i+1}/{len(questions)} ---")
                result = self.solve_question(question)
                results.append(result)
        
        except Exception as e:
            print(f"❌ Error in batch processing: {e}")
        
        return results
    
    def get_system_status(self) -> Dict[str, Any]:
        """Get comprehensive system status."""
        return {
            "models": self.model_manager.get_model_status(),
            "available_providers": self.model_manager.get_available_providers(),
            "current_provider": self.model_manager.current_provider,
            "config": {
                "debug_mode": self.config.debug_mode,
                "log_level": self.config.log_level,
                "available_models": [model.value for model in self.config.get_available_models()]
            },
            "components": {
                "model_manager": "initialized",
                "answer_extractor": "initialized", 
                "question_processor": "initialized"
            }
        }
    
    def switch_model(self, provider_name: str) -> bool:
        """Switch to a specific model provider."""
        try:
            success = self.model_manager.switch_to_provider(provider_name)
            if success:
                print(f"βœ… Switched to model provider: {provider_name}")
            else:
                print(f"❌ Failed to switch to provider: {provider_name}")
            return success
        except Exception as e:
            print(f"❌ Error switching model: {e}")
            return False
    
    def reset_models(self) -> None:
        """Reset all model providers."""
        try:
            self.model_manager.reset_all_providers()
            print("βœ… Reset all model providers")
        except Exception as e:
            print(f"❌ Error resetting models: {e}")


# Backward compatibility function
def extract_final_answer(raw_answer: str, question_text: str) -> str:
    """Backward compatibility function for the old extract_final_answer."""
    extractor = AnswerExtractor()
    return extractor.extract_final_answer(raw_answer, question_text)