File size: 12,560 Bytes
ba68fc1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
#!/usr/bin/env python3
"""
Model provider implementations for GAIA agent.
Contains specific model provider classes and utilities.
"""

import os
import time
import litellm
from typing import List, Dict, Any, Optional

from ..utils.exceptions import ModelError, ModelAuthenticationError


class LiteLLMModel:
    """Custom model adapter to use LiteLLM with smolagents"""
    
    def __init__(self, model_name: str, api_key: str, api_base: str = None):
        if not api_key:
            raise ValueError(f"No API key provided for {model_name}")
        
        self.model_name = model_name
        self.api_key = api_key
        self.api_base = api_base
        
        # Configure LiteLLM based on provider
        self._configure_environment()
        self._test_authentication()
    
    def _configure_environment(self) -> None:
        """Configure environment variables for the model."""
        try:
            if "gemini" in self.model_name.lower():
                os.environ["GEMINI_API_KEY"] = self.api_key
            elif self.api_base:
                # For custom API endpoints like Kluster.ai
                os.environ["OPENAI_API_KEY"] = self.api_key
                os.environ["OPENAI_API_BASE"] = self.api_base
            
            litellm.set_verbose = False  # Reduce verbose logging
            
        except Exception as e:
            raise ModelError(f"Failed to configure environment for {self.model_name}: {e}")
    
    def _test_authentication(self) -> None:
        """Test authentication with a minimal request."""
        try:
            if "gemini" in self.model_name.lower():
                # Test Gemini authentication
                test_response = litellm.completion(
                    model=self.model_name,
                    messages=[{"role": "user", "content": "test"}],
                    max_tokens=1
                )
            
            print(f"✅ Initialized LiteLLM with {self.model_name}" + 
                  (f" via {self.api_base}" if self.api_base else ""))
                  
        except Exception as e:
            error_msg = f"Authentication failed for {self.model_name}: {str(e)}"
            print(f"❌ {error_msg}")
            raise ModelAuthenticationError(error_msg, model_name=self.model_name)
    
    class ChatMessage:
        """Enhanced ChatMessage class for smolagents + LiteLLM compatibility"""
        
        def __init__(self, content: str, role: str = "assistant"):
            self.content = content
            self.role = role
            self.tool_calls = []
            
            # Token usage attributes - covering different naming conventions
            self.token_usage = {
                "prompt_tokens": 0,
                "completion_tokens": 0,
                "total_tokens": 0
            }
            
            # Additional attributes for broader compatibility
            self.input_tokens = 0  # Alternative naming for prompt_tokens
            self.output_tokens = 0  # Alternative naming for completion_tokens
            self.usage = self.token_usage  # Alternative attribute name
            
            # Optional metadata attributes
            self.finish_reason = "stop"
            self.model = None
            self.created = None
            
        def __str__(self):
            return self.content
        
        def __repr__(self):
            return f"ChatMessage(role='{self.role}', content='{self.content[:50]}...')"
            
        def __getitem__(self, key):
            """Make the object dict-like for backward compatibility"""
            if key == 'input_tokens':
                return self.input_tokens
            elif key == 'output_tokens':
                return self.output_tokens
            elif key == 'content':
                return self.content
            elif key == 'role':
                return self.role
            else:
                raise KeyError(f"Key '{key}' not found")
        
        def get(self, key, default=None):
            """Dict-like get method"""
            try:
                return self[key]
            except KeyError:
                return default
    
    def __call__(self, messages: List[Dict], **kwargs):
        """Make the model callable for smolagents compatibility"""
        try:
            # Format messages for LiteLLM
            formatted_messages = self._format_messages(messages)
            
            # Execute with retry logic
            return self._execute_with_retry(formatted_messages, **kwargs)
            
        except Exception as e:
            print(f"❌ LiteLLM error: {e}")
            print(f"Error type: {type(e)}")
            if "content" in str(e):
                print("This looks like a response parsing error - returning error as ChatMessage")
                return self.ChatMessage(f"Error in model response: {str(e)}")
            print(f"Debug - Input messages: {messages}")
            # Return error as ChatMessage instead of raising to maintain compatibility
            return self.ChatMessage(f"Error: {str(e)}")
    
    def _format_messages(self, messages: List[Dict]) -> List[Dict]:
        """Format messages for LiteLLM consumption."""
        formatted_messages = []
        
        for msg in messages:
            if isinstance(msg, dict):
                if 'content' in msg:
                    content = msg['content']
                    role = msg.get('role', 'user')
                    
                    # Handle complex content structures
                    if isinstance(content, list):
                        text_content = self._extract_text_from_content_list(content)
                        formatted_messages.append({"role": role, "content": text_content})
                    elif isinstance(content, str):
                        formatted_messages.append({"role": role, "content": content})
                    else:
                        formatted_messages.append({"role": role, "content": str(content)})
                else:
                    # Fallback for messages without explicit content
                    formatted_messages.append({"role": "user", "content": str(msg)})
            else:
                # Handle string messages
                formatted_messages.append({"role": "user", "content": str(msg)})
        
        # Ensure we have at least one message
        if not formatted_messages:
            formatted_messages = [{"role": "user", "content": "Hello"}]
        
        return formatted_messages
    
    def _extract_text_from_content_list(self, content_list: List) -> str:
        """Extract text content from complex content structures."""
        text_content = ""
        
        for item in content_list:
            if isinstance(item, dict):
                if 'content' in item and isinstance(item['content'], list):
                    # Nested content structure
                    for subitem in item['content']:
                        if isinstance(subitem, dict) and subitem.get('type') == 'text':
                            text_content += subitem.get('text', '') + "\n"
                elif item.get('type') == 'text':
                    text_content += item.get('text', '') + "\n"
            else:
                text_content += str(item) + "\n"
        
        return text_content.strip()
    
    def _execute_with_retry(self, formatted_messages: List[Dict], **kwargs):
        """Execute LiteLLM call with retry logic."""
        max_retries = 3
        base_delay = 2
        
        for attempt in range(max_retries):
            try:
                # Prepare completion arguments
                completion_kwargs = {
                    "model": self.model_name,
                    "messages": formatted_messages,
                    "temperature": kwargs.get('temperature', 0.7),
                    "max_tokens": kwargs.get('max_tokens', 4000)
                }
                
                # Add API base for custom endpoints
                if self.api_base:
                    completion_kwargs["api_base"] = self.api_base
                
                # Make the API call
                response = litellm.completion(**completion_kwargs)
                
                # Process and return response
                return self._process_response(response)
                
            except Exception as retry_error:
                if self._is_retryable_error(retry_error) and attempt < max_retries - 1:
                    delay = base_delay * (2 ** attempt)
                    print(f"⏳ Model overloaded (attempt {attempt + 1}/{max_retries}), retrying in {delay}s...")
                    time.sleep(delay)
                    continue
                else:
                    # For non-retryable errors or final attempt, raise
                    raise retry_error
    
    def _is_retryable_error(self, error: Exception) -> bool:
        """Check if error is retryable (overload/503 errors)."""
        error_str = str(error).lower()
        return "overloaded" in error_str or "503" in error_str
    
    def _process_response(self, response) -> 'ChatMessage':
        """Process LiteLLM response and return ChatMessage."""
        content = None
        
        if hasattr(response, 'choices') and len(response.choices) > 0:
            choice = response.choices[0]
            if hasattr(choice, 'message') and hasattr(choice.message, 'content'):
                content = choice.message.content
            elif hasattr(choice, 'text'):
                content = choice.text
            else:
                print(f"Warning: Unexpected choice structure: {choice}")
                content = str(choice)
        elif isinstance(response, str):
            content = response
        else:
            print(f"Warning: Unexpected response format: {type(response)}")
            content = str(response)
        
        # Create ChatMessage with token usage
        if content:
            chat_msg = self.ChatMessage(content)
            self._extract_token_usage(response, chat_msg)
            return chat_msg
        else:
            return self.ChatMessage("Error: No content in response")
    
    def _extract_token_usage(self, response, chat_msg: 'ChatMessage') -> None:
        """Extract token usage from response."""
        if hasattr(response, 'usage'):
            usage = response.usage
            if hasattr(usage, 'prompt_tokens'):
                chat_msg.input_tokens = usage.prompt_tokens
                chat_msg.token_usage['prompt_tokens'] = usage.prompt_tokens
            if hasattr(usage, 'completion_tokens'):
                chat_msg.output_tokens = usage.completion_tokens
                chat_msg.token_usage['completion_tokens'] = usage.completion_tokens
            if hasattr(usage, 'total_tokens'):
                chat_msg.token_usage['total_tokens'] = usage.total_tokens
    
    def generate(self, prompt: str, **kwargs):
        """Generate response for a single prompt"""
        messages = [{"role": "user", "content": prompt}]
        result = self(messages, **kwargs)
        # Ensure we always return a ChatMessage object
        if not isinstance(result, self.ChatMessage):
            return self.ChatMessage(str(result))
        return result


class GeminiProvider:
    """Specialized provider for Gemini models."""
    
    def __init__(self, api_key: str):
        self.api_key = api_key
        self.model_name = "gemini/gemini-2.0-flash"
    
    def create_model(self) -> LiteLLMModel:
        """Create Gemini model instance."""
        return LiteLLMModel(self.model_name, self.api_key)


class KlusterProvider:
    """Specialized provider for Kluster.ai models."""
    
    MODELS = {
        "gemma3-27b": "openai/google/gemma-3-27b-it",
        "qwen3-235b": "openai/Qwen/Qwen3-235B-A22B-FP8",
        "qwen2.5-72b": "openai/Qwen/Qwen2.5-72B-Instruct",
        "llama3.1-405b": "openai/meta-llama/Meta-Llama-3.1-405B-Instruct"
    }
    
    def __init__(self, api_key: str, model_key: str = "qwen3-235b"):
        self.api_key = api_key
        self.model_key = model_key
        self.api_base = "https://api.kluster.ai/v1"
        
        if model_key not in self.MODELS:
            raise ValueError(f"Model '{model_key}' not found. Available: {list(self.MODELS.keys())}")
        
        self.model_name = self.MODELS[model_key]
    
    def create_model(self) -> LiteLLMModel:
        """Create Kluster.ai model instance."""
        return LiteLLMModel(self.model_name, self.api_key, self.api_base)