WebSocketChat / app.py
Almaatla's picture
Update app.py
c6bbb27 verified
raw
history blame
21.3 kB
from fastapi import FastAPI, WebSocket
from fastapi.responses import HTMLResponse
import uvicorn
import json
from fastapi import Request, HTTPException
from pydantic import BaseModel
from typing import List, Optional
import uuid
import time
import asyncio
class ChatMessage(BaseModel):
role: str
content: str
class ChatCompletionRequest(BaseModel):
model: str = "gemini-2.5-pro-exp-03-25"
messages: List[ChatMessage]
temperature: Optional[float] = 0.7
stream: Optional[bool] = False
class ChatCompletionResponseChoice(BaseModel):
index: int = 0
message: ChatMessage
finish_reason: str = "stop"
class ChatCompletionResponse(BaseModel):
id: str
object: str = "chat.completion"
created: int
model: str
choices: List[ChatCompletionResponseChoice]
app = FastAPI()
class ConnectionManager:
def __init__(self):
self.active_connections = {} # WebSocket: source
self.response_queues = {} # request_id: asyncio.Queue
async def connect(self, websocket: WebSocket):
await websocket.accept()
self.active_connections[websocket] = None
def set_source(self, websocket: WebSocket, source: str):
if websocket in self.active_connections:
self.active_connections[websocket] = source
async def send_to_destination(self, destination: str, message: str):
for ws, src in self.active_connections.items():
if src == destination:
await ws.send_text(message)
def remove(self, websocket: WebSocket):
if websocket in self.active_connections:
del self.active_connections[websocket]
async def wait_for_response(self, request_id: str, timeout: int = 30):
queue = asyncio.Queue(maxsize=1)
self.response_queues[request_id] = queue
try:
return await asyncio.wait_for(queue.get(), timeout=timeout)
finally:
self.response_queues.pop(request_id, None)
manager = ConnectionManager()
@app.get("/")
async def get():
return HTMLResponse("""
<html>
<body>
<h1>Chat Client</h1>
<div id="chat" style="height:300px;overflow-y:scroll"></div>
<input id="msg" type="text">
<button onclick="send()">Send</button>
<script>
const ws = new WebSocket('wss://' + window.location.host + '/ws');
ws.onopen = () => {
ws.send(JSON.stringify({ source: 'user' }));
};
ws.onmessage = e => {
const msg = JSON.parse(e.data);
document.getElementById('chat').innerHTML +=
`<div>${msg.content}</div>`;
};
const send = () => {
const message = {
content: document.getElementById('msg').value,
source: 'user',
destination: 'proxy'
};
ws.send(JSON.stringify(message));
document.getElementById('msg').value = '';
};
</script>
</body>
</html>
""")
@app.get("/proxy")
async def get_proxy():
return HTMLResponse("""
<html>
<body>
<h1>Proxy Client (Message Gateway)</h1>
<div style="display: grid; grid-template-columns: 1fr 1fr; gap: 20px; height: 80vh;">
<!-- Connection Panel -->
<div style="border-right: 1px solid #ccc; padding-right: 20px;">
<div style="margin-bottom: 20px;">
<input type="password" id="apiKey" placeholder="Enter API Key" style="width: 100%;">
<button onclick="initializeClient()" style="margin-top: 10px;">Fetch Models</button>
</div>
<select id="modelSelect" style="width: 100%; margin-bottom: 20px;"></select>
<div id="systemStatus" style="color: #666; font-size: 0.9em;"></div>
</div>
<!-- Message Flow Visualization -->
<div style="display: flex; flex-direction: column; height: 100%;">
<div id="messageFlow" style="flex: 1; border: 1px solid #eee; padding: 10px; overflow-y: auto; background: #f9f9f9;">
<div style="text-align: center; color: #999; margin-bottom: 10px;">Message Flow</div>
</div>
<div id="detailedStatus" style="color: #666; font-size: 0.9em; margin-top: 10px;"></div>
</div>
</div>
<style>
.message-entry {
margin: 5px 0;
padding: 8px;
border-radius: 8px;
background: white;
box-shadow: 0 2px 4px rgba(0,0,0,0.1);
font-family: monospace;
}
.incoming { border-left: 4px solid #4CAF50; }
.outgoing { border-left: 4px solid #2196F3; }
.system { border-left: 4px solid #9C27B0; }
.error { border-left: 4px solid #F44336; }
.message-header {
display: flex;
justify-content: space-between;
font-size: 0.8em;
color: #666;
margin-bottom: 4px;
}
</style>
<script>
let agentClient = null;
let currentModel = null;
const systemPrompt = "You are a helpful AI assistant. Respond concisely and accurately.";
const conversationHistory = [];
// Modified model initialization
function showStatus(message, type = 'info') {
const statusDiv = document.getElementById('systemStatus');
statusDiv.innerHTML = `<div style="color: ${type === 'error' ? '#F44336' : '#4CAF50'}">${message}</div>`;
addMessageEntry('system', 'system', 'proxy', message);
}
function initializeClient() {
const apiKey = document.getElementById('apiKey').value;
if (!apiKey) {
showStatus("Please enter an API key", 'error');
return;
}
agentClient = new ConversationalAgentClient(apiKey);
agentClient.populateLLMModels()
.then(models => {
agentClient.updateModelSelect('modelSelect', models.find(m => m.includes("gemini-2.5")));
currentModel = document.getElementById('modelSelect').value;
showStatus(`Loaded ${models.length} models. Default: ${currentModel}`);
})
.catch(error => {
showStatus(`Error fetching models: ${error.message}`, 'error');
});
}
// Model selection change handler
document.getElementById('modelSelect').addEventListener('change', function() {
currentModel = this.value;
showStatus(`Model changed to: ${currentModel}`);
});
// --- Include provided client classes here ---
// --- API Client Classes --- (Keep existing classes BaseAgentClient, ConversationalAgentClient)
class BaseAgentClient {
constructor(apiKey, apiUrl = 'https://llm.synapse.thalescloud.io/v1/') { this.apiKey = apiKey; this.apiUrl = apiUrl; this.models = []; this.maxCallsPerMinute = 4; this.callTimestamps = []; }
async fetchLLMModels() { if (!this.apiKey) throw new Error("API Key is not set."); console.log("Fetching models from:", this.apiUrl + 'models'); try { const response = await fetch(this.apiUrl + 'models', { method: 'GET', headers: { 'Authorization': `Bearer ${this.apiKey}` } }); if (!response.ok) { const errorText = await response.text(); console.error("Fetch models error response:", errorText); throw new Error(`HTTP error! Status: ${response.status} - ${errorText}`); } const data = await response.json(); console.log("Models fetched:", data.data); const filteredModels = data.data.map(model => model.id).filter(id => !id.toLowerCase().includes('embed') && !id.toLowerCase().includes('image')); return filteredModels; } catch (error) { console.error('Error fetching LLM models:', error); throw new Error(`Failed to fetch models: ${error.message}`); } }
async populateLLMModels(defaultModel = "gemini-2.5-pro-exp-03-25") { try { const modelList = await this.fetchLLMModels(); const sortedModels = modelList.sort((a, b) => { if (a === defaultModel) return -1; if (b === defaultModel) return 1; return a.localeCompare(b); }); const finalModels = []; if (sortedModels.includes(defaultModel)) { finalModels.push(defaultModel); sortedModels.forEach(model => { if (model !== defaultModel) finalModels.push(model); }); } else { finalModels.push(defaultModel); finalModels.push(...sortedModels); } this.models = finalModels; console.log("Populated models:", this.models); return this.models; } catch (error) { console.error("Error populating models:", error); this.models = [defaultModel]; throw error; } }
updateModelSelect(elementId = 'modelSelect', selectedModel = null) { const select = document.getElementById(elementId); if (!select) { console.warn(`Element ID ${elementId} not found.`); return; } const currentSelection = selectedModel || select.value || this.models[0]; select.innerHTML = ''; if (this.models.length === 0 || (this.models.length === 1 && this.models[0] === "gemini-2.5-pro-exp-03-25" && !this.apiKey)) { const option = document.createElement('option'); option.value = ""; option.textContent = "-- Fetch models first --"; option.disabled = true; select.appendChild(option); return; } this.models.forEach(model => { const option = document.createElement('option'); option.value = model; option.textContent = model; if (model === currentSelection) option.selected = true; select.appendChild(option); }); if (!select.value && this.models.length > 0) select.value = this.models[0]; }
async rateLimitWait() { const currentTime = Date.now(); this.callTimestamps = this.callTimestamps.filter(ts => currentTime - ts <= 60000); if (this.callTimestamps.length >= this.maxCallsPerMinute) { const waitTime = 60000 - (currentTime - this.callTimestamps[0]); const waitSeconds = Math.ceil(waitTime / 1000); const waitMessage = `Rate limit (${this.maxCallsPerMinute}/min) reached. Waiting ${waitSeconds}s...`; console.log(waitMessage); showGenerationStatus(waitMessage, 'warn'); await new Promise(resolve => setTimeout(resolve, waitTime + 100)); showGenerationStatus('Resuming after rate limit wait...', 'info'); this.callTimestamps = this.callTimestamps.filter(ts => Date.now() - ts <= 60000); } }
async callAgent(model, messages, temperature = 0.7) { await this.rateLimitWait(); const startTime = Date.now(); console.log("Calling Agent:", model); try { const response = await fetch(this.apiUrl + 'chat/completions', { method: 'POST', headers: { 'Content-Type': 'application/json', 'Authorization': `Bearer ${this.apiKey}` }, body: JSON.stringify({ model: model, messages: messages, temperature: temperature }) }); const endTime = Date.now(); this.callTimestamps.push(endTime); console.log(`API call took ${endTime - startTime} ms`); if (!response.ok) { const errorData = await response.json().catch(() => ({ error: { message: response.statusText } })); console.error("API Error:", errorData); throw new Error(errorData.error?.message || `API failed: ${response.status}`); } const data = await response.json(); if (!data.choices || !data.choices[0]?.message) throw new Error("Invalid API response structure"); console.log("API Response received."); return data.choices[0].message.content; } catch (error) { this.callTimestamps.push(Date.now()); console.error('Error calling agent:', error); throw error; } }
setMaxCallsPerMinute(value) { const parsedValue = parseInt(value, 10); if (!isNaN(parsedValue) && parsedValue > 0) { console.log(`Max calls/min set to: ${parsedValue}`); this.maxCallsPerMinute = parsedValue; return true; } console.warn(`Invalid max calls/min: ${value}`); return false; }
}
class ConversationalAgentClient extends BaseAgentClient {
constructor(apiKey, apiUrl = 'https://llm.synapse.thalescloud.io/v1/') {
super(apiKey, apiUrl);
}
async call(model, userPrompt, systemPrompt, conversationHistory = [], temperature = 0.7) {
const messages = [
{ role: 'system', content: systemPrompt },
...conversationHistory,
{ role: 'user', content: userPrompt }
];
const assistantResponse = await super.callAgent(model, messages, temperature);
const updatedHistory = [
...conversationHistory,
{ role: 'user', content: userPrompt },
{ role: 'assistant', content: assistantResponse }
];
return {
response: assistantResponse,
history: updatedHistory
};
}
async callWithCodeContext(
model,
userPrompt,
systemPrompt,
selectedCodeVersionsData = [],
conversationHistory = [],
temperature = 0.7
) {
let codeContext = "";
let fullSystemPrompt = systemPrompt || "";
if (selectedCodeVersionsData && selectedCodeVersionsData.length > 0) {
codeContext = `Code context (chronological):\n\n`;
selectedCodeVersionsData.forEach((versionData, index) => {
if (versionData && typeof versionData.code === 'string') {
codeContext += `--- Part ${index + 1} (${versionData.version || '?'}) ---\n`;
codeContext += `${versionData.code}\n\n`;
} else {
console.warn(`Invalid context version data at index ${index}`);
}
});
codeContext += `-------- end context ---\n\nUser request based on context:\n\n`;
}
const fullPrompt = codeContext + userPrompt;
const messages = [
{ role: 'system', content: fullSystemPrompt },
...conversationHistory,
{ role: 'user', content: fullPrompt }
];
const assistantResponse = await super.callAgent(model, messages, temperature);
const updatedHistory = [
...conversationHistory,
{ role: 'user', content: fullPrompt },
{ role: 'assistant', content: assistantResponse }
];
return {
response: assistantResponse,
history: updatedHistory
};
}
}
function addMessageEntry(direction, source, destination, content) {
const flowDiv = document.getElementById('messageFlow');
const timestamp = new Date().toLocaleTimeString();
const entry = document.createElement('div');
entry.className = `message-entry ${direction}`;
entry.innerHTML = `
<div class="message-header">
<span>${source} → ${destination}</span>
<span>${timestamp}</span>
</div>
<div style="white-space: pre-wrap;">${content}</div>
`;
flowDiv.appendChild(entry);
flowDiv.scrollTop = flowDiv.scrollHeight;
}
// Modified WebSocket handler
const ws = new WebSocket('wss://' + window.location.host + '/ws');
ws.onopen = () => {
ws.send(JSON.stringify({ source: 'proxy' }));
};
ws.onmessage = async e => {
const msg = JSON.parse(e.data);
// Display incoming messages
if (msg.destination === 'proxy') {
addMessageEntry('incoming', msg.source, 'proxy', msg.content);
document.getElementById('detailedStatus').textContent = `Processing ${msg.source} request...`;
try {
const llmResponse = await agentClient.call(currentModel, msg.content, systemPrompt, conversationHistory);
// Display outgoing response
addMessageEntry('outgoing', 'proxy', msg.source, llmResponse.response);
const responseMsg = {
content: llmResponse.response,
source: 'proxy',
destination: msg.source
};
ws.send(JSON.stringify(responseMsg));
document.getElementById('detailedStatus').textContent = `Response sent to ${msg.source}`;
} catch (error) {
addMessageEntry('error', 'system', 'proxy', `Error: ${error.message}`);
const errorResponse = {
request_id: msg.request_id, // Critical addition
content: `Error: ${error.message}`,
source: 'proxy',
destination: msg.source
};
ws.send(JSON.stringify(errorResponse));
}
}
};
// Modified model initialization
function showStatus(message, type = 'info') {
const statusDiv = document.getElementById('systemStatus');
statusDiv.innerHTML = `<div style="color: ${type === 'error' ? '#F44336' : '#4CAF50'}">${message}</div>`;
addMessageEntry('system', 'system', 'proxy', message);
}
</script>
</body>
</html>
""")
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
async def chat_completions(request: ChatCompletionRequest):
request_id = str(uuid.uuid4())
# Find proxy connection
proxy_ws = next((ws for ws, src in manager.active_connections.items() if src == 'proxy'), None)
if not proxy_ws:
raise HTTPException(503, "Proxy client not connected")
# Get user message
user_message = next((m for m in request.messages if m.role == "user"), None)
if not user_message:
raise HTTPException(400, "No user message found")
# Send to proxy
proxy_msg = {
"request_id": request_id,
"content": user_message.content,
"source": "api",
"destination": "proxy",
"model": request.model,
"temperature": request.temperature
}
await proxy_ws.send_text(json.dumps(proxy_msg))
# Wait for response from proxy
try:
response_content = await manager.wait_for_response(request_id)
except asyncio.TimeoutError:
raise HTTPException(504, "Proxy response timeout")
# Return OpenAI-compatible response
return ChatCompletionResponse(
id=request_id,
created=int(time.time()),
model=request.model,
choices=[ChatCompletionResponseChoice(
message=ChatMessage(role="assistant", content=response_content)
)]
)
@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
await manager.connect(websocket)
try:
# Handle initial source identification
data = await websocket.receive_text()
init_msg = json.loads(data)
if 'source' in init_msg:
manager.set_source(websocket, init_msg['source'])
# Handle messages
while True:
message = await websocket.receive_text()
msg_data = json.loads(message)
# If this is a response to an API request
if 'request_id' in msg_data and msg_data.get('destination') == 'api':
queue = manager.response_queues.get(msg_data['request_id'])
if queue:
await queue.put(msg_data['content'])
else:
await manager.send_to_destination(msg_data['destination'], message)
except Exception as e:
manager.remove(websocket)
await websocket.close()
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)