import json import traceback from fastapi import FastAPI, HTTPException from dotenv import load_dotenv import os import re from huggingface_hub import ChatCompletionInputMessage, ChatCompletionInputTool import litellm litellm.ssl_verify = False from litellm.router import Router from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel, Field from typing import List, Optional, Literal, Type, Union load_dotenv() app = FastAPI() api_keys = [] for k,v in os.environ.items(): if re.match(r'^GROQ_\d+$', k): api_keys.append(v) models_data = { "allam-2-7b": {"rpm": 30, "rpd": 7000, "tpm": 6000}, "compound-beta": {"rpm": 15, "rpd": 200, "tpm": 70000}, "compound-beta-mini": {"rpm": 15, "rpd": 200, "tpm": 70000}, "deepseek-r1-distill-llama-70b": {"rpm": 30, "rpd": 1000, "tpm": 6000}, "gemma2-9b-it": {"rpm": 30, "rpd": 14400, "tpm": 15000, "tpd": 500000}, "llama-3.1-8b-instant": {"rpm": 30, "rpd": 14400, "tpm": 6000, "tpd": 500000}, "llama-3.3-70b-versatile": {"rpm": 30, "rpd": 1000, "tpm": 12000, "tpd": 100000}, "llama3-70b-8192": {"rpm": 30, "rpd": 14400, "tpm": 6000, "tpd": 500000}, "llama3-8b-8192": {"rpm": 30, "rpd": 14400, "tpm": 6000, "tpd": 500000}, "meta-llama/llama-4-maverick-17b-128e-instruct": {"rpm": 30, "rpd": 1000, "tpm": 6000, "tpd": None}, "meta-llama/llama-4-scout-17b-16e-instruct": {"rpm": 30, "rpd": 1000, "tpm": 30000, "tpd": None}, "meta-llama/llama-guard-4-12b": {"rpm": 30, "rpd": 14400, "tpm": 15000, "tpd": 500000}, "meta-llama/llama-prompt-guard-2-22m": {"rpm": 30, "rpd": 14400, "tpm": 15000, "tpd": None}, "meta-llama/llama-prompt-guard-2-86m": {"rpm": 30, "rpd": 14400, "tpm": None, "tpd": None}, } model_list = [ { "model_name": f"{model_name}_{key_idx}" if key_idx != 0 else f"{model_name}", # Nom unique par clé "litellm_params": { "model": f"groq/{model_name}", "api_key": api_key }, "timeout": 120, "max_retries": 5 } for model_name, config in models_data.items() for key_idx, api_key in enumerate(api_keys) ] def generate_fallbacks_per_key(): fallbacks = [] # Liste de dictionnaires au lieu d'un dictionnaire excluded_models = {"compound-beta", "compound-beta-mini"} for model_name in models_data.keys(): if model_name in excluded_models: continue # Pour chaque version d'un modèle, les fallbacks sont les autres versions du même modèle for key_idx in range(len(api_keys)): current_model = f"{model_name}_{key_idx}" if key_idx != 0 else f"{model_name}" fallback_versions = [ f"{model_name}_{other_key_idx}" if other_key_idx != 0 else f"{model_name}" for other_key_idx in range(len(api_keys)) if other_key_idx != key_idx ] # Format attendu par LiteLLM fallbacks.append({ current_model: fallback_versions }) return fallbacks fallbacks = generate_fallbacks_per_key() router = Router( model_list=model_list, fallbacks=fallbacks, num_retries=5, retry_after=10 ) app.add_middleware( CORSMiddleware, allow_credentials=True, allow_headers=["*"], allow_methods=["GET", "POST"], allow_origins=["*"] ) class ChatRequest(BaseModel): models: List[str] messages: List[ChatCompletionInputMessage] tools: Optional[List[ChatCompletionInputTool]] = None temperature: Optional[float] = None max_tokens: Optional[int] = None n: Optional[int] = None stream: Optional[bool] = None stop: Optional[List[str]] = None def clean_message(msg) -> dict: """Convertit un message en dictionnaire, gérant différents types d'objets""" if hasattr(msg, 'model_dump'): # Pour les objets Pydantic return {k: v for k, v in msg.model_dump().items() if v is not None} elif hasattr(msg, '__dict__'): # Pour les objets avec attributs return {k: v for k, v in msg.__dict__.items() if v is not None} elif isinstance(msg, dict): # Si c'est déjà un dictionnaire return {k: v for k, v in msg.items() if v is not None} else: # Conversion générique return dict(msg) @app.get("/") def main_page(): return {"status": "ok"} @app.post("/chat") def chat_with_groq(req: ChatRequest): models = req.models if len(models) == 1 and (models[0] == "" or models[0] not in models_data.keys()): raise HTTPException(400, detail="Empty model field") messages = [clean_message(m) for m in req.messages] if len(models) == 1: try: resp = router.completion(model=models[0], messages=messages, **req.model_dump(exclude={"models", "messages"}, exclude_defaults=True, exclude_none=True)) print("Asked to", models[0], ":", messages) return {"error": False, "content": resp.choices[0].message.content} except Exception as e: traceback.print_exception(e) return {"error": True, "content": "Aucune clé ne fonctionne avec le modèle sélectionné, patientez ...."} else: for model in models: if model not in models_data.keys(): print(f"Erreur: {model} n'existe pas") continue try: resp = router.completion(model=model, messages=messages, **req.model_dump(exclude={"models", "messages"}, exclude_defaults=True, exclude_none=True)) print("Asked to", models[0], ":", messages) return {"error": False, "content": resp.choices[0].message.content} except Exception as e: traceback.print_exception(e) continue return {"error": True, "content": "Tous les modèles n'ont pas fonctionné avec les différentes clé, patientez ...."}