Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -2,12 +2,41 @@ from fastapi import FastAPI, WebSocket
|
|
2 |
from fastapi.responses import HTMLResponse
|
3 |
import uvicorn
|
4 |
import json
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
|
6 |
app = FastAPI()
|
7 |
|
8 |
class ConnectionManager:
|
9 |
def __init__(self):
|
10 |
self.active_connections = {} # WebSocket: source
|
|
|
11 |
|
12 |
async def connect(self, websocket: WebSocket):
|
13 |
await websocket.accept()
|
@@ -26,6 +55,15 @@ class ConnectionManager:
|
|
26 |
if websocket in self.active_connections:
|
27 |
del self.active_connections[websocket]
|
28 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
manager = ConnectionManager()
|
30 |
|
31 |
@app.get("/")
|
@@ -244,6 +282,43 @@ async def get_proxy():
|
|
244 |
""")
|
245 |
|
246 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
247 |
@app.websocket("/ws")
|
248 |
async def websocket_endpoint(websocket: WebSocket):
|
249 |
await manager.connect(websocket)
|
@@ -253,13 +328,17 @@ async def websocket_endpoint(websocket: WebSocket):
|
|
253 |
init_msg = json.loads(data)
|
254 |
if 'source' in init_msg:
|
255 |
manager.set_source(websocket, init_msg['source'])
|
256 |
-
|
257 |
# Handle messages
|
258 |
while True:
|
259 |
message = await websocket.receive_text()
|
260 |
msg_data = json.loads(message)
|
261 |
-
|
262 |
-
|
|
|
|
|
|
|
|
|
|
|
263 |
except Exception as e:
|
264 |
manager.remove(websocket)
|
265 |
await websocket.close()
|
|
|
2 |
from fastapi.responses import HTMLResponse
|
3 |
import uvicorn
|
4 |
import json
|
5 |
+
from fastapi import Request, HTTPException
|
6 |
+
from pydantic import BaseModel
|
7 |
+
from typing import List, Optional
|
8 |
+
import uuid
|
9 |
+
import time
|
10 |
+
import asyncio
|
11 |
+
|
12 |
+
class ChatMessage(BaseModel):
|
13 |
+
role: str
|
14 |
+
content: str
|
15 |
+
|
16 |
+
class ChatCompletionRequest(BaseModel):
|
17 |
+
model: str = "gemini-2.5-pro-exp-03-25"
|
18 |
+
messages: List[ChatMessage]
|
19 |
+
temperature: Optional[float] = 0.7
|
20 |
+
stream: Optional[bool] = False
|
21 |
+
|
22 |
+
class ChatCompletionResponseChoice(BaseModel):
|
23 |
+
index: int = 0
|
24 |
+
message: ChatMessage
|
25 |
+
finish_reason: str = "stop"
|
26 |
+
|
27 |
+
class ChatCompletionResponse(BaseModel):
|
28 |
+
id: str
|
29 |
+
object: str = "chat.completion"
|
30 |
+
created: int
|
31 |
+
model: str
|
32 |
+
choices: List[ChatCompletionResponseChoice]
|
33 |
|
34 |
app = FastAPI()
|
35 |
|
36 |
class ConnectionManager:
|
37 |
def __init__(self):
|
38 |
self.active_connections = {} # WebSocket: source
|
39 |
+
self.response_queues = {} # request_id: asyncio.Queue
|
40 |
|
41 |
async def connect(self, websocket: WebSocket):
|
42 |
await websocket.accept()
|
|
|
55 |
if websocket in self.active_connections:
|
56 |
del self.active_connections[websocket]
|
57 |
|
58 |
+
async def wait_for_response(self, request_id: str, timeout: int = 30):
|
59 |
+
queue = asyncio.Queue(maxsize=1)
|
60 |
+
self.response_queues[request_id] = queue
|
61 |
+
try:
|
62 |
+
return await asyncio.wait_for(queue.get(), timeout=timeout)
|
63 |
+
finally:
|
64 |
+
self.response_queues.pop(request_id, None)
|
65 |
+
|
66 |
+
|
67 |
manager = ConnectionManager()
|
68 |
|
69 |
@app.get("/")
|
|
|
282 |
""")
|
283 |
|
284 |
|
285 |
+
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
|
286 |
+
async def chat_completions(request: ChatCompletionRequest):
|
287 |
+
request_id = str(uuid.uuid4())
|
288 |
+
# Find proxy connection
|
289 |
+
proxy_ws = next((ws for ws, src in manager.active_connections.items() if src == 'proxy'), None)
|
290 |
+
if not proxy_ws:
|
291 |
+
raise HTTPException(503, "Proxy client not connected")
|
292 |
+
# Get user message
|
293 |
+
user_message = next((m for m in request.messages if m.role == "user"), None)
|
294 |
+
if not user_message:
|
295 |
+
raise HTTPException(400, "No user message found")
|
296 |
+
# Send to proxy
|
297 |
+
proxy_msg = {
|
298 |
+
"request_id": request_id,
|
299 |
+
"content": user_message.content,
|
300 |
+
"source": "api",
|
301 |
+
"destination": "proxy",
|
302 |
+
"model": request.model,
|
303 |
+
"temperature": request.temperature
|
304 |
+
}
|
305 |
+
await proxy_ws.send_text(json.dumps(proxy_msg))
|
306 |
+
# Wait for response from proxy
|
307 |
+
try:
|
308 |
+
response_content = await manager.wait_for_response(request_id)
|
309 |
+
except asyncio.TimeoutError:
|
310 |
+
raise HTTPException(504, "Proxy response timeout")
|
311 |
+
# Return OpenAI-compatible response
|
312 |
+
return ChatCompletionResponse(
|
313 |
+
id=request_id,
|
314 |
+
created=int(time.time()),
|
315 |
+
model=request.model,
|
316 |
+
choices=[ChatCompletionResponseChoice(
|
317 |
+
message=ChatMessage(role="assistant", content=response_content)
|
318 |
+
)]
|
319 |
+
)
|
320 |
+
|
321 |
+
|
322 |
@app.websocket("/ws")
|
323 |
async def websocket_endpoint(websocket: WebSocket):
|
324 |
await manager.connect(websocket)
|
|
|
328 |
init_msg = json.loads(data)
|
329 |
if 'source' in init_msg:
|
330 |
manager.set_source(websocket, init_msg['source'])
|
|
|
331 |
# Handle messages
|
332 |
while True:
|
333 |
message = await websocket.receive_text()
|
334 |
msg_data = json.loads(message)
|
335 |
+
# If this is a response to an API request
|
336 |
+
if 'request_id' in msg_data and msg_data.get('destination') == 'api':
|
337 |
+
queue = manager.response_queues.get(msg_data['request_id'])
|
338 |
+
if queue:
|
339 |
+
await queue.put(msg_data['content'])
|
340 |
+
else:
|
341 |
+
await manager.send_to_destination(msg_data['destination'], message)
|
342 |
except Exception as e:
|
343 |
manager.remove(websocket)
|
344 |
await websocket.close()
|