Spaces:
Running
Running
File size: 13,682 Bytes
18a21d2 050938e 18a21d2 050938e 18a21d2 050938e 18a21d2 050938e 18a21d2 050938e 18a21d2 0242cbc 050938e 18a21d2 050938e 18a21d2 050938e 18a21d2 050938e 18a21d2 050938e 120f3d6 050938e 120f3d6 050938e 120f3d6 050938e 120f3d6 050938e 120f3d6 050938e 120f3d6 18a21d2 050938e 18a21d2 050938e 18a21d2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 |
from fastapi import FastAPI, WebSocket
from fastapi.responses import HTMLResponse
import uvicorn
import json
app = FastAPI()
class ConnectionManager:
def __init__(self):
self.active_connections = {} # WebSocket: source
async def connect(self, websocket: WebSocket):
await websocket.accept()
self.active_connections[websocket] = None
def set_source(self, websocket: WebSocket, source: str):
if websocket in self.active_connections:
self.active_connections[websocket] = source
async def send_to_destination(self, destination: str, message: str):
for ws, src in self.active_connections.items():
if src == destination:
await ws.send_text(message)
def remove(self, websocket: WebSocket):
if websocket in self.active_connections:
del self.active_connections[websocket]
manager = ConnectionManager()
@app.get("/")
async def get():
return HTMLResponse("""
<html>
<body>
<h1>Chat Client</h1>
<div id="chat" style="height:300px;overflow-y:scroll"></div>
<input id="msg" type="text">
<button onclick="send()">Send</button>
<script>
const ws = new WebSocket('wss://' + window.location.host + '/ws');
ws.onopen = () => {
ws.send(JSON.stringify({ source: 'user' }));
};
ws.onmessage = e => {
const msg = JSON.parse(e.data);
document.getElementById('chat').innerHTML +=
`<div>${msg.content}</div>`;
};
const send = () => {
const message = {
content: document.getElementById('msg').value,
source: 'user',
destination: 'proxy'
};
ws.send(JSON.stringify(message));
document.getElementById('msg').value = '';
};
</script>
</body>
</html>
""")
@app.get("/proxy")
async def get_proxy():
return HTMLResponse("""
<html>
<body>
<h1>Proxy Client (LLM Gateway)</h1>
<div style="margin-bottom: 20px;">
<input type="password" id="apiKey" placeholder="Enter API Key" style="width: 300px;">
<button onclick="initializeClient()">Fetch Models</button>
</div>
<div style="margin-bottom: 20px;">
<select id="modelSelect" style="width: 300px;">
<option value="" disabled selected>-- Select Model --</option>
</select>
</div>
<div id="status"></div>
<script>
let agentClient = null;
let currentModel = null;
const systemPrompt = "You are a helpful AI assistant. Respond concisely and accurately.";
const conversationHistory = [];
function showStatus(message, type = 'info') {
const statusDiv = document.getElementById('status');
statusDiv.innerHTML = `<div style="color: ${type === 'error' ? 'red' : 'orange'}">${message}</div>`;
}
function initializeClient() {
const apiKey = document.getElementById('apiKey').value;
if (!apiKey) {
showStatus("Please enter an API key", 'error');
return;
}
agentClient = new ConversationalAgentClient(apiKey);
agentClient.populateLLMModels()
.then(models => {
agentClient.updateModelSelect('modelSelect', models.find(m => m.includes("gemini-2.5")));
currentModel = document.getElementById('modelSelect').value;
showStatus(`Loaded ${models.length} models. Default: ${currentModel}`);
})
.catch(error => {
showStatus(`Error fetching models: ${error.message}`, 'error');
});
}
// WebSocket setup
const ws = new WebSocket('wss://' + window.location.host + '/ws');
ws.onopen = () => {
ws.send(JSON.stringify({ source: 'proxy' }));
};
ws.onmessage = async e => {
const msg = JSON.parse(e.data);
if (msg.destination === 'proxy') {
try {
showStatus("Processing user query...");
const llmResponse = await agentClient.call(
currentModel,
msg.content,
systemPrompt,
conversationHistory
);
const responseMsg = {
content: llmResponse.response,
source: 'proxy',
destination: 'user'
};
ws.send(JSON.stringify(responseMsg));
showStatus("Response sent successfully");
} catch (error) {
console.error("LLM Error:", error);
const errorResponse = {
content: `Error processing request: ${error.message}`,
source: 'proxy',
destination: 'user'
};
ws.send(JSON.stringify(errorResponse));
showStatus(`Error: ${error.message}`, 'error');
}
}
};
// Model selection change handler
document.getElementById('modelSelect').addEventListener('change', function() {
currentModel = this.value;
showStatus(`Model changed to: ${currentModel}`);
});
// --- Include provided client classes here ---
// --- API Client Classes --- (Keep existing classes BaseAgentClient, ConversationalAgentClient)
class BaseAgentClient {
constructor(apiKey, apiUrl = 'https://llm.synapse.thalescloud.io/v1/') { this.apiKey = apiKey; this.apiUrl = apiUrl; this.models = []; this.maxCallsPerMinute = 4; this.callTimestamps = []; }
async fetchLLMModels() { if (!this.apiKey) throw new Error("API Key is not set."); console.log("Fetching models from:", this.apiUrl + 'models'); try { const response = await fetch(this.apiUrl + 'models', { method: 'GET', headers: { 'Authorization': `Bearer ${this.apiKey}` } }); if (!response.ok) { const errorText = await response.text(); console.error("Fetch models error response:", errorText); throw new Error(`HTTP error! Status: ${response.status} - ${errorText}`); } const data = await response.json(); console.log("Models fetched:", data.data); const filteredModels = data.data.map(model => model.id).filter(id => !id.toLowerCase().includes('embed') && !id.toLowerCase().includes('image')); return filteredModels; } catch (error) { console.error('Error fetching LLM models:', error); throw new Error(`Failed to fetch models: ${error.message}`); } }
async populateLLMModels(defaultModel = "gemini-2.5-pro-exp-03-25") { try { const modelList = await this.fetchLLMModels(); const sortedModels = modelList.sort((a, b) => { if (a === defaultModel) return -1; if (b === defaultModel) return 1; return a.localeCompare(b); }); const finalModels = []; if (sortedModels.includes(defaultModel)) { finalModels.push(defaultModel); sortedModels.forEach(model => { if (model !== defaultModel) finalModels.push(model); }); } else { finalModels.push(defaultModel); finalModels.push(...sortedModels); } this.models = finalModels; console.log("Populated models:", this.models); return this.models; } catch (error) { console.error("Error populating models:", error); this.models = [defaultModel]; throw error; } }
updateModelSelect(elementId = 'modelSelect', selectedModel = null) { const select = document.getElementById(elementId); if (!select) { console.warn(`Element ID ${elementId} not found.`); return; } const currentSelection = selectedModel || select.value || this.models[0]; select.innerHTML = ''; if (this.models.length === 0 || (this.models.length === 1 && this.models[0] === "gemini-2.5-pro-exp-03-25" && !this.apiKey)) { const option = document.createElement('option'); option.value = ""; option.textContent = "-- Fetch models first --"; option.disabled = true; select.appendChild(option); return; } this.models.forEach(model => { const option = document.createElement('option'); option.value = model; option.textContent = model; if (model === currentSelection) option.selected = true; select.appendChild(option); }); if (!select.value && this.models.length > 0) select.value = this.models[0]; }
async rateLimitWait() { const currentTime = Date.now(); this.callTimestamps = this.callTimestamps.filter(ts => currentTime - ts <= 60000); if (this.callTimestamps.length >= this.maxCallsPerMinute) { const waitTime = 60000 - (currentTime - this.callTimestamps[0]); const waitSeconds = Math.ceil(waitTime / 1000); const waitMessage = `Rate limit (${this.maxCallsPerMinute}/min) reached. Waiting ${waitSeconds}s...`; console.log(waitMessage); showGenerationStatus(waitMessage, 'warn'); await new Promise(resolve => setTimeout(resolve, waitTime + 100)); showGenerationStatus('Resuming after rate limit wait...', 'info'); this.callTimestamps = this.callTimestamps.filter(ts => Date.now() - ts <= 60000); } }
async callAgent(model, messages, temperature = 0.7) { await this.rateLimitWait(); const startTime = Date.now(); console.log("Calling Agent:", model); try { const response = await fetch(this.apiUrl + 'chat/completions', { method: 'POST', headers: { 'Content-Type': 'application/json', 'Authorization': `Bearer ${this.apiKey}` }, body: JSON.stringify({ model: model, messages: messages, temperature: temperature }) }); const endTime = Date.now(); this.callTimestamps.push(endTime); console.log(`API call took ${endTime - startTime} ms`); if (!response.ok) { const errorData = await response.json().catch(() => ({ error: { message: response.statusText } })); console.error("API Error:", errorData); throw new Error(errorData.error?.message || `API failed: ${response.status}`); } const data = await response.json(); if (!data.choices || !data.choices[0]?.message) throw new Error("Invalid API response structure"); console.log("API Response received."); return data.choices[0].message.content; } catch (error) { this.callTimestamps.push(Date.now()); console.error('Error calling agent:', error); throw error; } }
setMaxCallsPerMinute(value) { const parsedValue = parseInt(value, 10); if (!isNaN(parsedValue) && parsedValue > 0) { console.log(`Max calls/min set to: ${parsedValue}`); this.maxCallsPerMinute = parsedValue; return true; } console.warn(`Invalid max calls/min: ${value}`); return false; }
}
class ConversationalAgentClient extends BaseAgentClient {
constructor(apiKey, apiUrl = 'https://llm.synapse.thalescloud.io/v1/') { super(apiKey, apiUrl); }
async call(model, userPrompt, systemPrompt, conversationHistory = [], temperature = 0.7) { const messages = [{ role: 'system', content: systemPrompt }, ...conversationHistory, { role: 'user', content: userPrompt }]; const assistantResponse = await super.callAgent(model, messages, temperature); const updatedHistory = [...conversationHistory, { role: 'user', content: userPrompt }, { role: 'assistant', content: assistantResponse }]; return { response: assistantResponse, history: updatedHistory }; }
async callWithCodeContext(model, userPrompt, systemPrompt, selectedCodeVersionsData = [], conversationHistory = [], temperature = 0.7) { let codeContext = ""; let fullSystemPrompt = systemPrompt || ""; if (selectedCodeVersionsData && selectedCodeVersionsData.length > 0) { codeContext = "Code context (chronological):\n\n"; selectedCodeVersionsData.forEach((versionData, index) => { if (versionData && typeof versionData.code === 'string') codeContext += `--- Part ${index + 1} (${versionData.version || '?'}) ---\n${versionData.code}\n\n`; else console.warn(`Invalid context version data at index ${index}`); }); codeContext += "-------- end context ---\n\nUser request based on context:\n\n"; } const fullPrompt = codeContext + userPrompt; const messages = [{ role: 'system', content: fullSystemPrompt }, ...conversationHistory, { role: 'user', content: fullPrompt }]; const assistantResponse = await super.callAgent(model, messages, temperature); const updatedHistory = [...conversationHistory, { role: 'user', content: fullPrompt }, { role: 'assistant', content: assistantResponse }]; return { response: assistantResponse, history: updatedHistory }; }
}
</script>
</body>
</html>
""")
@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
await manager.connect(websocket)
try:
# Handle initial source identification
data = await websocket.receive_text()
init_msg = json.loads(data)
if 'source' in init_msg:
manager.set_source(websocket, init_msg['source'])
# Handle messages
while True:
message = await websocket.receive_text()
msg_data = json.loads(message)
await manager.send_to_destination(msg_data['destination'], message)
except Exception as e:
manager.remove(websocket)
await websocket.close()
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)
|