Spaces:
Running
Running
#!/usr/bin/env python3 | |
""" | |
Model provider implementations for GAIA agent. | |
Contains specific model provider classes and utilities. | |
""" | |
import os | |
import time | |
import litellm | |
from typing import List, Dict, Any, Optional | |
from ..utils.exceptions import ModelError, ModelAuthenticationError | |
class LiteLLMModel: | |
"""Custom model adapter to use LiteLLM with smolagents""" | |
def __init__(self, model_name: str, api_key: str, api_base: str = None): | |
if not api_key: | |
raise ValueError(f"No API key provided for {model_name}") | |
self.model_name = model_name | |
self.api_key = api_key | |
self.api_base = api_base | |
# Configure LiteLLM based on provider | |
self._configure_environment() | |
self._test_authentication() | |
def _configure_environment(self) -> None: | |
"""Configure environment variables for the model.""" | |
try: | |
if "gemini" in self.model_name.lower(): | |
os.environ["GEMINI_API_KEY"] = self.api_key | |
elif self.api_base: | |
# For custom API endpoints like Kluster.ai | |
os.environ["OPENAI_API_KEY"] = self.api_key | |
os.environ["OPENAI_API_BASE"] = self.api_base | |
litellm.set_verbose = False # Reduce verbose logging | |
except Exception as e: | |
raise ModelError(f"Failed to configure environment for {self.model_name}: {e}") | |
def _test_authentication(self) -> None: | |
"""Test authentication with a minimal request.""" | |
try: | |
if "gemini" in self.model_name.lower(): | |
# Test Gemini authentication | |
test_response = litellm.completion( | |
model=self.model_name, | |
messages=[{"role": "user", "content": "test"}], | |
max_tokens=1 | |
) | |
print(f"✅ Initialized LiteLLM with {self.model_name}" + | |
(f" via {self.api_base}" if self.api_base else "")) | |
except Exception as e: | |
error_msg = f"Authentication failed for {self.model_name}: {str(e)}" | |
print(f"❌ {error_msg}") | |
raise ModelAuthenticationError(error_msg, model_name=self.model_name) | |
class ChatMessage: | |
"""Enhanced ChatMessage class for smolagents + LiteLLM compatibility""" | |
def __init__(self, content: str, role: str = "assistant"): | |
self.content = content | |
self.role = role | |
self.tool_calls = [] | |
# Token usage attributes - covering different naming conventions | |
self.token_usage = { | |
"prompt_tokens": 0, | |
"completion_tokens": 0, | |
"total_tokens": 0 | |
} | |
# Additional attributes for broader compatibility | |
self.input_tokens = 0 # Alternative naming for prompt_tokens | |
self.output_tokens = 0 # Alternative naming for completion_tokens | |
self.usage = self.token_usage # Alternative attribute name | |
# Optional metadata attributes | |
self.finish_reason = "stop" | |
self.model = None | |
self.created = None | |
def __str__(self): | |
return self.content | |
def __repr__(self): | |
return f"ChatMessage(role='{self.role}', content='{self.content[:50]}...')" | |
def __getitem__(self, key): | |
"""Make the object dict-like for backward compatibility""" | |
if key == 'input_tokens': | |
return self.input_tokens | |
elif key == 'output_tokens': | |
return self.output_tokens | |
elif key == 'content': | |
return self.content | |
elif key == 'role': | |
return self.role | |
else: | |
raise KeyError(f"Key '{key}' not found") | |
def get(self, key, default=None): | |
"""Dict-like get method""" | |
try: | |
return self[key] | |
except KeyError: | |
return default | |
def __call__(self, messages: List[Dict], **kwargs): | |
"""Make the model callable for smolagents compatibility""" | |
try: | |
# Format messages for LiteLLM | |
formatted_messages = self._format_messages(messages) | |
# Execute with retry logic | |
return self._execute_with_retry(formatted_messages, **kwargs) | |
except Exception as e: | |
print(f"❌ LiteLLM error: {e}") | |
print(f"Error type: {type(e)}") | |
if "content" in str(e): | |
print("This looks like a response parsing error - returning error as ChatMessage") | |
return self.ChatMessage(f"Error in model response: {str(e)}") | |
print(f"Debug - Input messages: {messages}") | |
# Return error as ChatMessage instead of raising to maintain compatibility | |
return self.ChatMessage(f"Error: {str(e)}") | |
def _format_messages(self, messages: List[Dict]) -> List[Dict]: | |
"""Format messages for LiteLLM consumption.""" | |
formatted_messages = [] | |
for msg in messages: | |
if isinstance(msg, dict): | |
if 'content' in msg: | |
content = msg['content'] | |
role = msg.get('role', 'user') | |
# Handle complex content structures | |
if isinstance(content, list): | |
text_content = self._extract_text_from_content_list(content) | |
formatted_messages.append({"role": role, "content": text_content}) | |
elif isinstance(content, str): | |
formatted_messages.append({"role": role, "content": content}) | |
else: | |
formatted_messages.append({"role": role, "content": str(content)}) | |
else: | |
# Fallback for messages without explicit content | |
formatted_messages.append({"role": "user", "content": str(msg)}) | |
else: | |
# Handle string messages | |
formatted_messages.append({"role": "user", "content": str(msg)}) | |
# Ensure we have at least one message | |
if not formatted_messages: | |
formatted_messages = [{"role": "user", "content": "Hello"}] | |
return formatted_messages | |
def _extract_text_from_content_list(self, content_list: List) -> str: | |
"""Extract text content from complex content structures.""" | |
text_content = "" | |
for item in content_list: | |
if isinstance(item, dict): | |
if 'content' in item and isinstance(item['content'], list): | |
# Nested content structure | |
for subitem in item['content']: | |
if isinstance(subitem, dict) and subitem.get('type') == 'text': | |
text_content += subitem.get('text', '') + "\n" | |
elif item.get('type') == 'text': | |
text_content += item.get('text', '') + "\n" | |
else: | |
text_content += str(item) + "\n" | |
return text_content.strip() | |
def _execute_with_retry(self, formatted_messages: List[Dict], **kwargs): | |
"""Execute LiteLLM call with retry logic.""" | |
max_retries = 3 | |
base_delay = 2 | |
for attempt in range(max_retries): | |
try: | |
# Prepare completion arguments | |
completion_kwargs = { | |
"model": self.model_name, | |
"messages": formatted_messages, | |
"temperature": kwargs.get('temperature', 0.7), | |
"max_tokens": kwargs.get('max_tokens', 4000) | |
} | |
# Add API base for custom endpoints | |
if self.api_base: | |
completion_kwargs["api_base"] = self.api_base | |
# Make the API call | |
response = litellm.completion(**completion_kwargs) | |
# Process and return response | |
return self._process_response(response) | |
except Exception as retry_error: | |
if self._is_retryable_error(retry_error) and attempt < max_retries - 1: | |
delay = base_delay * (2 ** attempt) | |
print(f"⏳ Model overloaded (attempt {attempt + 1}/{max_retries}), retrying in {delay}s...") | |
time.sleep(delay) | |
continue | |
else: | |
# For non-retryable errors or final attempt, raise | |
raise retry_error | |
def _is_retryable_error(self, error: Exception) -> bool: | |
"""Check if error is retryable (overload/503 errors).""" | |
error_str = str(error).lower() | |
return "overloaded" in error_str or "503" in error_str | |
def _process_response(self, response) -> 'ChatMessage': | |
"""Process LiteLLM response and return ChatMessage.""" | |
content = None | |
if hasattr(response, 'choices') and len(response.choices) > 0: | |
choice = response.choices[0] | |
if hasattr(choice, 'message') and hasattr(choice.message, 'content'): | |
content = choice.message.content | |
elif hasattr(choice, 'text'): | |
content = choice.text | |
else: | |
print(f"Warning: Unexpected choice structure: {choice}") | |
content = str(choice) | |
elif isinstance(response, str): | |
content = response | |
else: | |
print(f"Warning: Unexpected response format: {type(response)}") | |
content = str(response) | |
# Create ChatMessage with token usage | |
if content: | |
chat_msg = self.ChatMessage(content) | |
self._extract_token_usage(response, chat_msg) | |
return chat_msg | |
else: | |
return self.ChatMessage("Error: No content in response") | |
def _extract_token_usage(self, response, chat_msg: 'ChatMessage') -> None: | |
"""Extract token usage from response.""" | |
if hasattr(response, 'usage'): | |
usage = response.usage | |
if hasattr(usage, 'prompt_tokens'): | |
chat_msg.input_tokens = usage.prompt_tokens | |
chat_msg.token_usage['prompt_tokens'] = usage.prompt_tokens | |
if hasattr(usage, 'completion_tokens'): | |
chat_msg.output_tokens = usage.completion_tokens | |
chat_msg.token_usage['completion_tokens'] = usage.completion_tokens | |
if hasattr(usage, 'total_tokens'): | |
chat_msg.token_usage['total_tokens'] = usage.total_tokens | |
def generate(self, prompt: str, **kwargs): | |
"""Generate response for a single prompt""" | |
messages = [{"role": "user", "content": prompt}] | |
result = self(messages, **kwargs) | |
# Ensure we always return a ChatMessage object | |
if not isinstance(result, self.ChatMessage): | |
return self.ChatMessage(str(result)) | |
return result | |
class GeminiProvider: | |
"""Specialized provider for Gemini models.""" | |
def __init__(self, api_key: str): | |
self.api_key = api_key | |
self.model_name = "gemini/gemini-2.0-flash" | |
def create_model(self) -> LiteLLMModel: | |
"""Create Gemini model instance.""" | |
return LiteLLMModel(self.model_name, self.api_key) | |
class KlusterProvider: | |
"""Specialized provider for Kluster.ai models.""" | |
MODELS = { | |
"gemma3-27b": "openai/google/gemma-3-27b-it", | |
"qwen3-235b": "openai/Qwen/Qwen3-235B-A22B-FP8", | |
"qwen2.5-72b": "openai/Qwen/Qwen2.5-72B-Instruct", | |
"llama3.1-405b": "openai/meta-llama/Meta-Llama-3.1-405B-Instruct" | |
} | |
def __init__(self, api_key: str, model_key: str = "qwen3-235b"): | |
self.api_key = api_key | |
self.model_key = model_key | |
self.api_base = "https://api.kluster.ai/v1" | |
if model_key not in self.MODELS: | |
raise ValueError(f"Model '{model_key}' not found. Available: {list(self.MODELS.keys())}") | |
self.model_name = self.MODELS[model_key] | |
def create_model(self) -> LiteLLMModel: | |
"""Create Kluster.ai model instance.""" | |
return LiteLLMModel(self.model_name, self.api_key, self.api_base) |