|
import os |
|
import logging |
|
import json |
|
from huggingface_hub import model_info, InferenceClient |
|
from dotenv import load_dotenv |
|
|
|
|
|
PREFERRED_PROVIDERS = ["sambanova", "novita"] |
|
|
|
def filter_providers(providers): |
|
"""Filter providers to only include preferred ones.""" |
|
return [provider for provider in providers if provider in PREFERRED_PROVIDERS] |
|
|
|
def prioritize_providers(providers): |
|
"""Prioritize preferred providers, keeping all others.""" |
|
preferred = [provider for provider in providers if provider in PREFERRED_PROVIDERS] |
|
non_preferred = [provider for provider in providers if provider not in PREFERRED_PROVIDERS] |
|
return preferred + non_preferred |
|
|
|
|
|
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") |
|
logger = logging.getLogger(__name__) |
|
|
|
def is_vision_model(model_name: str) -> bool: |
|
""" |
|
Check if the model is a vision model based on its name |
|
|
|
Args: |
|
model_name: Name of the model |
|
|
|
Returns: |
|
True if it's a vision model, False otherwise |
|
""" |
|
vision_indicators = ["-VL-", "vision", "clip", "image"] |
|
return any(indicator in model_name.lower() for indicator in vision_indicators) |
|
|
|
def get_test_payload(model_name: str) -> dict: |
|
""" |
|
Get the appropriate test payload based on model type |
|
|
|
Args: |
|
model_name: Name of the model |
|
|
|
Returns: |
|
Dictionary containing the test payload |
|
""" |
|
|
|
return { |
|
"inputs": "Hello", |
|
"parameters": { |
|
"max_new_tokens": 5 |
|
} |
|
} |
|
|
|
def test_provider(model_name: str, provider: str, verbose: bool = False) -> bool: |
|
""" |
|
Test if a specific provider is available for a model using InferenceClient |
|
|
|
Args: |
|
model_name: Name of the model |
|
provider: Provider to test |
|
verbose: Whether to log detailed information |
|
|
|
Returns: |
|
True if the provider is available, False otherwise |
|
""" |
|
try: |
|
|
|
load_dotenv() |
|
|
|
|
|
hf_token = os.environ.get("HF_TOKEN") |
|
if not hf_token: |
|
raise ValueError("HF_TOKEN not defined in environment") |
|
|
|
if verbose: |
|
logger.info(f"Testing provider {provider} for model {model_name}") |
|
|
|
|
|
client = InferenceClient( |
|
model=model_name, |
|
token=hf_token, |
|
provider=provider, |
|
timeout=10 |
|
) |
|
|
|
try: |
|
|
|
response = client.chat_completion( |
|
messages=[{"role": "user", "content": "Hello"}], |
|
max_tokens=5 |
|
) |
|
|
|
if verbose: |
|
logger.info(f"Provider {provider} is available for {model_name}") |
|
return True |
|
|
|
except Exception as e: |
|
if verbose: |
|
error_message = str(e) |
|
logger.error(f"Error with provider {provider}: {error_message}") |
|
|
|
|
|
if "status_code=429" in error_message: |
|
logger.warning(f"Provider {provider} rate limited. You may need to wait or upgrade your plan.") |
|
elif "status_code=401" in error_message: |
|
logger.warning(f"Authentication failed for provider {provider}. Check your token.") |
|
elif "status_code=503" in error_message: |
|
logger.warning(f"Provider {provider} service unavailable. Model may be loading or provider is down.") |
|
elif "timed out" in error_message.lower(): |
|
logger.error(f"Timeout error with provider {provider} - request timed out after 10 seconds") |
|
return False |
|
|
|
except Exception as e: |
|
if verbose: |
|
logger.error(f"Error in test_provider: {str(e)}") |
|
return False |
|
|
|
def get_available_model_provider(model_name, verbose=False): |
|
""" |
|
Get the first available provider for a given model. |
|
|
|
Args: |
|
model_name: Name of the model on the Hub |
|
verbose: Whether to log detailed information |
|
|
|
Returns: |
|
First available provider or None if none are available |
|
""" |
|
try: |
|
|
|
load_dotenv() |
|
|
|
|
|
hf_token = os.environ.get("HF_TOKEN") |
|
if not hf_token: |
|
raise ValueError("HF_TOKEN not defined in environment") |
|
|
|
|
|
info = model_info(model_name, expand="inferenceProviderMapping") |
|
if not hasattr(info, "inference_provider_mapping"): |
|
if verbose: |
|
logger.info(f"No inference providers found for {model_name}") |
|
return None |
|
|
|
providers = list(info.inference_provider_mapping.keys()) |
|
if not providers: |
|
if verbose: |
|
logger.info(f"Empty list of providers for {model_name}") |
|
return None |
|
|
|
|
|
providers = prioritize_providers(providers) |
|
|
|
if verbose: |
|
logger.info(f"Available providers for {model_name}: {', '.join(providers)}") |
|
|
|
|
|
for provider in providers: |
|
if test_provider(model_name, provider, verbose): |
|
return provider |
|
|
|
return None |
|
|
|
except Exception as e: |
|
if verbose: |
|
logger.error(f"Error in get_available_model_provider: {str(e)}") |
|
return None |
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models = [ |
|
"Qwen/QwQ-32B", |
|
"Qwen/Qwen2.5-72B-Instruct", |
|
"meta-llama/Llama-3.3-70B-Instruct", |
|
"deepseek-ai/DeepSeek-R1-Distill-Llama-70B", |
|
"mistralai/Mistral-Small-24B-Instruct-2501", |
|
] |
|
|
|
providers = [] |
|
|
|
for model in models: |
|
provider = get_available_model_provider(model, verbose=True) |
|
providers.append(provider) |
|
|
|
print(f"Providers {len(providers)}: {providers}") |
|
|
|
|
|
|
|
|
|
|