om4r932 commited on
Commit
9afc631
·
1 Parent(s): ab68772

Changed to LiteLLM

Browse files
Files changed (1) hide show
  1. app.py +103 -48
app.py CHANGED
@@ -1,12 +1,16 @@
1
  import json
 
2
  from fastapi import FastAPI, HTTPException
3
  from dotenv import load_dotenv
4
  import os
5
  import re
6
- import requests
 
 
 
7
  from fastapi.middleware.cors import CORSMiddleware
8
  from pydantic import BaseModel, Field
9
- from typing import List, Optional, Literal, Union
10
 
11
  load_dotenv()
12
 
@@ -18,6 +22,65 @@ for k,v in os.environ.items():
18
  if re.match(r'^GROQ_\d+$', k):
19
  api_keys.append(v)
20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  app.add_middleware(
22
  CORSMiddleware,
23
  allow_credentials=True,
@@ -26,64 +89,56 @@ app.add_middleware(
26
  allow_origins=["*"]
27
  )
28
 
29
- class ChatMessage(BaseModel):
30
- role: Literal["system", "user", "assistant", "tool"]
31
- content: Optional[str] # Null pour certains messages (ex: tool calls)
32
- name: Optional[str] = None
33
- function_call: Optional[dict] = None # Déprécié
34
- tool_call_id: Optional[str] = None
35
- tool_calls: Optional[List[dict]] = None
36
-
37
  class ChatRequest(BaseModel):
38
- models: Optional[List[str]] = []
39
- messages: List[ChatMessage]
40
- temperature: Optional[float] = Field(default=1.0, ge=0.0, le=2.0)
41
- top_p: Optional[float] = Field(default=1.0, ge=0.0, le=1.0)
42
- n: Optional[int] = Field(default=1, ge=1)
43
- stream: Optional[bool] = False
44
- stop: Optional[Union[str, List[str]]] = None
45
  max_tokens: Optional[int] = None
46
- presence_penalty: Optional[float] = Field(default=0.0, ge=-2.0, le=2.0)
47
- frequency_penalty: Optional[float] = Field(default=0.0, ge=-2.0, le=2.0)
48
- logit_bias: Optional[dict] = None
49
- user: Optional[str] = None
50
- tools: Optional[List[dict]] = None
51
- tool_choice: Optional[Union[str, dict]] = None
52
 
53
- def clean_message(msg: ChatMessage) -> dict:
54
- return {k: v for k, v in msg.model_dump().items() if v is not None}
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
  @app.get("/")
57
  def main_page():
58
  return {"status": "ok"}
59
 
60
  @app.post("/chat")
61
- def ask_groq_llm(req: ChatRequest):
62
  models = req.models
63
  if len(models) == 1 and models[0] == "":
64
  raise HTTPException(400, detail="Empty model field")
65
  messages = [clean_message(m) for m in req.messages]
66
- looping = True
67
  if len(models) == 1:
68
- while looping:
69
- for key in api_keys:
70
- resp = requests.post("https://api.groq.com/openai/v1/chat/completions", verify=False, headers={"Content-Type": "application/json", "Authorization": f"Bearer {key}"}, data=json.dumps({"model": models[0], "messages": messages}))
71
- if resp.status_code == 200:
72
- respJson = resp.json()
73
- print("Asked to", models[0], "with the key ID", str(api_keys.index(key)+1), ":", messages)
74
- return {"error": False, "content": respJson["choices"]}
75
- print(resp.status_code, resp.text)
76
- looping = False
77
- return {"error": True, "content": "Aucun des modèles, ni des clés ne fonctionne, patientez ...."}
78
  else:
79
- while looping:
80
- for model in models:
81
- for key in api_keys:
82
- resp = requests.post("https://api.groq.com/openai/v1/chat/completions", verify=False, headers={"Content-Type": "application/json", "Authorization": f"Bearer {key}"}, data=json.dumps({"model": model, "messages": messages}))
83
- if resp.status_code == 200:
84
- respJson = resp.json()
85
- print("Asked to", model, "with the key ID", str(api_keys.index(key)+1), ":", messages)
86
- return {"error": False, "content": respJson["choices"]}
87
- print(resp.status_code, resp.text)
88
- looping = False
89
- return {"error": True, "content": "Aucun des modèles, ni des clés ne fonctionne, patientez ...."}
 
1
  import json
2
+ import traceback
3
  from fastapi import FastAPI, HTTPException
4
  from dotenv import load_dotenv
5
  import os
6
  import re
7
+ from huggingface_hub import ChatCompletionInputMessage, ChatCompletionInputTool
8
+ import litellm
9
+ litellm.ssl_verify = False
10
+ from litellm.router import Router
11
  from fastapi.middleware.cors import CORSMiddleware
12
  from pydantic import BaseModel, Field
13
+ from typing import List, Optional, Literal, Type, Union
14
 
15
  load_dotenv()
16
 
 
22
  if re.match(r'^GROQ_\d+$', k):
23
  api_keys.append(v)
24
 
25
+ models_data = {
26
+ "allam-2-7b": {"rpm": 30, "rpd": 7000, "tpm": 6000},
27
+ "compound-beta": {"rpm": 15, "rpd": 200, "tpm": 70000},
28
+ "compound-beta-mini": {"rpm": 15, "rpd": 200, "tpm": 70000},
29
+ "deepseek-r1-distill-llama-70b": {"rpm": 30, "rpd": 1000, "tpm": 6000},
30
+ "gemma2-9b-it": {"rpm": 30, "rpd": 14400, "tpm": 15000, "tpd": 500000},
31
+ "llama-3.1-8b-instant": {"rpm": 30, "rpd": 14400, "tpm": 6000, "tpd": 500000},
32
+ "llama-3.3-70b-versatile": {"rpm": 30, "rpd": 1000, "tpm": 12000, "tpd": 100000},
33
+ "llama3-70b-8192": {"rpm": 30, "rpd": 14400, "tpm": 6000, "tpd": 500000},
34
+ "llama3-8b-8192": {"rpm": 30, "rpd": 14400, "tpm": 6000, "tpd": 500000},
35
+ }
36
+
37
+ model_list = [
38
+ {
39
+ "model_name": f"{model_name}_{key_idx}" if key_idx != 0 else f"{model_name}", # Nom unique par clé
40
+ "litellm_params": {
41
+ "model": f"groq/{model_name}",
42
+ "api_key": api_key
43
+ },
44
+ "timeout": 120,
45
+ "max_retries": 5
46
+ }
47
+ for model_name, config in models_data.items()
48
+ for key_idx, api_key in enumerate(api_keys)
49
+ ]
50
+
51
+ def generate_fallbacks_per_key():
52
+ fallbacks = [] # Liste de dictionnaires au lieu d'un dictionnaire
53
+ excluded_models = {"compound-beta", "compound-beta-mini"}
54
+
55
+ for model_name in models_data.keys():
56
+ if model_name in excluded_models:
57
+ continue
58
+
59
+ # Pour chaque version d'un modèle, les fallbacks sont les autres versions du même modèle
60
+ for key_idx in range(len(api_keys)):
61
+ current_model = f"{model_name}_{key_idx}" if key_idx != 0 else f"{model_name}"
62
+ fallback_versions = [
63
+ f"{model_name}_{other_key_idx}" if other_key_idx != 0 else f"{model_name}"
64
+ for other_key_idx in range(len(api_keys))
65
+ if other_key_idx != key_idx
66
+ ]
67
+
68
+ # Format attendu par LiteLLM
69
+ fallbacks.append({
70
+ current_model: fallback_versions
71
+ })
72
+
73
+ return fallbacks
74
+
75
+ fallbacks = generate_fallbacks_per_key()
76
+
77
+ router = Router(
78
+ model_list=model_list,
79
+ fallbacks=fallbacks,
80
+ num_retries=5,
81
+ retry_after=10
82
+ )
83
+
84
  app.add_middleware(
85
  CORSMiddleware,
86
  allow_credentials=True,
 
89
  allow_origins=["*"]
90
  )
91
 
 
 
 
 
 
 
 
 
92
  class ChatRequest(BaseModel):
93
+ models: List[str]
94
+ messages: List[ChatCompletionInputMessage]
95
+ tools: Optional[List[ChatCompletionInputTool]] = None
96
+ temperature: Optional[float] = None
 
 
 
97
  max_tokens: Optional[int] = None
98
+ n: Optional[int] = None
99
+ stream: Optional[bool] = None
100
+ stop: Optional[List[str]] = None
 
 
 
101
 
102
+ def clean_message(msg) -> dict:
103
+ """Convertit un message en dictionnaire, gérant différents types d'objets"""
104
+ if hasattr(msg, 'model_dump'):
105
+ # Pour les objets Pydantic
106
+ return {k: v for k, v in msg.model_dump().items() if v is not None}
107
+ elif hasattr(msg, '__dict__'):
108
+ # Pour les objets avec attributs
109
+ return {k: v for k, v in msg.__dict__.items() if v is not None}
110
+ elif isinstance(msg, dict):
111
+ # Si c'est déjà un dictionnaire
112
+ return {k: v for k, v in msg.items() if v is not None}
113
+ else:
114
+ # Conversion générique
115
+ return dict(msg)
116
 
117
  @app.get("/")
118
  def main_page():
119
  return {"status": "ok"}
120
 
121
  @app.post("/chat")
122
+ def chat_with_groq(req: ChatRequest):
123
  models = req.models
124
  if len(models) == 1 and models[0] == "":
125
  raise HTTPException(400, detail="Empty model field")
126
  messages = [clean_message(m) for m in req.messages]
 
127
  if len(models) == 1:
128
+ resp = router.completion(model=models[0], messages=messages, **req.model_dump(exclude={"models", "messages"}, exclude_defaults=True, exclude_none=True))
129
+ try:
130
+ print("Asked to", models[0], ":", messages)
131
+ return {"error": False, "content": resp.choices[0].message.content}
132
+ except Exception as e:
133
+ traceback.print_exception(e)
134
+ return {"error": True, "content": "Aucune clé ne fonctionne avec le modèle sélectionné, patientez ...."}
 
 
 
135
  else:
136
+ for model in models:
137
+ resp = router.completion(model=model, messages=messages, **req.model_dump(exclude={"models", "messages"}, exclude_defaults=True, exclude_none=True))
138
+ try:
139
+ print("Asked to", models[0], ":", messages)
140
+ return {"error": False, "content": resp.choices[0].message.content}
141
+ except Exception as e:
142
+ traceback.print_exception(e)
143
+ continue
144
+ return {"error": True, "content": "Aucune clé ne fonctionne avec le modèle sélectionné, patientez ...."}