File size: 5,935 Bytes
343c758
9afc631
343c758
 
 
 
9afc631
 
 
 
343c758
689776f
9afc631
343c758
 
 
 
 
 
 
 
 
 
 
9afc631
 
 
 
 
 
 
 
 
 
322e8fa
 
 
 
 
9afc631
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
343c758
 
 
 
 
 
 
 
 
9afc631
 
 
 
689776f
9afc631
 
 
343c758
9afc631
 
 
 
 
 
 
 
 
 
 
 
 
 
343c758
 
 
 
 
b2fc61a
9afc631
343c758
322e8fa
689776f
 
 
9afc631
322e8fa
9afc631
 
 
 
 
343c758
9afc631
322e8fa
 
 
9afc631
322e8fa
9afc631
 
 
 
 
322e8fa
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
import json
import traceback
from fastapi import FastAPI, HTTPException
from dotenv import load_dotenv
import os
import re
from huggingface_hub import ChatCompletionInputMessage, ChatCompletionInputTool
import litellm
litellm.ssl_verify = False
from litellm.router import Router
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
from typing import List, Optional, Literal, Type, Union

load_dotenv()

app = FastAPI()

api_keys = []

for k,v in os.environ.items():
    if re.match(r'^GROQ_\d+$', k):
        api_keys.append(v)

models_data = {
    "allam-2-7b": {"rpm": 30, "rpd": 7000, "tpm": 6000},
    "compound-beta": {"rpm": 15, "rpd": 200, "tpm": 70000},
    "compound-beta-mini": {"rpm": 15, "rpd": 200, "tpm": 70000},
    "deepseek-r1-distill-llama-70b": {"rpm": 30, "rpd": 1000, "tpm": 6000},
    "gemma2-9b-it": {"rpm": 30, "rpd": 14400, "tpm": 15000, "tpd": 500000},
    "llama-3.1-8b-instant": {"rpm": 30, "rpd": 14400, "tpm": 6000, "tpd": 500000},
    "llama-3.3-70b-versatile": {"rpm": 30, "rpd": 1000, "tpm": 12000, "tpd": 100000},
    "llama3-70b-8192": {"rpm": 30, "rpd": 14400, "tpm": 6000, "tpd": 500000},
    "llama3-8b-8192": {"rpm": 30, "rpd": 14400, "tpm": 6000, "tpd": 500000},
    "meta-llama/llama-4-maverick-17b-128e-instruct": {"rpm": 30, "rpd": 1000, "tpm": 6000, "tpd": None},
    "meta-llama/llama-4-scout-17b-16e-instruct": {"rpm": 30, "rpd": 1000, "tpm": 30000, "tpd": None},
    "meta-llama/llama-guard-4-12b": {"rpm": 30, "rpd": 14400, "tpm": 15000, "tpd": 500000},
    "meta-llama/llama-prompt-guard-2-22m": {"rpm": 30, "rpd": 14400, "tpm": 15000, "tpd": None},
    "meta-llama/llama-prompt-guard-2-86m": {"rpm": 30, "rpd": 14400, "tpm": None, "tpd": None},
}

model_list = [
    {
        "model_name": f"{model_name}_{key_idx}" if key_idx != 0 else f"{model_name}",  # Nom unique par clé
        "litellm_params": {
            "model": f"groq/{model_name}",
            "api_key": api_key
        },
        "timeout": 120,
        "max_retries": 5
    }
    for model_name, config in models_data.items()
    for key_idx, api_key in enumerate(api_keys)
]

def generate_fallbacks_per_key():
    fallbacks = []  # Liste de dictionnaires au lieu d'un dictionnaire
    excluded_models = {"compound-beta", "compound-beta-mini"}
    
    for model_name in models_data.keys():
        if model_name in excluded_models:
            continue
            
        # Pour chaque version d'un modèle, les fallbacks sont les autres versions du même modèle
        for key_idx in range(len(api_keys)):
            current_model = f"{model_name}_{key_idx}" if key_idx != 0 else f"{model_name}"
            fallback_versions = [
                f"{model_name}_{other_key_idx}" if other_key_idx != 0 else f"{model_name}"
                for other_key_idx in range(len(api_keys)) 
                if other_key_idx != key_idx
            ]
            
            # Format attendu par LiteLLM
            fallbacks.append({
                current_model: fallback_versions
            })
    
    return fallbacks

fallbacks = generate_fallbacks_per_key()

router = Router(
    model_list=model_list,
    fallbacks=fallbacks,
    num_retries=5,
    retry_after=10
)

app.add_middleware(
    CORSMiddleware,
    allow_credentials=True,
    allow_headers=["*"],
    allow_methods=["GET", "POST"],
    allow_origins=["*"]
)

class ChatRequest(BaseModel):
    models: List[str]
    messages: List[ChatCompletionInputMessage]
    tools: Optional[List[ChatCompletionInputTool]] = None
    temperature: Optional[float] = None
    max_tokens: Optional[int] = None
    n: Optional[int] = None
    stream: Optional[bool] = None
    stop: Optional[List[str]] = None

def clean_message(msg) -> dict:
    """Convertit un message en dictionnaire, gérant différents types d'objets"""
    if hasattr(msg, 'model_dump'):
        # Pour les objets Pydantic
        return {k: v for k, v in msg.model_dump().items() if v is not None}
    elif hasattr(msg, '__dict__'):
        # Pour les objets avec attributs
        return {k: v for k, v in msg.__dict__.items() if v is not None}
    elif isinstance(msg, dict):
        # Si c'est déjà un dictionnaire
        return {k: v for k, v in msg.items() if v is not None}
    else:
        # Conversion générique
        return dict(msg)

@app.get("/")
def main_page():
    return {"status": "ok"}

@app.post("/chat")
def chat_with_groq(req: ChatRequest):
    models = req.models
    if len(models) == 1 and (models[0] == "" or models[0] not in models_data.keys()):
        raise HTTPException(400, detail="Empty model field")
    messages = [clean_message(m) for m in req.messages]
    if len(models) == 1:
        try:
            resp = router.completion(model=models[0], messages=messages, **req.model_dump(exclude={"models", "messages"}, exclude_defaults=True, exclude_none=True))
            print("Asked to", models[0], ":", messages)
            return {"error": False, "content": resp.choices[0].message.content}
        except Exception as e:
            traceback.print_exception(e)
            return {"error": True, "content": "Aucune clé ne fonctionne avec le modèle sélectionné, patientez ...."}
    else:
        for model in models:
            if model not in models_data.keys():
                print(f"Erreur: {model} n'existe pas")
                continue
            try:
                resp = router.completion(model=model, messages=messages, **req.model_dump(exclude={"models", "messages"}, exclude_defaults=True, exclude_none=True))
                print("Asked to", models[0], ":", messages)
                return {"error": False, "content": resp.choices[0].message.content}
            except Exception as e:
                traceback.print_exception(e)
                continue
        return {"error": True, "content": "Tous les modèles n'ont pas fonctionné avec les différentes clé, patientez ...."}