from fastapi import FastAPI, WebSocket from fastapi.responses import HTMLResponse import uvicorn import json from fastapi import Request, HTTPException, Header, Depends from pydantic import BaseModel from typing import List, Optional import uuid import time import asyncio from typing import Optional # Modern Python (3.10+) with Annotated from typing import Annotated class ChatMessage(BaseModel): role: str content: str class ChatCompletionRequest(BaseModel): model: str = "gemini-2.5-pro-exp-03-25" messages: List[ChatMessage] temperature: Optional[float] = 0.7 stream: Optional[bool] = False class ChatCompletionResponseChoice(BaseModel): index: int = 0 message: ChatMessage finish_reason: str = "stop" class ChatCompletionResponse(BaseModel): id: str object: str = "chat.completion" created: int model: str choices: List[ChatCompletionResponseChoice] app = FastAPI() class ConnectionManager: def __init__(self): self.active_connections = {} # WebSocket: source self.response_queues = {} # request_id: asyncio.Queue async def connect(self, websocket: WebSocket): await websocket.accept() self.active_connections[websocket] = None def set_source(self, websocket: WebSocket, source: str): if websocket in self.active_connections: self.active_connections[websocket] = source async def send_to_destination(self, destination: str, message: str): for ws, src in self.active_connections.items(): if src == destination: await ws.send_text(message) def remove(self, websocket: WebSocket): if websocket in self.active_connections: del self.active_connections[websocket] async def wait_for_response(self, request_id: str, timeout: int = 30): queue = asyncio.Queue(maxsize=1) self.response_queues[request_id] = queue try: return await asyncio.wait_for(queue.get(), timeout=timeout) finally: self.response_queues.pop(request_id, None) manager = ConnectionManager() @app.get("/") async def get(): return HTMLResponse("""

Chat Client

""") @app.get("/proxy") async def get_proxy(): return HTMLResponse("""

Proxy Client (Message Gateway)

Message Flow
""") @app.post("/v1/chat/completions", response_model=ChatCompletionResponse) async def chat_completions( request: ChatCompletionRequest, authorization: Annotated[Optional[str], Header()] = None # Correct format ): # Extract and validate API key if not authorization or not authorization.startswith("Bearer "): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Missing or invalid Authorization header" ) api_key = authorization[7:] # Remove "Bearer " prefix print("received key: " + api_key) request_id = str(uuid.uuid4()) proxy_ws = next((ws for ws, src in manager.active_connections.items() if src == 'proxy'), None) if not proxy_ws: raise HTTPException(503, "Proxy client not connected") user_message = next((m for m in request.messages if m.role == "user"), None) if not user_message: raise HTTPException(400, "No user message found") # Add API key to proxy message proxy_msg = { "request_id": request_id, "content": user_message.content, "source": "api", "destination": "proxy", "model": request.model, "temperature": request.temperature, "incomingKey": api_key # Critical addition } await proxy_ws.send_text(json.dumps(proxy_msg)) try: response_content = await manager.wait_for_response(request_id) except asyncio.TimeoutError: raise HTTPException(504, "Proxy response timeout") return ChatCompletionResponse( id=request_id, created=int(time.time()), model=request.model, choices=[ChatCompletionResponseChoice( message=ChatMessage(role="assistant", content=response_content) )] ) @app.websocket("/ws") async def websocket_endpoint(websocket: WebSocket): await manager.connect(websocket) try: # Handle initial source identification data = await websocket.receive_text() init_msg = json.loads(data) if 'source' in init_msg: manager.set_source(websocket, init_msg['source']) # Handle messages while True: message = await websocket.receive_text() msg_data = json.loads(message) # If this is a response to an API request if 'request_id' in msg_data and msg_data.get('destination') == 'api': queue = manager.response_queues.get(msg_data['request_id']) if queue: await queue.put(msg_data['content']) else: await manager.send_to_destination(msg_data['destination'], message) except Exception as e: manager.remove(websocket) await websocket.close() if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=8000)