|
import json |
|
import traceback |
|
from fastapi import FastAPI, HTTPException |
|
from dotenv import load_dotenv |
|
import os |
|
import re |
|
from huggingface_hub import ChatCompletionInputMessage, ChatCompletionInputTool |
|
import litellm |
|
litellm.ssl_verify = False |
|
from litellm.router import Router |
|
from fastapi.middleware.cors import CORSMiddleware |
|
from pydantic import BaseModel, Field |
|
from typing import List, Optional, Literal, Type, Union |
|
|
|
load_dotenv() |
|
|
|
app = FastAPI() |
|
|
|
api_keys = [] |
|
|
|
for k,v in os.environ.items(): |
|
if re.match(r'^GROQ_\d+$', k): |
|
api_keys.append(v) |
|
|
|
models_data = { |
|
"allam-2-7b": {"rpm": 30, "rpd": 7000, "tpm": 6000}, |
|
"compound-beta": {"rpm": 15, "rpd": 200, "tpm": 70000}, |
|
"compound-beta-mini": {"rpm": 15, "rpd": 200, "tpm": 70000}, |
|
"deepseek-r1-distill-llama-70b": {"rpm": 30, "rpd": 1000, "tpm": 6000}, |
|
"gemma2-9b-it": {"rpm": 30, "rpd": 14400, "tpm": 15000, "tpd": 500000}, |
|
"llama-3.1-8b-instant": {"rpm": 30, "rpd": 14400, "tpm": 6000, "tpd": 500000}, |
|
"llama-3.3-70b-versatile": {"rpm": 30, "rpd": 1000, "tpm": 12000, "tpd": 100000}, |
|
"llama3-70b-8192": {"rpm": 30, "rpd": 14400, "tpm": 6000, "tpd": 500000}, |
|
"llama3-8b-8192": {"rpm": 30, "rpd": 14400, "tpm": 6000, "tpd": 500000}, |
|
"meta-llama/llama-4-maverick-17b-128e-instruct": {"rpm": 30, "rpd": 1000, "tpm": 6000, "tpd": None}, |
|
"meta-llama/llama-4-scout-17b-16e-instruct": {"rpm": 30, "rpd": 1000, "tpm": 30000, "tpd": None}, |
|
"meta-llama/llama-guard-4-12b": {"rpm": 30, "rpd": 14400, "tpm": 15000, "tpd": 500000}, |
|
"meta-llama/llama-prompt-guard-2-22m": {"rpm": 30, "rpd": 14400, "tpm": 15000, "tpd": None}, |
|
"meta-llama/llama-prompt-guard-2-86m": {"rpm": 30, "rpd": 14400, "tpm": None, "tpd": None}, |
|
} |
|
|
|
model_list = [ |
|
{ |
|
"model_name": f"{model_name}_{key_idx}" if key_idx != 0 else f"{model_name}", |
|
"litellm_params": { |
|
"model": f"groq/{model_name}", |
|
"api_key": api_key |
|
}, |
|
"timeout": 120, |
|
"max_retries": 5 |
|
} |
|
for model_name, config in models_data.items() |
|
for key_idx, api_key in enumerate(api_keys) |
|
] |
|
|
|
def generate_fallbacks_per_key(): |
|
fallbacks = [] |
|
excluded_models = {"compound-beta", "compound-beta-mini"} |
|
|
|
for model_name in models_data.keys(): |
|
if model_name in excluded_models: |
|
continue |
|
|
|
|
|
for key_idx in range(len(api_keys)): |
|
current_model = f"{model_name}_{key_idx}" if key_idx != 0 else f"{model_name}" |
|
fallback_versions = [ |
|
f"{model_name}_{other_key_idx}" if other_key_idx != 0 else f"{model_name}" |
|
for other_key_idx in range(len(api_keys)) |
|
if other_key_idx != key_idx |
|
] |
|
|
|
|
|
fallbacks.append({ |
|
current_model: fallback_versions |
|
}) |
|
|
|
return fallbacks |
|
|
|
fallbacks = generate_fallbacks_per_key() |
|
|
|
router = Router( |
|
model_list=model_list, |
|
fallbacks=fallbacks, |
|
num_retries=5, |
|
retry_after=10 |
|
) |
|
|
|
app.add_middleware( |
|
CORSMiddleware, |
|
allow_credentials=True, |
|
allow_headers=["*"], |
|
allow_methods=["GET", "POST"], |
|
allow_origins=["*"] |
|
) |
|
|
|
class ChatRequest(BaseModel): |
|
models: List[str] |
|
messages: List[ChatCompletionInputMessage] |
|
tools: Optional[List[ChatCompletionInputTool]] = None |
|
temperature: Optional[float] = None |
|
max_tokens: Optional[int] = None |
|
n: Optional[int] = None |
|
stream: Optional[bool] = None |
|
stop: Optional[List[str]] = None |
|
|
|
def clean_message(msg) -> dict: |
|
"""Convertit un message en dictionnaire, gérant différents types d'objets""" |
|
if hasattr(msg, 'model_dump'): |
|
|
|
return {k: v for k, v in msg.model_dump().items() if v is not None} |
|
elif hasattr(msg, '__dict__'): |
|
|
|
return {k: v for k, v in msg.__dict__.items() if v is not None} |
|
elif isinstance(msg, dict): |
|
|
|
return {k: v for k, v in msg.items() if v is not None} |
|
else: |
|
|
|
return dict(msg) |
|
|
|
@app.get("/") |
|
def main_page(): |
|
return {"status": "ok"} |
|
|
|
@app.post("/chat") |
|
def chat_with_groq(req: ChatRequest): |
|
models = req.models |
|
if len(models) == 1 and (models[0] == "" or models[0] not in models_data.keys()): |
|
raise HTTPException(400, detail="Empty model field") |
|
messages = [clean_message(m) for m in req.messages] |
|
if len(models) == 1: |
|
try: |
|
resp = router.completion(model=models[0], messages=messages, **req.model_dump(exclude={"models", "messages"}, exclude_defaults=True, exclude_none=True)) |
|
print("Asked to", models[0], ":", messages) |
|
return {"error": False, "content": resp.choices[0].message.content} |
|
except Exception as e: |
|
traceback.print_exception(e) |
|
return {"error": True, "content": "Aucune clé ne fonctionne avec le modèle sélectionné, patientez ...."} |
|
else: |
|
for model in models: |
|
if model not in models_data.keys(): |
|
print(f"Erreur: {model} n'existe pas") |
|
continue |
|
try: |
|
resp = router.completion(model=model, messages=messages, **req.model_dump(exclude={"models", "messages"}, exclude_defaults=True, exclude_none=True)) |
|
print("Asked to", models[0], ":", messages) |
|
return {"error": False, "content": resp.choices[0].message.content} |
|
except Exception as e: |
|
traceback.print_exception(e) |
|
continue |
|
return {"error": True, "content": "Tous les modèles n'ont pas fonctionné avec les différentes clé, patientez ...."} |