tonthatthienvu's picture
feat: major refactoring - transform monolithic architecture into modular system
ba68fc1
#!/usr/bin/env python3
"""
Model provider implementations for GAIA agent.
Contains specific model provider classes and utilities.
"""
import os
import time
import litellm
from typing import List, Dict, Any, Optional
from ..utils.exceptions import ModelError, ModelAuthenticationError
class LiteLLMModel:
"""Custom model adapter to use LiteLLM with smolagents"""
def __init__(self, model_name: str, api_key: str, api_base: str = None):
if not api_key:
raise ValueError(f"No API key provided for {model_name}")
self.model_name = model_name
self.api_key = api_key
self.api_base = api_base
# Configure LiteLLM based on provider
self._configure_environment()
self._test_authentication()
def _configure_environment(self) -> None:
"""Configure environment variables for the model."""
try:
if "gemini" in self.model_name.lower():
os.environ["GEMINI_API_KEY"] = self.api_key
elif self.api_base:
# For custom API endpoints like Kluster.ai
os.environ["OPENAI_API_KEY"] = self.api_key
os.environ["OPENAI_API_BASE"] = self.api_base
litellm.set_verbose = False # Reduce verbose logging
except Exception as e:
raise ModelError(f"Failed to configure environment for {self.model_name}: {e}")
def _test_authentication(self) -> None:
"""Test authentication with a minimal request."""
try:
if "gemini" in self.model_name.lower():
# Test Gemini authentication
test_response = litellm.completion(
model=self.model_name,
messages=[{"role": "user", "content": "test"}],
max_tokens=1
)
print(f"✅ Initialized LiteLLM with {self.model_name}" +
(f" via {self.api_base}" if self.api_base else ""))
except Exception as e:
error_msg = f"Authentication failed for {self.model_name}: {str(e)}"
print(f"❌ {error_msg}")
raise ModelAuthenticationError(error_msg, model_name=self.model_name)
class ChatMessage:
"""Enhanced ChatMessage class for smolagents + LiteLLM compatibility"""
def __init__(self, content: str, role: str = "assistant"):
self.content = content
self.role = role
self.tool_calls = []
# Token usage attributes - covering different naming conventions
self.token_usage = {
"prompt_tokens": 0,
"completion_tokens": 0,
"total_tokens": 0
}
# Additional attributes for broader compatibility
self.input_tokens = 0 # Alternative naming for prompt_tokens
self.output_tokens = 0 # Alternative naming for completion_tokens
self.usage = self.token_usage # Alternative attribute name
# Optional metadata attributes
self.finish_reason = "stop"
self.model = None
self.created = None
def __str__(self):
return self.content
def __repr__(self):
return f"ChatMessage(role='{self.role}', content='{self.content[:50]}...')"
def __getitem__(self, key):
"""Make the object dict-like for backward compatibility"""
if key == 'input_tokens':
return self.input_tokens
elif key == 'output_tokens':
return self.output_tokens
elif key == 'content':
return self.content
elif key == 'role':
return self.role
else:
raise KeyError(f"Key '{key}' not found")
def get(self, key, default=None):
"""Dict-like get method"""
try:
return self[key]
except KeyError:
return default
def __call__(self, messages: List[Dict], **kwargs):
"""Make the model callable for smolagents compatibility"""
try:
# Format messages for LiteLLM
formatted_messages = self._format_messages(messages)
# Execute with retry logic
return self._execute_with_retry(formatted_messages, **kwargs)
except Exception as e:
print(f"❌ LiteLLM error: {e}")
print(f"Error type: {type(e)}")
if "content" in str(e):
print("This looks like a response parsing error - returning error as ChatMessage")
return self.ChatMessage(f"Error in model response: {str(e)}")
print(f"Debug - Input messages: {messages}")
# Return error as ChatMessage instead of raising to maintain compatibility
return self.ChatMessage(f"Error: {str(e)}")
def _format_messages(self, messages: List[Dict]) -> List[Dict]:
"""Format messages for LiteLLM consumption."""
formatted_messages = []
for msg in messages:
if isinstance(msg, dict):
if 'content' in msg:
content = msg['content']
role = msg.get('role', 'user')
# Handle complex content structures
if isinstance(content, list):
text_content = self._extract_text_from_content_list(content)
formatted_messages.append({"role": role, "content": text_content})
elif isinstance(content, str):
formatted_messages.append({"role": role, "content": content})
else:
formatted_messages.append({"role": role, "content": str(content)})
else:
# Fallback for messages without explicit content
formatted_messages.append({"role": "user", "content": str(msg)})
else:
# Handle string messages
formatted_messages.append({"role": "user", "content": str(msg)})
# Ensure we have at least one message
if not formatted_messages:
formatted_messages = [{"role": "user", "content": "Hello"}]
return formatted_messages
def _extract_text_from_content_list(self, content_list: List) -> str:
"""Extract text content from complex content structures."""
text_content = ""
for item in content_list:
if isinstance(item, dict):
if 'content' in item and isinstance(item['content'], list):
# Nested content structure
for subitem in item['content']:
if isinstance(subitem, dict) and subitem.get('type') == 'text':
text_content += subitem.get('text', '') + "\n"
elif item.get('type') == 'text':
text_content += item.get('text', '') + "\n"
else:
text_content += str(item) + "\n"
return text_content.strip()
def _execute_with_retry(self, formatted_messages: List[Dict], **kwargs):
"""Execute LiteLLM call with retry logic."""
max_retries = 3
base_delay = 2
for attempt in range(max_retries):
try:
# Prepare completion arguments
completion_kwargs = {
"model": self.model_name,
"messages": formatted_messages,
"temperature": kwargs.get('temperature', 0.7),
"max_tokens": kwargs.get('max_tokens', 4000)
}
# Add API base for custom endpoints
if self.api_base:
completion_kwargs["api_base"] = self.api_base
# Make the API call
response = litellm.completion(**completion_kwargs)
# Process and return response
return self._process_response(response)
except Exception as retry_error:
if self._is_retryable_error(retry_error) and attempt < max_retries - 1:
delay = base_delay * (2 ** attempt)
print(f"⏳ Model overloaded (attempt {attempt + 1}/{max_retries}), retrying in {delay}s...")
time.sleep(delay)
continue
else:
# For non-retryable errors or final attempt, raise
raise retry_error
def _is_retryable_error(self, error: Exception) -> bool:
"""Check if error is retryable (overload/503 errors)."""
error_str = str(error).lower()
return "overloaded" in error_str or "503" in error_str
def _process_response(self, response) -> 'ChatMessage':
"""Process LiteLLM response and return ChatMessage."""
content = None
if hasattr(response, 'choices') and len(response.choices) > 0:
choice = response.choices[0]
if hasattr(choice, 'message') and hasattr(choice.message, 'content'):
content = choice.message.content
elif hasattr(choice, 'text'):
content = choice.text
else:
print(f"Warning: Unexpected choice structure: {choice}")
content = str(choice)
elif isinstance(response, str):
content = response
else:
print(f"Warning: Unexpected response format: {type(response)}")
content = str(response)
# Create ChatMessage with token usage
if content:
chat_msg = self.ChatMessage(content)
self._extract_token_usage(response, chat_msg)
return chat_msg
else:
return self.ChatMessage("Error: No content in response")
def _extract_token_usage(self, response, chat_msg: 'ChatMessage') -> None:
"""Extract token usage from response."""
if hasattr(response, 'usage'):
usage = response.usage
if hasattr(usage, 'prompt_tokens'):
chat_msg.input_tokens = usage.prompt_tokens
chat_msg.token_usage['prompt_tokens'] = usage.prompt_tokens
if hasattr(usage, 'completion_tokens'):
chat_msg.output_tokens = usage.completion_tokens
chat_msg.token_usage['completion_tokens'] = usage.completion_tokens
if hasattr(usage, 'total_tokens'):
chat_msg.token_usage['total_tokens'] = usage.total_tokens
def generate(self, prompt: str, **kwargs):
"""Generate response for a single prompt"""
messages = [{"role": "user", "content": prompt}]
result = self(messages, **kwargs)
# Ensure we always return a ChatMessage object
if not isinstance(result, self.ChatMessage):
return self.ChatMessage(str(result))
return result
class GeminiProvider:
"""Specialized provider for Gemini models."""
def __init__(self, api_key: str):
self.api_key = api_key
self.model_name = "gemini/gemini-2.0-flash"
def create_model(self) -> LiteLLMModel:
"""Create Gemini model instance."""
return LiteLLMModel(self.model_name, self.api_key)
class KlusterProvider:
"""Specialized provider for Kluster.ai models."""
MODELS = {
"gemma3-27b": "openai/google/gemma-3-27b-it",
"qwen3-235b": "openai/Qwen/Qwen3-235B-A22B-FP8",
"qwen2.5-72b": "openai/Qwen/Qwen2.5-72B-Instruct",
"llama3.1-405b": "openai/meta-llama/Meta-Llama-3.1-405B-Instruct"
}
def __init__(self, api_key: str, model_key: str = "qwen3-235b"):
self.api_key = api_key
self.model_key = model_key
self.api_base = "https://api.kluster.ai/v1"
if model_key not in self.MODELS:
raise ValueError(f"Model '{model_key}' not found. Available: {list(self.MODELS.keys())}")
self.model_name = self.MODELS[model_key]
def create_model(self) -> LiteLLMModel:
"""Create Kluster.ai model instance."""
return LiteLLMModel(self.model_name, self.api_key, self.api_base)