WebSocketChat / app.py
Almaatla's picture
Update app.py
28ebbb8 verified
from fastapi import FastAPI, WebSocket
from fastapi.responses import HTMLResponse
import uvicorn
import json
from fastapi import Request, HTTPException, Header, Depends
from pydantic import BaseModel
from typing import List, Optional
import uuid
import time
import asyncio
from typing import Optional
# Modern Python (3.10+) with Annotated
from typing import Annotated
class ChatMessage(BaseModel):
role: str
content: str
class ChatCompletionRequest(BaseModel):
model: str = "gemini-2.5-pro-exp-03-25"
messages: List[ChatMessage]
temperature: Optional[float] = 0.7
stream: Optional[bool] = False
class ChatCompletionResponseChoice(BaseModel):
index: int = 0
message: ChatMessage
finish_reason: str = "stop"
class ChatCompletionResponse(BaseModel):
id: str
object: str = "chat.completion"
created: int
model: str
choices: List[ChatCompletionResponseChoice]
app = FastAPI()
class ConnectionManager:
def __init__(self):
self.active_connections = {} # WebSocket: source
self.response_queues = {} # request_id: asyncio.Queue
async def connect(self, websocket: WebSocket):
await websocket.accept()
self.active_connections[websocket] = None
def set_source(self, websocket: WebSocket, source: str):
if websocket in self.active_connections:
self.active_connections[websocket] = source
async def send_to_destination(self, destination: str, message: str):
for ws, src in self.active_connections.items():
if src == destination:
await ws.send_text(message)
def remove(self, websocket: WebSocket):
if websocket in self.active_connections:
del self.active_connections[websocket]
async def wait_for_response(self, request_id: str, timeout: int = 30):
queue = asyncio.Queue(maxsize=1)
self.response_queues[request_id] = queue
try:
return await asyncio.wait_for(queue.get(), timeout=timeout)
finally:
self.response_queues.pop(request_id, None)
manager = ConnectionManager()
@app.get("/")
async def get():
return HTMLResponse("""
<html>
<body>
<h1>Chat Client</h1>
<div id="chat" style="height:300px;overflow-y:scroll"></div>
<input id="msg" type="text">
<button onclick="send()">Send</button>
<script>
const ws = new WebSocket('wss://' + window.location.host + '/ws');
ws.onopen = () => {
ws.send(JSON.stringify({ source: 'user' }));
};
ws.onmessage = e => {
const msg = JSON.parse(e.data);
document.getElementById('chat').innerHTML +=
`<div>${msg.content}</div>`;
};
const send = () => {
const message = {
content: document.getElementById('msg').value,
source: 'user',
destination: 'proxy'
};
ws.send(JSON.stringify(message));
document.getElementById('msg').value = '';
};
</script>
</body>
</html>
""")
@app.get("/proxy")
async def get_proxy():
return HTMLResponse("""
<html>
<body>
<h1>Proxy Client (Message Gateway)</h1>
<div style="display: grid; grid-template-columns: 1fr 1fr; gap: 20px; height: 80vh;">
<!-- Connection Panel -->
<div style="border-right: 1px solid #ccc; padding-right: 20px;">
<div style="margin-bottom: 20px;">
<div style="margin-bottom: 20px;">
<input type="password" id="apiKey" placeholder="LLM API Key" style="width: 100%; margin-bottom: 8px;">
<input type="password" id="incomingKey" placeholder="Proxy Incoming Key" style="width: 100%;">
<button onclick="initializeClient()" style="margin-top: 10px;">Connect</button>
</div>
<button onclick="initializeClient()" style="margin-top: 10px;">Fetch Models</button>
</div>
<select id="modelSelect" style="width: 100%; margin-bottom: 20px;"></select>
<div id="systemStatus" style="color: #666; font-size: 0.9em;"></div>
</div>
<!-- Message Flow Visualization -->
<div style="display: flex; flex-direction: column; height: 100%;">
<div id="messageFlow" style="flex: 1; border: 1px solid #eee; padding: 10px; overflow-y: auto; background: #f9f9f9;">
<div style="text-align: center; color: #999; margin-bottom: 10px;">Message Flow</div>
</div>
<div id="detailedStatus" style="color: #666; font-size: 0.9em; margin-top: 10px;"></div>
</div>
</div>
<style>
.message-entry {
margin: 5px 0;
padding: 8px;
border-radius: 8px;
background: white;
box-shadow: 0 2px 4px rgba(0,0,0,0.1);
font-family: monospace;
}
.incoming { border-left: 4px solid #4CAF50; }
.outgoing { border-left: 4px solid #2196F3; }
.system { border-left: 4px solid #9C27B0; }
.error { border-left: 4px solid #F44336; }
.message-header {
display: flex;
justify-content: space-between;
font-size: 0.8em;
color: #666;
margin-bottom: 4px;
}
</style>
<script>
let agentClient = null;
let currentModel = null;
const systemPrompt = "You are a helpful AI assistant. Respond concisely and accurately.";
const conversationHistory = [];
// Modified model initialization
function showStatus(message, type = 'info') {
const statusDiv = document.getElementById('systemStatus');
statusDiv.innerHTML = `<div style="color: ${type === 'error' ? '#F44336' : '#4CAF50'}">${message}</div>`;
addMessageEntry('system', 'system', 'proxy', message);
}
function initializeClient() {
const apiKey = document.getElementById('apiKey').value;
if (!apiKey) {
showStatus("Please enter an API key", 'error');
return;
}
agentClient = new ConversationalAgentClient(apiKey);
agentClient.populateLLMModels()
.then(models => {
agentClient.updateModelSelect('modelSelect', models.find(m => m.includes("gemini-2.5")));
currentModel = document.getElementById('modelSelect').value;
showStatus(`Loaded ${models.length} models. Default: ${currentModel}`);
})
.catch(error => {
showStatus(`Error fetching models: ${error.message}`, 'error');
});
}
// Model selection change handler
document.getElementById('modelSelect').addEventListener('change', function() {
currentModel = this.value;
showStatus(`Model changed to: ${currentModel}`);
});
// --- Include provided client classes here ---
// --- API Client Classes --- (Keep existing classes BaseAgentClient, ConversationalAgentClient)
class BaseAgentClient {
constructor(apiKey, apiUrl = 'https://llm.synapse.thalescloud.io/v1/') { this.apiKey = apiKey; this.apiUrl = apiUrl; this.models = []; this.maxCallsPerMinute = 4; this.callTimestamps = []; }
async fetchLLMModels() { if (!this.apiKey) throw new Error("API Key is not set."); console.log("Fetching models from:", this.apiUrl + 'models'); try { const response = await fetch(this.apiUrl + 'models', { method: 'GET', headers: { 'Authorization': `Bearer ${this.apiKey}` } }); if (!response.ok) { const errorText = await response.text(); console.error("Fetch models error response:", errorText); throw new Error(`HTTP error! Status: ${response.status} - ${errorText}`); } const data = await response.json(); console.log("Models fetched:", data.data); const filteredModels = data.data.map(model => model.id).filter(id => !id.toLowerCase().includes('embed') && !id.toLowerCase().includes('image')); return filteredModels; } catch (error) { console.error('Error fetching LLM models:', error); throw new Error(`Failed to fetch models: ${error.message}`); } }
async populateLLMModels(defaultModel = "gemini-2.5-pro-exp-03-25") { try { const modelList = await this.fetchLLMModels(); const sortedModels = modelList.sort((a, b) => { if (a === defaultModel) return -1; if (b === defaultModel) return 1; return a.localeCompare(b); }); const finalModels = []; if (sortedModels.includes(defaultModel)) { finalModels.push(defaultModel); sortedModels.forEach(model => { if (model !== defaultModel) finalModels.push(model); }); } else { finalModels.push(defaultModel); finalModels.push(...sortedModels); } this.models = finalModels; console.log("Populated models:", this.models); return this.models; } catch (error) { console.error("Error populating models:", error); this.models = [defaultModel]; throw error; } }
updateModelSelect(elementId = 'modelSelect', selectedModel = null) { const select = document.getElementById(elementId); if (!select) { console.warn(`Element ID ${elementId} not found.`); return; } const currentSelection = selectedModel || select.value || this.models[0]; select.innerHTML = ''; if (this.models.length === 0 || (this.models.length === 1 && this.models[0] === "gemini-2.5-pro-exp-03-25" && !this.apiKey)) { const option = document.createElement('option'); option.value = ""; option.textContent = "-- Fetch models first --"; option.disabled = true; select.appendChild(option); return; } this.models.forEach(model => { const option = document.createElement('option'); option.value = model; option.textContent = model; if (model === currentSelection) option.selected = true; select.appendChild(option); }); if (!select.value && this.models.length > 0) select.value = this.models[0]; }
async rateLimitWait() { const currentTime = Date.now(); this.callTimestamps = this.callTimestamps.filter(ts => currentTime - ts <= 60000); if (this.callTimestamps.length >= this.maxCallsPerMinute) { const waitTime = 60000 - (currentTime - this.callTimestamps[0]); const waitSeconds = Math.ceil(waitTime / 1000); const waitMessage = `Rate limit (${this.maxCallsPerMinute}/min) reached. Waiting ${waitSeconds}s...`; console.log(waitMessage); showGenerationStatus(waitMessage, 'warn'); await new Promise(resolve => setTimeout(resolve, waitTime + 100)); showGenerationStatus('Resuming after rate limit wait...', 'info'); this.callTimestamps = this.callTimestamps.filter(ts => Date.now() - ts <= 60000); } }
async callAgent(model, messages, temperature = 0.7) { await this.rateLimitWait(); const startTime = Date.now(); console.log("Calling Agent:", model); try { const response = await fetch(this.apiUrl + 'chat/completions', { method: 'POST', headers: { 'Content-Type': 'application/json', 'Authorization': `Bearer ${this.apiKey}` }, body: JSON.stringify({ model: model, messages: messages, temperature: temperature }) }); const endTime = Date.now(); this.callTimestamps.push(endTime); console.log(`API call took ${endTime - startTime} ms`); if (!response.ok) { const errorData = await response.json().catch(() => ({ error: { message: response.statusText } })); console.error("API Error:", errorData); throw new Error(errorData.error?.message || `API failed: ${response.status}`); } const data = await response.json(); if (!data.choices || !data.choices[0]?.message) throw new Error("Invalid API response structure"); console.log("API Response received."); return data.choices[0].message.content; } catch (error) { this.callTimestamps.push(Date.now()); console.error('Error calling agent:', error); throw error; } }
setMaxCallsPerMinute(value) { const parsedValue = parseInt(value, 10); if (!isNaN(parsedValue) && parsedValue > 0) { console.log(`Max calls/min set to: ${parsedValue}`); this.maxCallsPerMinute = parsedValue; return true; } console.warn(`Invalid max calls/min: ${value}`); return false; }
}
class ConversationalAgentClient extends BaseAgentClient {
constructor(apiKey, apiUrl = 'https://llm.synapse.thalescloud.io/v1/') {
super(apiKey, apiUrl);
}
async call(model, userPrompt, systemPrompt, conversationHistory = [], temperature = 0.7) {
const messages = [
{ role: 'system', content: systemPrompt },
...conversationHistory,
{ role: 'user', content: userPrompt }
];
const assistantResponse = await super.callAgent(model, messages, temperature);
const updatedHistory = [
...conversationHistory,
{ role: 'user', content: userPrompt },
{ role: 'assistant', content: assistantResponse }
];
return {
response: assistantResponse,
history: updatedHistory
};
}
async callWithCodeContext(
model,
userPrompt,
systemPrompt,
selectedCodeVersionsData = [],
conversationHistory = [],
temperature = 0.7
) {
let codeContext = "";
let fullSystemPrompt = systemPrompt || "";
if (selectedCodeVersionsData && selectedCodeVersionsData.length > 0) {
codeContext = `Code context (chronological):\n\n`;
selectedCodeVersionsData.forEach((versionData, index) => {
if (versionData && typeof versionData.code === 'string') {
codeContext += `--- Part ${index + 1} (${versionData.version || '?'}) ---\n`;
codeContext += `${versionData.code}\n\n`;
} else {
console.warn(`Invalid context version data at index ${index}`);
}
});
codeContext += `-------- end context ---\n\nUser request based on context:\n\n`;
}
const fullPrompt = codeContext + userPrompt;
const messages = [
{ role: 'system', content: fullSystemPrompt },
...conversationHistory,
{ role: 'user', content: fullPrompt }
];
const assistantResponse = await super.callAgent(model, messages, temperature);
const updatedHistory = [
...conversationHistory,
{ role: 'user', content: fullPrompt },
{ role: 'assistant', content: assistantResponse }
];
return {
response: assistantResponse,
history: updatedHistory
};
}
}
function addMessageEntry(direction, source, destination, content) {
const flowDiv = document.getElementById('messageFlow');
const timestamp = new Date().toLocaleTimeString();
const entry = document.createElement('div');
entry.className = `message-entry ${direction}`;
entry.innerHTML = `
<div class="message-header">
<span>${source} → ${destination}</span>
<span>${timestamp}</span>
</div>
<div style="white-space: pre-wrap;">${content}</div>
`;
flowDiv.appendChild(entry);
flowDiv.scrollTop = flowDiv.scrollHeight;
}
// Modified WebSocket handler
const ws = new WebSocket('wss://' + window.location.host + '/ws');
ws.onopen = () => {
ws.send(JSON.stringify({ source: 'proxy' }));
};
ws.onmessage = async e => {
const msg = JSON.parse(e.data);
// Display incoming messages
if (msg.destination === 'proxy') {
addMessageEntry('incoming', msg.source, 'proxy', msg.content);
document.getElementById('detailedStatus').textContent = `Processing ${msg.source} request...`;
// check if incoming call has the correct key
const expectedKey = document.getElementById('incomingKey').value;
if (!msg.incomingKey || msg.incomingKey !== expectedKey) {
ws.send(JSON.stringify({
request_id: msg.request_id,
content: "Error: Invalid authentication: --expected" + expectedKey + "--received: " + msg.incomingKey,
source: 'proxy',
destination: msg.source
}));
return;
}
showStatus("Authentication ok");
try {
const llmResponse = await agentClient.call(currentModel, msg.content, systemPrompt, conversationHistory);
// Display outgoing response
addMessageEntry('outgoing', 'proxy', msg.source, llmResponse.response);
const responseMsg = {
request_id: msg.request_id, // Critical addition
content: llmResponse.response,
source: 'proxy',
destination: msg.source
};
ws.send(JSON.stringify(responseMsg));
document.getElementById('detailedStatus').textContent = `Response sent to ${msg.source}`;
} catch (error) {
addMessageEntry('error', 'system', 'proxy', `Error: ${error.message}`);
const errorResponse = {
request_id: msg.request_id, // Critical addition
content: `Error: ${error.message}`,
source: 'proxy',
destination: msg.source
};
ws.send(JSON.stringify(errorResponse));
}
}
};
// Modified model initialization
function showStatus(message, type = 'info') {
const statusDiv = document.getElementById('systemStatus');
statusDiv.innerHTML = `<div style="color: ${type === 'error' ? '#F44336' : '#4CAF50'}">${message}</div>`;
addMessageEntry('system', 'system', 'proxy', message);
}
</script>
</body>
</html>
""")
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
async def chat_completions(
request: ChatCompletionRequest,
authorization: Annotated[Optional[str], Header()] = None # Correct format
):
# Extract and validate API key
if not authorization or not authorization.startswith("Bearer "):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Missing or invalid Authorization header"
)
api_key = authorization[7:] # Remove "Bearer " prefix
print("received key: " + api_key)
request_id = str(uuid.uuid4())
proxy_ws = next((ws for ws, src in manager.active_connections.items() if src == 'proxy'), None)
if not proxy_ws:
raise HTTPException(503, "Proxy client not connected")
user_message = next((m for m in request.messages if m.role == "user"), None)
if not user_message:
raise HTTPException(400, "No user message found")
# Add API key to proxy message
proxy_msg = {
"request_id": request_id,
"content": user_message.content,
"source": "api",
"destination": "proxy",
"model": request.model,
"temperature": request.temperature,
"incomingKey": api_key # Critical addition
}
await proxy_ws.send_text(json.dumps(proxy_msg))
try:
response_content = await manager.wait_for_response(request_id)
except asyncio.TimeoutError:
raise HTTPException(504, "Proxy response timeout")
return ChatCompletionResponse(
id=request_id,
created=int(time.time()),
model=request.model,
choices=[ChatCompletionResponseChoice(
message=ChatMessage(role="assistant", content=response_content)
)]
)
@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
await manager.connect(websocket)
try:
# Handle initial source identification
data = await websocket.receive_text()
init_msg = json.loads(data)
if 'source' in init_msg:
manager.set_source(websocket, init_msg['source'])
# Handle messages
while True:
message = await websocket.receive_text()
msg_data = json.loads(message)
# If this is a response to an API request
if 'request_id' in msg_data and msg_data.get('destination') == 'api':
queue = manager.response_queues.get(msg_data['request_id'])
if queue:
await queue.put(msg_data['content'])
else:
await manager.send_to_destination(msg_data['destination'], message)
except Exception as e:
manager.remove(websocket)
await websocket.close()
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)