File size: 5,935 Bytes
343c758 9afc631 343c758 9afc631 343c758 689776f 9afc631 343c758 9afc631 322e8fa 9afc631 343c758 9afc631 689776f 9afc631 343c758 9afc631 343c758 b2fc61a 9afc631 343c758 322e8fa 689776f 9afc631 322e8fa 9afc631 343c758 9afc631 322e8fa 9afc631 322e8fa 9afc631 322e8fa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 |
import json
import traceback
from fastapi import FastAPI, HTTPException
from dotenv import load_dotenv
import os
import re
from huggingface_hub import ChatCompletionInputMessage, ChatCompletionInputTool
import litellm
litellm.ssl_verify = False
from litellm.router import Router
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
from typing import List, Optional, Literal, Type, Union
load_dotenv()
app = FastAPI()
api_keys = []
for k,v in os.environ.items():
if re.match(r'^GROQ_\d+$', k):
api_keys.append(v)
models_data = {
"allam-2-7b": {"rpm": 30, "rpd": 7000, "tpm": 6000},
"compound-beta": {"rpm": 15, "rpd": 200, "tpm": 70000},
"compound-beta-mini": {"rpm": 15, "rpd": 200, "tpm": 70000},
"deepseek-r1-distill-llama-70b": {"rpm": 30, "rpd": 1000, "tpm": 6000},
"gemma2-9b-it": {"rpm": 30, "rpd": 14400, "tpm": 15000, "tpd": 500000},
"llama-3.1-8b-instant": {"rpm": 30, "rpd": 14400, "tpm": 6000, "tpd": 500000},
"llama-3.3-70b-versatile": {"rpm": 30, "rpd": 1000, "tpm": 12000, "tpd": 100000},
"llama3-70b-8192": {"rpm": 30, "rpd": 14400, "tpm": 6000, "tpd": 500000},
"llama3-8b-8192": {"rpm": 30, "rpd": 14400, "tpm": 6000, "tpd": 500000},
"meta-llama/llama-4-maverick-17b-128e-instruct": {"rpm": 30, "rpd": 1000, "tpm": 6000, "tpd": None},
"meta-llama/llama-4-scout-17b-16e-instruct": {"rpm": 30, "rpd": 1000, "tpm": 30000, "tpd": None},
"meta-llama/llama-guard-4-12b": {"rpm": 30, "rpd": 14400, "tpm": 15000, "tpd": 500000},
"meta-llama/llama-prompt-guard-2-22m": {"rpm": 30, "rpd": 14400, "tpm": 15000, "tpd": None},
"meta-llama/llama-prompt-guard-2-86m": {"rpm": 30, "rpd": 14400, "tpm": None, "tpd": None},
}
model_list = [
{
"model_name": f"{model_name}_{key_idx}" if key_idx != 0 else f"{model_name}", # Nom unique par clé
"litellm_params": {
"model": f"groq/{model_name}",
"api_key": api_key
},
"timeout": 120,
"max_retries": 5
}
for model_name, config in models_data.items()
for key_idx, api_key in enumerate(api_keys)
]
def generate_fallbacks_per_key():
fallbacks = [] # Liste de dictionnaires au lieu d'un dictionnaire
excluded_models = {"compound-beta", "compound-beta-mini"}
for model_name in models_data.keys():
if model_name in excluded_models:
continue
# Pour chaque version d'un modèle, les fallbacks sont les autres versions du même modèle
for key_idx in range(len(api_keys)):
current_model = f"{model_name}_{key_idx}" if key_idx != 0 else f"{model_name}"
fallback_versions = [
f"{model_name}_{other_key_idx}" if other_key_idx != 0 else f"{model_name}"
for other_key_idx in range(len(api_keys))
if other_key_idx != key_idx
]
# Format attendu par LiteLLM
fallbacks.append({
current_model: fallback_versions
})
return fallbacks
fallbacks = generate_fallbacks_per_key()
router = Router(
model_list=model_list,
fallbacks=fallbacks,
num_retries=5,
retry_after=10
)
app.add_middleware(
CORSMiddleware,
allow_credentials=True,
allow_headers=["*"],
allow_methods=["GET", "POST"],
allow_origins=["*"]
)
class ChatRequest(BaseModel):
models: List[str]
messages: List[ChatCompletionInputMessage]
tools: Optional[List[ChatCompletionInputTool]] = None
temperature: Optional[float] = None
max_tokens: Optional[int] = None
n: Optional[int] = None
stream: Optional[bool] = None
stop: Optional[List[str]] = None
def clean_message(msg) -> dict:
"""Convertit un message en dictionnaire, gérant différents types d'objets"""
if hasattr(msg, 'model_dump'):
# Pour les objets Pydantic
return {k: v for k, v in msg.model_dump().items() if v is not None}
elif hasattr(msg, '__dict__'):
# Pour les objets avec attributs
return {k: v for k, v in msg.__dict__.items() if v is not None}
elif isinstance(msg, dict):
# Si c'est déjà un dictionnaire
return {k: v for k, v in msg.items() if v is not None}
else:
# Conversion générique
return dict(msg)
@app.get("/")
def main_page():
return {"status": "ok"}
@app.post("/chat")
def chat_with_groq(req: ChatRequest):
models = req.models
if len(models) == 1 and (models[0] == "" or models[0] not in models_data.keys()):
raise HTTPException(400, detail="Empty model field")
messages = [clean_message(m) for m in req.messages]
if len(models) == 1:
try:
resp = router.completion(model=models[0], messages=messages, **req.model_dump(exclude={"models", "messages"}, exclude_defaults=True, exclude_none=True))
print("Asked to", models[0], ":", messages)
return {"error": False, "content": resp.choices[0].message.content}
except Exception as e:
traceback.print_exception(e)
return {"error": True, "content": "Aucune clé ne fonctionne avec le modèle sélectionné, patientez ...."}
else:
for model in models:
if model not in models_data.keys():
print(f"Erreur: {model} n'existe pas")
continue
try:
resp = router.completion(model=model, messages=messages, **req.model_dump(exclude={"models", "messages"}, exclude_defaults=True, exclude_none=True))
print("Asked to", models[0], ":", messages)
return {"error": False, "content": resp.choices[0].message.content}
except Exception as e:
traceback.print_exception(e)
continue
return {"error": True, "content": "Tous les modèles n'ont pas fonctionné avec les différentes clé, patientez ...."} |