om4r932's picture
Fix some issues and error handling fix
322e8fa
import json
import traceback
from fastapi import FastAPI, HTTPException
from dotenv import load_dotenv
import os
import re
from huggingface_hub import ChatCompletionInputMessage, ChatCompletionInputTool
import litellm
litellm.ssl_verify = False
from litellm.router import Router
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
from typing import List, Optional, Literal, Type, Union
load_dotenv()
app = FastAPI()
api_keys = []
for k,v in os.environ.items():
if re.match(r'^GROQ_\d+$', k):
api_keys.append(v)
models_data = {
"allam-2-7b": {"rpm": 30, "rpd": 7000, "tpm": 6000},
"compound-beta": {"rpm": 15, "rpd": 200, "tpm": 70000},
"compound-beta-mini": {"rpm": 15, "rpd": 200, "tpm": 70000},
"deepseek-r1-distill-llama-70b": {"rpm": 30, "rpd": 1000, "tpm": 6000},
"gemma2-9b-it": {"rpm": 30, "rpd": 14400, "tpm": 15000, "tpd": 500000},
"llama-3.1-8b-instant": {"rpm": 30, "rpd": 14400, "tpm": 6000, "tpd": 500000},
"llama-3.3-70b-versatile": {"rpm": 30, "rpd": 1000, "tpm": 12000, "tpd": 100000},
"llama3-70b-8192": {"rpm": 30, "rpd": 14400, "tpm": 6000, "tpd": 500000},
"llama3-8b-8192": {"rpm": 30, "rpd": 14400, "tpm": 6000, "tpd": 500000},
"meta-llama/llama-4-maverick-17b-128e-instruct": {"rpm": 30, "rpd": 1000, "tpm": 6000, "tpd": None},
"meta-llama/llama-4-scout-17b-16e-instruct": {"rpm": 30, "rpd": 1000, "tpm": 30000, "tpd": None},
"meta-llama/llama-guard-4-12b": {"rpm": 30, "rpd": 14400, "tpm": 15000, "tpd": 500000},
"meta-llama/llama-prompt-guard-2-22m": {"rpm": 30, "rpd": 14400, "tpm": 15000, "tpd": None},
"meta-llama/llama-prompt-guard-2-86m": {"rpm": 30, "rpd": 14400, "tpm": None, "tpd": None},
}
model_list = [
{
"model_name": f"{model_name}_{key_idx}" if key_idx != 0 else f"{model_name}", # Nom unique par clé
"litellm_params": {
"model": f"groq/{model_name}",
"api_key": api_key
},
"timeout": 120,
"max_retries": 5
}
for model_name, config in models_data.items()
for key_idx, api_key in enumerate(api_keys)
]
def generate_fallbacks_per_key():
fallbacks = [] # Liste de dictionnaires au lieu d'un dictionnaire
excluded_models = {"compound-beta", "compound-beta-mini"}
for model_name in models_data.keys():
if model_name in excluded_models:
continue
# Pour chaque version d'un modèle, les fallbacks sont les autres versions du même modèle
for key_idx in range(len(api_keys)):
current_model = f"{model_name}_{key_idx}" if key_idx != 0 else f"{model_name}"
fallback_versions = [
f"{model_name}_{other_key_idx}" if other_key_idx != 0 else f"{model_name}"
for other_key_idx in range(len(api_keys))
if other_key_idx != key_idx
]
# Format attendu par LiteLLM
fallbacks.append({
current_model: fallback_versions
})
return fallbacks
fallbacks = generate_fallbacks_per_key()
router = Router(
model_list=model_list,
fallbacks=fallbacks,
num_retries=5,
retry_after=10
)
app.add_middleware(
CORSMiddleware,
allow_credentials=True,
allow_headers=["*"],
allow_methods=["GET", "POST"],
allow_origins=["*"]
)
class ChatRequest(BaseModel):
models: List[str]
messages: List[ChatCompletionInputMessage]
tools: Optional[List[ChatCompletionInputTool]] = None
temperature: Optional[float] = None
max_tokens: Optional[int] = None
n: Optional[int] = None
stream: Optional[bool] = None
stop: Optional[List[str]] = None
def clean_message(msg) -> dict:
"""Convertit un message en dictionnaire, gérant différents types d'objets"""
if hasattr(msg, 'model_dump'):
# Pour les objets Pydantic
return {k: v for k, v in msg.model_dump().items() if v is not None}
elif hasattr(msg, '__dict__'):
# Pour les objets avec attributs
return {k: v for k, v in msg.__dict__.items() if v is not None}
elif isinstance(msg, dict):
# Si c'est déjà un dictionnaire
return {k: v for k, v in msg.items() if v is not None}
else:
# Conversion générique
return dict(msg)
@app.get("/")
def main_page():
return {"status": "ok"}
@app.post("/chat")
def chat_with_groq(req: ChatRequest):
models = req.models
if len(models) == 1 and (models[0] == "" or models[0] not in models_data.keys()):
raise HTTPException(400, detail="Empty model field")
messages = [clean_message(m) for m in req.messages]
if len(models) == 1:
try:
resp = router.completion(model=models[0], messages=messages, **req.model_dump(exclude={"models", "messages"}, exclude_defaults=True, exclude_none=True))
print("Asked to", models[0], ":", messages)
return {"error": False, "content": resp.choices[0].message.content}
except Exception as e:
traceback.print_exception(e)
return {"error": True, "content": "Aucune clé ne fonctionne avec le modèle sélectionné, patientez ...."}
else:
for model in models:
if model not in models_data.keys():
print(f"Erreur: {model} n'existe pas")
continue
try:
resp = router.completion(model=model, messages=messages, **req.model_dump(exclude={"models", "messages"}, exclude_defaults=True, exclude_none=True))
print("Asked to", models[0], ":", messages)
return {"error": False, "content": resp.choices[0].message.content}
except Exception as e:
traceback.print_exception(e)
continue
return {"error": True, "content": "Tous les modèles n'ont pas fonctionné avec les différentes clé, patientez ...."}