#!/usr/bin/env python3 """ Model provider implementations for GAIA agent. Contains specific model provider classes and utilities. """ import os import time import litellm from typing import List, Dict, Any, Optional from ..utils.exceptions import ModelError, ModelAuthenticationError class LiteLLMModel: """Custom model adapter to use LiteLLM with smolagents""" def __init__(self, model_name: str, api_key: str, api_base: str = None): if not api_key: raise ValueError(f"No API key provided for {model_name}") self.model_name = model_name self.api_key = api_key self.api_base = api_base # Configure LiteLLM based on provider self._configure_environment() self._test_authentication() def _configure_environment(self) -> None: """Configure environment variables for the model.""" try: if "gemini" in self.model_name.lower(): os.environ["GEMINI_API_KEY"] = self.api_key elif self.api_base: # For custom API endpoints like Kluster.ai os.environ["OPENAI_API_KEY"] = self.api_key os.environ["OPENAI_API_BASE"] = self.api_base litellm.set_verbose = False # Reduce verbose logging except Exception as e: raise ModelError(f"Failed to configure environment for {self.model_name}: {e}") def _test_authentication(self) -> None: """Test authentication with a minimal request.""" try: if "gemini" in self.model_name.lower(): # Test Gemini authentication test_response = litellm.completion( model=self.model_name, messages=[{"role": "user", "content": "test"}], max_tokens=1 ) print(f"✅ Initialized LiteLLM with {self.model_name}" + (f" via {self.api_base}" if self.api_base else "")) except Exception as e: error_msg = f"Authentication failed for {self.model_name}: {str(e)}" print(f"❌ {error_msg}") raise ModelAuthenticationError(error_msg, model_name=self.model_name) class ChatMessage: """Enhanced ChatMessage class for smolagents + LiteLLM compatibility""" def __init__(self, content: str, role: str = "assistant"): self.content = content self.role = role self.tool_calls = [] # Token usage attributes - covering different naming conventions self.token_usage = { "prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0 } # Additional attributes for broader compatibility self.input_tokens = 0 # Alternative naming for prompt_tokens self.output_tokens = 0 # Alternative naming for completion_tokens self.usage = self.token_usage # Alternative attribute name # Optional metadata attributes self.finish_reason = "stop" self.model = None self.created = None def __str__(self): return self.content def __repr__(self): return f"ChatMessage(role='{self.role}', content='{self.content[:50]}...')" def __getitem__(self, key): """Make the object dict-like for backward compatibility""" if key == 'input_tokens': return self.input_tokens elif key == 'output_tokens': return self.output_tokens elif key == 'content': return self.content elif key == 'role': return self.role else: raise KeyError(f"Key '{key}' not found") def get(self, key, default=None): """Dict-like get method""" try: return self[key] except KeyError: return default def __call__(self, messages: List[Dict], **kwargs): """Make the model callable for smolagents compatibility""" try: # Format messages for LiteLLM formatted_messages = self._format_messages(messages) # Execute with retry logic return self._execute_with_retry(formatted_messages, **kwargs) except Exception as e: print(f"❌ LiteLLM error: {e}") print(f"Error type: {type(e)}") if "content" in str(e): print("This looks like a response parsing error - returning error as ChatMessage") return self.ChatMessage(f"Error in model response: {str(e)}") print(f"Debug - Input messages: {messages}") # Return error as ChatMessage instead of raising to maintain compatibility return self.ChatMessage(f"Error: {str(e)}") def _format_messages(self, messages: List[Dict]) -> List[Dict]: """Format messages for LiteLLM consumption.""" formatted_messages = [] for msg in messages: if isinstance(msg, dict): if 'content' in msg: content = msg['content'] role = msg.get('role', 'user') # Handle complex content structures if isinstance(content, list): text_content = self._extract_text_from_content_list(content) formatted_messages.append({"role": role, "content": text_content}) elif isinstance(content, str): formatted_messages.append({"role": role, "content": content}) else: formatted_messages.append({"role": role, "content": str(content)}) else: # Fallback for messages without explicit content formatted_messages.append({"role": "user", "content": str(msg)}) else: # Handle string messages formatted_messages.append({"role": "user", "content": str(msg)}) # Ensure we have at least one message if not formatted_messages: formatted_messages = [{"role": "user", "content": "Hello"}] return formatted_messages def _extract_text_from_content_list(self, content_list: List) -> str: """Extract text content from complex content structures.""" text_content = "" for item in content_list: if isinstance(item, dict): if 'content' in item and isinstance(item['content'], list): # Nested content structure for subitem in item['content']: if isinstance(subitem, dict) and subitem.get('type') == 'text': text_content += subitem.get('text', '') + "\n" elif item.get('type') == 'text': text_content += item.get('text', '') + "\n" else: text_content += str(item) + "\n" return text_content.strip() def _execute_with_retry(self, formatted_messages: List[Dict], **kwargs): """Execute LiteLLM call with retry logic.""" max_retries = 3 base_delay = 2 for attempt in range(max_retries): try: # Prepare completion arguments completion_kwargs = { "model": self.model_name, "messages": formatted_messages, "temperature": kwargs.get('temperature', 0.7), "max_tokens": kwargs.get('max_tokens', 4000) } # Add API base for custom endpoints if self.api_base: completion_kwargs["api_base"] = self.api_base # Make the API call response = litellm.completion(**completion_kwargs) # Process and return response return self._process_response(response) except Exception as retry_error: if self._is_retryable_error(retry_error) and attempt < max_retries - 1: delay = base_delay * (2 ** attempt) print(f"⏳ Model overloaded (attempt {attempt + 1}/{max_retries}), retrying in {delay}s...") time.sleep(delay) continue else: # For non-retryable errors or final attempt, raise raise retry_error def _is_retryable_error(self, error: Exception) -> bool: """Check if error is retryable (overload/503 errors).""" error_str = str(error).lower() return "overloaded" in error_str or "503" in error_str def _process_response(self, response) -> 'ChatMessage': """Process LiteLLM response and return ChatMessage.""" content = None if hasattr(response, 'choices') and len(response.choices) > 0: choice = response.choices[0] if hasattr(choice, 'message') and hasattr(choice.message, 'content'): content = choice.message.content elif hasattr(choice, 'text'): content = choice.text else: print(f"Warning: Unexpected choice structure: {choice}") content = str(choice) elif isinstance(response, str): content = response else: print(f"Warning: Unexpected response format: {type(response)}") content = str(response) # Create ChatMessage with token usage if content: chat_msg = self.ChatMessage(content) self._extract_token_usage(response, chat_msg) return chat_msg else: return self.ChatMessage("Error: No content in response") def _extract_token_usage(self, response, chat_msg: 'ChatMessage') -> None: """Extract token usage from response.""" if hasattr(response, 'usage'): usage = response.usage if hasattr(usage, 'prompt_tokens'): chat_msg.input_tokens = usage.prompt_tokens chat_msg.token_usage['prompt_tokens'] = usage.prompt_tokens if hasattr(usage, 'completion_tokens'): chat_msg.output_tokens = usage.completion_tokens chat_msg.token_usage['completion_tokens'] = usage.completion_tokens if hasattr(usage, 'total_tokens'): chat_msg.token_usage['total_tokens'] = usage.total_tokens def generate(self, prompt: str, **kwargs): """Generate response for a single prompt""" messages = [{"role": "user", "content": prompt}] result = self(messages, **kwargs) # Ensure we always return a ChatMessage object if not isinstance(result, self.ChatMessage): return self.ChatMessage(str(result)) return result class GeminiProvider: """Specialized provider for Gemini models.""" def __init__(self, api_key: str): self.api_key = api_key self.model_name = "gemini/gemini-2.0-flash" def create_model(self) -> LiteLLMModel: """Create Gemini model instance.""" return LiteLLMModel(self.model_name, self.api_key) class KlusterProvider: """Specialized provider for Kluster.ai models.""" MODELS = { "gemma3-27b": "openai/google/gemma-3-27b-it", "qwen3-235b": "openai/Qwen/Qwen3-235B-A22B-FP8", "qwen2.5-72b": "openai/Qwen/Qwen2.5-72B-Instruct", "llama3.1-405b": "openai/meta-llama/Meta-Llama-3.1-405B-Instruct" } def __init__(self, api_key: str, model_key: str = "qwen3-235b"): self.api_key = api_key self.model_key = model_key self.api_base = "https://api.kluster.ai/v1" if model_key not in self.MODELS: raise ValueError(f"Model '{model_key}' not found. Available: {list(self.MODELS.keys())}") self.model_name = self.MODELS[model_key] def create_model(self) -> LiteLLMModel: """Create Kluster.ai model instance.""" return LiteLLMModel(self.model_name, self.api_key, self.api_base)