Spaces:
Running
Running
#!/usr/bin/env python3 | |
""" | |
Model management system for GAIA agent. | |
Handles model initialization, fallback chains, and lifecycle management. | |
""" | |
import os | |
import time | |
import random | |
from typing import Optional, List, Dict, Any, Union | |
from abc import ABC, abstractmethod | |
from enum import Enum | |
from ..config.settings import Config, ModelType, config | |
from ..utils.exceptions import ( | |
ModelError, ModelNotAvailableError, ModelAuthenticationError, | |
ModelOverloadedError, create_error | |
) | |
class ModelStatus(Enum): | |
"""Model status states.""" | |
AVAILABLE = "available" | |
UNAVAILABLE = "unavailable" | |
OVERLOADED = "overloaded" | |
AUTHENTICATING = "authenticating" | |
ERROR = "error" | |
class ModelProvider(ABC): | |
"""Abstract base class for model providers.""" | |
def __init__(self, name: str, model_type: ModelType): | |
self.name = name | |
self.model_type = model_type | |
self.status = ModelStatus.UNAVAILABLE | |
self.last_error: Optional[str] = None | |
self.retry_count = 0 | |
self.last_used = None | |
def initialize(self) -> bool: | |
"""Initialize the model provider. Returns True if successful.""" | |
pass | |
def is_available(self) -> bool: | |
"""Check if the model is available for use.""" | |
pass | |
def create_model(self, **kwargs): | |
"""Create model instance.""" | |
pass | |
def reset_error_state(self) -> None: | |
"""Reset error state for retry attempts.""" | |
self.retry_count = 0 | |
self.last_error = None | |
self.status = ModelStatus.UNAVAILABLE | |
def record_usage(self) -> None: | |
"""Record model usage timestamp.""" | |
self.last_used = time.time() | |
def handle_error(self, error: Exception) -> None: | |
"""Handle and categorize model errors.""" | |
error_str = str(error).lower() | |
if "overloaded" in error_str or "503" in error_str: | |
self.status = ModelStatus.OVERLOADED | |
self.last_error = "Model overloaded" | |
elif "authentication" in error_str or "401" in error_str or "403" in error_str: | |
self.status = ModelStatus.ERROR | |
self.last_error = "Authentication failed" | |
else: | |
self.status = ModelStatus.ERROR | |
self.last_error = str(error) | |
self.retry_count += 1 | |
class LiteLLMProvider(ModelProvider): | |
"""Provider for LiteLLM-based models (Gemini, Kluster.ai).""" | |
def __init__(self, model_name: str, api_key: str, api_base: Optional[str] = None): | |
self.model_name = model_name | |
self.api_key = api_key | |
self.api_base = api_base | |
self._model_instance = None | |
model_type = self._determine_model_type(model_name) | |
super().__init__(model_name, model_type) | |
def _determine_model_type(self, model_name: str) -> ModelType: | |
"""Determine model type from name.""" | |
if "gemini" in model_name.lower(): | |
return ModelType.GEMINI | |
elif hasattr(self, 'api_base') and self.api_base and "kluster" in str(self.api_base).lower(): | |
return ModelType.KLUSTER | |
else: | |
return ModelType.QWEN | |
def initialize(self) -> bool: | |
"""Initialize LiteLLM model.""" | |
try: | |
# Import the class from the same module | |
from .providers import LiteLLMModel | |
self.status = ModelStatus.AUTHENTICATING | |
# Configure environment | |
if self.model_type == ModelType.GEMINI: | |
os.environ["GEMINI_API_KEY"] = self.api_key | |
elif self.api_base: | |
os.environ["OPENAI_API_KEY"] = self.api_key | |
os.environ["OPENAI_API_BASE"] = self.api_base | |
# Create model instance | |
self._model_instance = LiteLLMModel( | |
model_name=self.model_name, | |
api_key=self.api_key, | |
api_base=self.api_base | |
) | |
self.status = ModelStatus.AVAILABLE | |
return True | |
except Exception as e: | |
self.handle_error(e) | |
return False | |
def is_available(self) -> bool: | |
"""Check if model is available.""" | |
return self.status == ModelStatus.AVAILABLE and self._model_instance is not None | |
def create_model(self, **kwargs): | |
"""Create model instance.""" | |
if not self.is_available(): | |
raise ModelNotAvailableError(f"Model {self.name} is not available") | |
self.record_usage() | |
return self._model_instance | |
class HuggingFaceProvider(ModelProvider): | |
"""Provider for HuggingFace models.""" | |
def __init__(self, model_name: str, api_key: str): | |
super().__init__(model_name, ModelType.QWEN) | |
self.model_name = model_name | |
self.api_key = api_key | |
self._model_instance = None | |
def initialize(self) -> bool: | |
"""Initialize HuggingFace model.""" | |
try: | |
from smolagents import InferenceClientModel | |
self.status = ModelStatus.AUTHENTICATING | |
self._model_instance = InferenceClientModel( | |
model_id=self.model_name, | |
token=self.api_key | |
) | |
self.status = ModelStatus.AVAILABLE | |
return True | |
except Exception as e: | |
self.handle_error(e) | |
return False | |
def is_available(self) -> bool: | |
"""Check if model is available.""" | |
return self.status == ModelStatus.AVAILABLE and self._model_instance is not None | |
def create_model(self, **kwargs): | |
"""Create model instance.""" | |
if not self.is_available(): | |
raise ModelNotAvailableError(f"Model {self.name} is not available") | |
self.record_usage() | |
return self._model_instance | |
class ModelManager: | |
"""Manages model providers and fallback chains.""" | |
def __init__(self, config_instance: Optional[Config] = None): | |
self.config = config_instance or config | |
self.providers: Dict[str, ModelProvider] = {} | |
self.fallback_chain: List[str] = [] | |
self.current_provider: Optional[str] = None | |
self._initialize_providers() | |
def _initialize_providers(self) -> None: | |
"""Initialize all available model providers.""" | |
# Kluster.ai models | |
if self.config.has_api_key("kluster"): | |
kluster_key = self.config.get_api_key("kluster") | |
for model_key, model_name in self.config.model.KLUSTER_MODELS.items(): | |
provider_name = f"kluster_{model_key}" | |
provider = LiteLLMProvider( | |
model_name=model_name, | |
api_key=kluster_key, | |
api_base=self.config.model.KLUSTER_API_BASE | |
) | |
self.providers[provider_name] = provider | |
# Gemini models | |
if self.config.has_api_key("gemini"): | |
gemini_key = self.config.get_api_key("gemini") | |
provider = LiteLLMProvider( | |
model_name=self.config.model.GEMINI_MODEL, | |
api_key=gemini_key | |
) | |
self.providers["gemini"] = provider | |
# HuggingFace models | |
if self.config.has_api_key("huggingface"): | |
hf_key = self.config.get_api_key("huggingface") | |
provider = HuggingFaceProvider( | |
model_name=self.config.model.QWEN_MODEL, | |
api_key=hf_key | |
) | |
self.providers["qwen"] = provider | |
# Set up fallback chain | |
self._setup_fallback_chain() | |
def _setup_fallback_chain(self) -> None: | |
"""Set up model fallback chain based on availability and preference.""" | |
# Priority order: Kluster.ai (highest tier) -> Gemini -> Qwen | |
priority_providers = [] | |
# Add Kluster.ai models (prefer qwen3-235b) | |
if "kluster_qwen3-235b" in self.providers: | |
priority_providers.append("kluster_qwen3-235b") | |
elif "kluster_gemma3-27b" in self.providers: | |
priority_providers.append("kluster_gemma3-27b") | |
# Add other available providers | |
if "gemini" in self.providers: | |
priority_providers.append("gemini") | |
if "qwen" in self.providers: | |
priority_providers.append("qwen") | |
self.fallback_chain = priority_providers | |
if not self.fallback_chain: | |
raise ModelNotAvailableError("No model providers available") | |
def initialize_all(self) -> Dict[str, bool]: | |
"""Initialize all model providers.""" | |
results = {} | |
for name, provider in self.providers.items(): | |
try: | |
success = provider.initialize() | |
results[name] = success | |
if success and self.current_provider is None: | |
self.current_provider = name | |
except Exception as e: | |
results[name] = False | |
provider.handle_error(e) | |
return results | |
def get_current_model(self, **kwargs): | |
"""Get current active model.""" | |
if self.current_provider is None: | |
self._select_best_provider() | |
if self.current_provider is None: | |
raise ModelNotAvailableError("No models available") | |
provider = self.providers[self.current_provider] | |
try: | |
return provider.create_model(**kwargs) | |
except Exception as e: | |
provider.handle_error(e) | |
# Try to switch to fallback | |
if self._switch_to_fallback(): | |
return self.get_current_model(**kwargs) | |
else: | |
raise ModelError(f"All models failed: {str(e)}") | |
def _select_best_provider(self) -> None: | |
"""Select the best available provider from fallback chain.""" | |
for provider_name in self.fallback_chain: | |
provider = self.providers.get(provider_name) | |
if provider and provider.is_available(): | |
self.current_provider = provider_name | |
return | |
elif provider and provider.status == ModelStatus.UNAVAILABLE: | |
# Try to initialize | |
if provider.initialize(): | |
self.current_provider = provider_name | |
return | |
self.current_provider = None | |
def _switch_to_fallback(self) -> bool: | |
"""Switch to next available model in fallback chain.""" | |
if self.current_provider is None: | |
return False | |
try: | |
current_index = self.fallback_chain.index(self.current_provider) | |
# Try next providers in chain | |
for i in range(current_index + 1, len(self.fallback_chain)): | |
provider_name = self.fallback_chain[i] | |
provider = self.providers[provider_name] | |
if provider.is_available() or provider.initialize(): | |
self.current_provider = provider_name | |
return True | |
except ValueError: | |
pass | |
# No fallback available | |
self.current_provider = None | |
return False | |
def retry_current_model(self, max_retries: int = 3) -> bool: | |
"""Retry current model with exponential backoff.""" | |
if self.current_provider is None: | |
return False | |
provider = self.providers[self.current_provider] | |
for attempt in range(max_retries): | |
if provider.status == ModelStatus.OVERLOADED: | |
wait_time = (2 ** attempt) + random.random() | |
time.sleep(wait_time) | |
# Reset error state and try to reinitialize | |
provider.reset_error_state() | |
if provider.initialize(): | |
return True | |
return False | |
def get_model_status(self) -> Dict[str, Dict[str, Any]]: | |
"""Get status of all model providers.""" | |
status = {} | |
for name, provider in self.providers.items(): | |
status[name] = { | |
"status": provider.status.value, | |
"model_type": provider.model_type.value, | |
"last_error": provider.last_error, | |
"retry_count": provider.retry_count, | |
"last_used": provider.last_used, | |
"is_current": name == self.current_provider | |
} | |
return status | |
def switch_to_provider(self, provider_name: str) -> bool: | |
"""Manually switch to specific provider.""" | |
if provider_name not in self.providers: | |
raise ModelNotAvailableError(f"Provider {provider_name} not found") | |
provider = self.providers[provider_name] | |
if provider.is_available() or provider.initialize(): | |
self.current_provider = provider_name | |
return True | |
return False | |
def get_available_providers(self) -> List[str]: | |
"""Get list of available providers.""" | |
available = [] | |
for name, provider in self.providers.items(): | |
if provider.is_available(): | |
available.append(name) | |
return available | |
def reset_all_providers(self) -> None: | |
"""Reset all providers to allow retry.""" | |
for provider in self.providers.values(): | |
provider.reset_error_state() | |
self.current_provider = None | |
self._select_best_provider() | |
# Monkey patch for smolagents compatibility | |
def monkey_patch_smolagents(): | |
"""Apply compatibility patches for smolagents.""" | |
try: | |
import smolagents.monitoring | |
from smolagents.monitoring import TokenUsage | |
# Store original update_metrics function | |
original_update_metrics = smolagents.monitoring.Monitor.update_metrics | |
def patched_update_metrics(self, step_log): | |
"""Patched version that handles dict token_usage""" | |
try: | |
# If token_usage is a dict, convert it to TokenUsage object | |
if hasattr(step_log, 'token_usage') and isinstance(step_log.token_usage, dict): | |
token_dict = step_log.token_usage | |
# Create TokenUsage object from dict | |
step_log.token_usage = TokenUsage( | |
input_tokens=token_dict.get('prompt_tokens', 0), | |
output_tokens=token_dict.get('completion_tokens', 0) | |
) | |
# Call original function | |
return original_update_metrics(self, step_log) | |
except Exception as e: | |
# If patching fails, try to handle gracefully | |
print(f"Token usage patch warning: {e}") | |
return original_update_metrics(self, step_log) | |
# Apply the patch | |
smolagents.monitoring.Monitor.update_metrics = patched_update_metrics | |
print("✅ Applied smolagents token usage compatibility patch") | |
except ImportError: | |
print("⚠️ smolagents not available, skipping compatibility patch") | |
except Exception as e: | |
print(f"⚠️ Failed to apply smolagents patch: {e}") | |
# Apply monkey patch on import | |
monkey_patch_smolagents() |