from fastapi import FastAPI, WebSocket
from fastapi.responses import HTMLResponse
import uvicorn
import json
from fastapi import Request, HTTPException, Header, Depends
from pydantic import BaseModel
from typing import List, Optional
import uuid
import time
import asyncio
from typing import Optional
# Modern Python (3.10+) with Annotated
from typing import Annotated
class ChatMessage(BaseModel):
role: str
content: str
class ChatCompletionRequest(BaseModel):
model: str = "gemini-2.5-pro-exp-03-25"
messages: List[ChatMessage]
temperature: Optional[float] = 0.7
stream: Optional[bool] = False
class ChatCompletionResponseChoice(BaseModel):
index: int = 0
message: ChatMessage
finish_reason: str = "stop"
class ChatCompletionResponse(BaseModel):
id: str
object: str = "chat.completion"
created: int
model: str
choices: List[ChatCompletionResponseChoice]
app = FastAPI()
class ConnectionManager:
def __init__(self):
self.active_connections = {} # WebSocket: source
self.response_queues = {} # request_id: asyncio.Queue
async def connect(self, websocket: WebSocket):
await websocket.accept()
self.active_connections[websocket] = None
def set_source(self, websocket: WebSocket, source: str):
if websocket in self.active_connections:
self.active_connections[websocket] = source
async def send_to_destination(self, destination: str, message: str):
for ws, src in self.active_connections.items():
if src == destination:
await ws.send_text(message)
def remove(self, websocket: WebSocket):
if websocket in self.active_connections:
del self.active_connections[websocket]
async def wait_for_response(self, request_id: str, timeout: int = 30):
queue = asyncio.Queue(maxsize=1)
self.response_queues[request_id] = queue
try:
return await asyncio.wait_for(queue.get(), timeout=timeout)
finally:
self.response_queues.pop(request_id, None)
manager = ConnectionManager()
@app.get("/")
async def get():
return HTMLResponse("""
Chat Client
""")
@app.get("/proxy")
async def get_proxy():
return HTMLResponse("""
Proxy Client (Message Gateway)
""")
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
async def chat_completions(
request: ChatCompletionRequest,
authorization: Annotated[Optional[str], Header()] = None # Correct format
):
# Extract and validate API key
if not authorization or not authorization.startswith("Bearer "):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Missing or invalid Authorization header"
)
api_key = authorization[7:] # Remove "Bearer " prefix
print("received key: " + api_key)
request_id = str(uuid.uuid4())
proxy_ws = next((ws for ws, src in manager.active_connections.items() if src == 'proxy'), None)
if not proxy_ws:
raise HTTPException(503, "Proxy client not connected")
user_message = next((m for m in request.messages if m.role == "user"), None)
if not user_message:
raise HTTPException(400, "No user message found")
# Add API key to proxy message
proxy_msg = {
"request_id": request_id,
"content": user_message.content,
"source": "api",
"destination": "proxy",
"model": request.model,
"temperature": request.temperature,
"incomingKey": api_key # Critical addition
}
await proxy_ws.send_text(json.dumps(proxy_msg))
try:
response_content = await manager.wait_for_response(request_id)
except asyncio.TimeoutError:
raise HTTPException(504, "Proxy response timeout")
return ChatCompletionResponse(
id=request_id,
created=int(time.time()),
model=request.model,
choices=[ChatCompletionResponseChoice(
message=ChatMessage(role="assistant", content=response_content)
)]
)
@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
await manager.connect(websocket)
try:
# Handle initial source identification
data = await websocket.receive_text()
init_msg = json.loads(data)
if 'source' in init_msg:
manager.set_source(websocket, init_msg['source'])
# Handle messages
while True:
message = await websocket.receive_text()
msg_data = json.loads(message)
# If this is a response to an API request
if 'request_id' in msg_data and msg_data.get('destination') == 'api':
queue = manager.response_queues.get(msg_data['request_id'])
if queue:
await queue.put(msg_data['content'])
else:
await manager.send_to_destination(msg_data['destination'], message)
except Exception as e:
manager.remove(websocket)
await websocket.close()
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)