Spaces:
Sleeping
Sleeping
from fastapi import FastAPI, WebSocket | |
from fastapi.responses import HTMLResponse | |
import uvicorn | |
import json | |
from fastapi import Request, HTTPException, Header, Depends | |
from pydantic import BaseModel | |
from typing import List, Optional | |
import uuid | |
import time | |
import asyncio | |
from typing import Optional | |
# Modern Python (3.10+) with Annotated | |
from typing import Annotated | |
class ChatMessage(BaseModel): | |
role: str | |
content: str | |
class ChatCompletionRequest(BaseModel): | |
model: str = "gemini-2.5-pro-exp-03-25" | |
messages: List[ChatMessage] | |
temperature: Optional[float] = 0.7 | |
stream: Optional[bool] = False | |
class ChatCompletionResponseChoice(BaseModel): | |
index: int = 0 | |
message: ChatMessage | |
finish_reason: str = "stop" | |
class ChatCompletionResponse(BaseModel): | |
id: str | |
object: str = "chat.completion" | |
created: int | |
model: str | |
choices: List[ChatCompletionResponseChoice] | |
app = FastAPI() | |
class ConnectionManager: | |
def __init__(self): | |
self.active_connections = {} # WebSocket: source | |
self.response_queues = {} # request_id: asyncio.Queue | |
async def connect(self, websocket: WebSocket): | |
await websocket.accept() | |
self.active_connections[websocket] = None | |
def set_source(self, websocket: WebSocket, source: str): | |
if websocket in self.active_connections: | |
self.active_connections[websocket] = source | |
async def send_to_destination(self, destination: str, message: str): | |
for ws, src in self.active_connections.items(): | |
if src == destination: | |
await ws.send_text(message) | |
def remove(self, websocket: WebSocket): | |
if websocket in self.active_connections: | |
del self.active_connections[websocket] | |
async def wait_for_response(self, request_id: str, timeout: int = 30): | |
queue = asyncio.Queue(maxsize=1) | |
self.response_queues[request_id] = queue | |
try: | |
return await asyncio.wait_for(queue.get(), timeout=timeout) | |
finally: | |
self.response_queues.pop(request_id, None) | |
manager = ConnectionManager() | |
async def get(): | |
return HTMLResponse(""" | |
<html> | |
<body> | |
<h1>Chat Client</h1> | |
<div id="chat" style="height:300px;overflow-y:scroll"></div> | |
<input id="msg" type="text"> | |
<button onclick="send()">Send</button> | |
<script> | |
const ws = new WebSocket('wss://' + window.location.host + '/ws'); | |
ws.onopen = () => { | |
ws.send(JSON.stringify({ source: 'user' })); | |
}; | |
ws.onmessage = e => { | |
const msg = JSON.parse(e.data); | |
document.getElementById('chat').innerHTML += | |
`<div>${msg.content}</div>`; | |
}; | |
const send = () => { | |
const message = { | |
content: document.getElementById('msg').value, | |
source: 'user', | |
destination: 'proxy' | |
}; | |
ws.send(JSON.stringify(message)); | |
document.getElementById('msg').value = ''; | |
}; | |
</script> | |
</body> | |
</html> | |
""") | |
async def get_proxy(): | |
return HTMLResponse(""" | |
<html> | |
<body> | |
<h1>Proxy Client (Message Gateway)</h1> | |
<div style="display: grid; grid-template-columns: 1fr 1fr; gap: 20px; height: 80vh;"> | |
<!-- Connection Panel --> | |
<div style="border-right: 1px solid #ccc; padding-right: 20px;"> | |
<div style="margin-bottom: 20px;"> | |
<div style="margin-bottom: 20px;"> | |
<input type="password" id="apiKey" placeholder="LLM API Key" style="width: 100%; margin-bottom: 8px;"> | |
<input type="password" id="incomingKey" placeholder="Proxy Incoming Key" style="width: 100%;"> | |
<button onclick="initializeClient()" style="margin-top: 10px;">Connect</button> | |
</div> | |
<button onclick="initializeClient()" style="margin-top: 10px;">Fetch Models</button> | |
</div> | |
<select id="modelSelect" style="width: 100%; margin-bottom: 20px;"></select> | |
<div id="systemStatus" style="color: #666; font-size: 0.9em;"></div> | |
</div> | |
<!-- Message Flow Visualization --> | |
<div style="display: flex; flex-direction: column; height: 100%;"> | |
<div id="messageFlow" style="flex: 1; border: 1px solid #eee; padding: 10px; overflow-y: auto; background: #f9f9f9;"> | |
<div style="text-align: center; color: #999; margin-bottom: 10px;">Message Flow</div> | |
</div> | |
<div id="detailedStatus" style="color: #666; font-size: 0.9em; margin-top: 10px;"></div> | |
</div> | |
</div> | |
<style> | |
.message-entry { | |
margin: 5px 0; | |
padding: 8px; | |
border-radius: 8px; | |
background: white; | |
box-shadow: 0 2px 4px rgba(0,0,0,0.1); | |
font-family: monospace; | |
} | |
.incoming { border-left: 4px solid #4CAF50; } | |
.outgoing { border-left: 4px solid #2196F3; } | |
.system { border-left: 4px solid #9C27B0; } | |
.error { border-left: 4px solid #F44336; } | |
.message-header { | |
display: flex; | |
justify-content: space-between; | |
font-size: 0.8em; | |
color: #666; | |
margin-bottom: 4px; | |
} | |
</style> | |
<script> | |
let agentClient = null; | |
let currentModel = null; | |
const systemPrompt = "You are a helpful AI assistant. Respond concisely and accurately."; | |
const conversationHistory = []; | |
// Modified model initialization | |
function showStatus(message, type = 'info') { | |
const statusDiv = document.getElementById('systemStatus'); | |
statusDiv.innerHTML = `<div style="color: ${type === 'error' ? '#F44336' : '#4CAF50'}">${message}</div>`; | |
addMessageEntry('system', 'system', 'proxy', message); | |
} | |
function initializeClient() { | |
const apiKey = document.getElementById('apiKey').value; | |
if (!apiKey) { | |
showStatus("Please enter an API key", 'error'); | |
return; | |
} | |
agentClient = new ConversationalAgentClient(apiKey); | |
agentClient.populateLLMModels() | |
.then(models => { | |
agentClient.updateModelSelect('modelSelect', models.find(m => m.includes("gemini-2.5"))); | |
currentModel = document.getElementById('modelSelect').value; | |
showStatus(`Loaded ${models.length} models. Default: ${currentModel}`); | |
}) | |
.catch(error => { | |
showStatus(`Error fetching models: ${error.message}`, 'error'); | |
}); | |
} | |
// Model selection change handler | |
document.getElementById('modelSelect').addEventListener('change', function() { | |
currentModel = this.value; | |
showStatus(`Model changed to: ${currentModel}`); | |
}); | |
// --- Include provided client classes here --- | |
// --- API Client Classes --- (Keep existing classes BaseAgentClient, ConversationalAgentClient) | |
class BaseAgentClient { | |
constructor(apiKey, apiUrl = 'https://llm.synapse.thalescloud.io/v1/') { this.apiKey = apiKey; this.apiUrl = apiUrl; this.models = []; this.maxCallsPerMinute = 4; this.callTimestamps = []; } | |
async fetchLLMModels() { if (!this.apiKey) throw new Error("API Key is not set."); console.log("Fetching models from:", this.apiUrl + 'models'); try { const response = await fetch(this.apiUrl + 'models', { method: 'GET', headers: { 'Authorization': `Bearer ${this.apiKey}` } }); if (!response.ok) { const errorText = await response.text(); console.error("Fetch models error response:", errorText); throw new Error(`HTTP error! Status: ${response.status} - ${errorText}`); } const data = await response.json(); console.log("Models fetched:", data.data); const filteredModels = data.data.map(model => model.id).filter(id => !id.toLowerCase().includes('embed') && !id.toLowerCase().includes('image')); return filteredModels; } catch (error) { console.error('Error fetching LLM models:', error); throw new Error(`Failed to fetch models: ${error.message}`); } } | |
async populateLLMModels(defaultModel = "gemini-2.5-pro-exp-03-25") { try { const modelList = await this.fetchLLMModels(); const sortedModels = modelList.sort((a, b) => { if (a === defaultModel) return -1; if (b === defaultModel) return 1; return a.localeCompare(b); }); const finalModels = []; if (sortedModels.includes(defaultModel)) { finalModels.push(defaultModel); sortedModels.forEach(model => { if (model !== defaultModel) finalModels.push(model); }); } else { finalModels.push(defaultModel); finalModels.push(...sortedModels); } this.models = finalModels; console.log("Populated models:", this.models); return this.models; } catch (error) { console.error("Error populating models:", error); this.models = [defaultModel]; throw error; } } | |
updateModelSelect(elementId = 'modelSelect', selectedModel = null) { const select = document.getElementById(elementId); if (!select) { console.warn(`Element ID ${elementId} not found.`); return; } const currentSelection = selectedModel || select.value || this.models[0]; select.innerHTML = ''; if (this.models.length === 0 || (this.models.length === 1 && this.models[0] === "gemini-2.5-pro-exp-03-25" && !this.apiKey)) { const option = document.createElement('option'); option.value = ""; option.textContent = "-- Fetch models first --"; option.disabled = true; select.appendChild(option); return; } this.models.forEach(model => { const option = document.createElement('option'); option.value = model; option.textContent = model; if (model === currentSelection) option.selected = true; select.appendChild(option); }); if (!select.value && this.models.length > 0) select.value = this.models[0]; } | |
async rateLimitWait() { const currentTime = Date.now(); this.callTimestamps = this.callTimestamps.filter(ts => currentTime - ts <= 60000); if (this.callTimestamps.length >= this.maxCallsPerMinute) { const waitTime = 60000 - (currentTime - this.callTimestamps[0]); const waitSeconds = Math.ceil(waitTime / 1000); const waitMessage = `Rate limit (${this.maxCallsPerMinute}/min) reached. Waiting ${waitSeconds}s...`; console.log(waitMessage); showGenerationStatus(waitMessage, 'warn'); await new Promise(resolve => setTimeout(resolve, waitTime + 100)); showGenerationStatus('Resuming after rate limit wait...', 'info'); this.callTimestamps = this.callTimestamps.filter(ts => Date.now() - ts <= 60000); } } | |
async callAgent(model, messages, temperature = 0.7) { await this.rateLimitWait(); const startTime = Date.now(); console.log("Calling Agent:", model); try { const response = await fetch(this.apiUrl + 'chat/completions', { method: 'POST', headers: { 'Content-Type': 'application/json', 'Authorization': `Bearer ${this.apiKey}` }, body: JSON.stringify({ model: model, messages: messages, temperature: temperature }) }); const endTime = Date.now(); this.callTimestamps.push(endTime); console.log(`API call took ${endTime - startTime} ms`); if (!response.ok) { const errorData = await response.json().catch(() => ({ error: { message: response.statusText } })); console.error("API Error:", errorData); throw new Error(errorData.error?.message || `API failed: ${response.status}`); } const data = await response.json(); if (!data.choices || !data.choices[0]?.message) throw new Error("Invalid API response structure"); console.log("API Response received."); return data.choices[0].message.content; } catch (error) { this.callTimestamps.push(Date.now()); console.error('Error calling agent:', error); throw error; } } | |
setMaxCallsPerMinute(value) { const parsedValue = parseInt(value, 10); if (!isNaN(parsedValue) && parsedValue > 0) { console.log(`Max calls/min set to: ${parsedValue}`); this.maxCallsPerMinute = parsedValue; return true; } console.warn(`Invalid max calls/min: ${value}`); return false; } | |
} | |
class ConversationalAgentClient extends BaseAgentClient { | |
constructor(apiKey, apiUrl = 'https://llm.synapse.thalescloud.io/v1/') { | |
super(apiKey, apiUrl); | |
} | |
async call(model, userPrompt, systemPrompt, conversationHistory = [], temperature = 0.7) { | |
const messages = [ | |
{ role: 'system', content: systemPrompt }, | |
...conversationHistory, | |
{ role: 'user', content: userPrompt } | |
]; | |
const assistantResponse = await super.callAgent(model, messages, temperature); | |
const updatedHistory = [ | |
...conversationHistory, | |
{ role: 'user', content: userPrompt }, | |
{ role: 'assistant', content: assistantResponse } | |
]; | |
return { | |
response: assistantResponse, | |
history: updatedHistory | |
}; | |
} | |
async callWithCodeContext( | |
model, | |
userPrompt, | |
systemPrompt, | |
selectedCodeVersionsData = [], | |
conversationHistory = [], | |
temperature = 0.7 | |
) { | |
let codeContext = ""; | |
let fullSystemPrompt = systemPrompt || ""; | |
if (selectedCodeVersionsData && selectedCodeVersionsData.length > 0) { | |
codeContext = `Code context (chronological):\n\n`; | |
selectedCodeVersionsData.forEach((versionData, index) => { | |
if (versionData && typeof versionData.code === 'string') { | |
codeContext += `--- Part ${index + 1} (${versionData.version || '?'}) ---\n`; | |
codeContext += `${versionData.code}\n\n`; | |
} else { | |
console.warn(`Invalid context version data at index ${index}`); | |
} | |
}); | |
codeContext += `-------- end context ---\n\nUser request based on context:\n\n`; | |
} | |
const fullPrompt = codeContext + userPrompt; | |
const messages = [ | |
{ role: 'system', content: fullSystemPrompt }, | |
...conversationHistory, | |
{ role: 'user', content: fullPrompt } | |
]; | |
const assistantResponse = await super.callAgent(model, messages, temperature); | |
const updatedHistory = [ | |
...conversationHistory, | |
{ role: 'user', content: fullPrompt }, | |
{ role: 'assistant', content: assistantResponse } | |
]; | |
return { | |
response: assistantResponse, | |
history: updatedHistory | |
}; | |
} | |
} | |
function addMessageEntry(direction, source, destination, content) { | |
const flowDiv = document.getElementById('messageFlow'); | |
const timestamp = new Date().toLocaleTimeString(); | |
const entry = document.createElement('div'); | |
entry.className = `message-entry ${direction}`; | |
entry.innerHTML = ` | |
<div class="message-header"> | |
<span>${source} → ${destination}</span> | |
<span>${timestamp}</span> | |
</div> | |
<div style="white-space: pre-wrap;">${content}</div> | |
`; | |
flowDiv.appendChild(entry); | |
flowDiv.scrollTop = flowDiv.scrollHeight; | |
} | |
// Modified WebSocket handler | |
const ws = new WebSocket('wss://' + window.location.host + '/ws'); | |
ws.onopen = () => { | |
ws.send(JSON.stringify({ source: 'proxy' })); | |
}; | |
ws.onmessage = async e => { | |
const msg = JSON.parse(e.data); | |
// Display incoming messages | |
if (msg.destination === 'proxy') { | |
addMessageEntry('incoming', msg.source, 'proxy', msg.content); | |
document.getElementById('detailedStatus').textContent = `Processing ${msg.source} request...`; | |
// check if incoming call has the correct key | |
const expectedKey = document.getElementById('incomingKey').value; | |
if (!msg.incomingKey || msg.incomingKey !== expectedKey) { | |
ws.send(JSON.stringify({ | |
request_id: msg.request_id, | |
content: "Error: Invalid authentication: --expected" + expectedKey + "--received: " + msg.incomingKey, | |
source: 'proxy', | |
destination: msg.source | |
})); | |
return; | |
} | |
showStatus("Authentication ok"); | |
try { | |
const llmResponse = await agentClient.call(currentModel, msg.content, systemPrompt, conversationHistory); | |
// Display outgoing response | |
addMessageEntry('outgoing', 'proxy', msg.source, llmResponse.response); | |
const responseMsg = { | |
request_id: msg.request_id, // Critical addition | |
content: llmResponse.response, | |
source: 'proxy', | |
destination: msg.source | |
}; | |
ws.send(JSON.stringify(responseMsg)); | |
document.getElementById('detailedStatus').textContent = `Response sent to ${msg.source}`; | |
} catch (error) { | |
addMessageEntry('error', 'system', 'proxy', `Error: ${error.message}`); | |
const errorResponse = { | |
request_id: msg.request_id, // Critical addition | |
content: `Error: ${error.message}`, | |
source: 'proxy', | |
destination: msg.source | |
}; | |
ws.send(JSON.stringify(errorResponse)); | |
} | |
} | |
}; | |
// Modified model initialization | |
function showStatus(message, type = 'info') { | |
const statusDiv = document.getElementById('systemStatus'); | |
statusDiv.innerHTML = `<div style="color: ${type === 'error' ? '#F44336' : '#4CAF50'}">${message}</div>`; | |
addMessageEntry('system', 'system', 'proxy', message); | |
} | |
</script> | |
</body> | |
</html> | |
""") | |
async def chat_completions( | |
request: ChatCompletionRequest, | |
authorization: Annotated[Optional[str], Header()] = None # Correct format | |
): | |
# Extract and validate API key | |
if not authorization or not authorization.startswith("Bearer "): | |
raise HTTPException( | |
status_code=status.HTTP_401_UNAUTHORIZED, | |
detail="Missing or invalid Authorization header" | |
) | |
api_key = authorization[7:] # Remove "Bearer " prefix | |
print("received key: " + api_key) | |
request_id = str(uuid.uuid4()) | |
proxy_ws = next((ws for ws, src in manager.active_connections.items() if src == 'proxy'), None) | |
if not proxy_ws: | |
raise HTTPException(503, "Proxy client not connected") | |
user_message = next((m for m in request.messages if m.role == "user"), None) | |
if not user_message: | |
raise HTTPException(400, "No user message found") | |
# Add API key to proxy message | |
proxy_msg = { | |
"request_id": request_id, | |
"content": user_message.content, | |
"source": "api", | |
"destination": "proxy", | |
"model": request.model, | |
"temperature": request.temperature, | |
"incomingKey": api_key # Critical addition | |
} | |
await proxy_ws.send_text(json.dumps(proxy_msg)) | |
try: | |
response_content = await manager.wait_for_response(request_id) | |
except asyncio.TimeoutError: | |
raise HTTPException(504, "Proxy response timeout") | |
return ChatCompletionResponse( | |
id=request_id, | |
created=int(time.time()), | |
model=request.model, | |
choices=[ChatCompletionResponseChoice( | |
message=ChatMessage(role="assistant", content=response_content) | |
)] | |
) | |
async def websocket_endpoint(websocket: WebSocket): | |
await manager.connect(websocket) | |
try: | |
# Handle initial source identification | |
data = await websocket.receive_text() | |
init_msg = json.loads(data) | |
if 'source' in init_msg: | |
manager.set_source(websocket, init_msg['source']) | |
# Handle messages | |
while True: | |
message = await websocket.receive_text() | |
msg_data = json.loads(message) | |
# If this is a response to an API request | |
if 'request_id' in msg_data and msg_data.get('destination') == 'api': | |
queue = manager.response_queues.get(msg_data['request_id']) | |
if queue: | |
await queue.put(msg_data['content']) | |
else: | |
await manager.send_to_destination(msg_data['destination'], message) | |
except Exception as e: | |
manager.remove(websocket) | |
await websocket.close() | |
if __name__ == "__main__": | |
uvicorn.run(app, host="0.0.0.0", port=8000) | |