private-synthid / wasm-demo.js
jfrery-zama's picture
add retry on network error
20d002d unverified
import initWasm, {
decrypt_serialized_u64_radix_flat_wasm
} from './concrete-ml-extensions-wasm/concrete_ml_extensions_wasm.js';
const SERVER = 'https://api.zama.ai';
let clientKey, serverKey;
let encTokens;
let encServerResult;
let keygenWorker;
let encryptWorker;
let sessionUid;
let taskId;
let currentTokenCount = 0;
let progressTimer;
// Memory-efficient base64 encoding for large Uint8Array
function uint8ToBase64(uint8) {
return new Promise((resolve, reject) => {
const blob = new Blob([uint8]);
const reader = new FileReader();
reader.onload = function () {
const base64 = reader.result.split(',')[1];
resolve(base64);
};
reader.onerror = reject;
reader.readAsDataURL(blob);
});
}
// β€”β€”β€” local key cache helpers β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
const KEYS_STORAGE_KEY = 'synthid_keys_v1';
/** base64 β†’ Uint8Array (works in all browsers, avoids atob size limits) */
function base64ToUint8(base64) {
const binStr = atob(base64);
const len = binStr.length;
const bytes = new Uint8Array(len);
for (let i = 0; i < len; i++) bytes[i] = binStr.charCodeAt(i);
return bytes;
}
function getSavedKeys() { return JSON.parse(localStorage.getItem(KEYS_STORAGE_KEY) || '{}'); }
function saveKeys(map) { localStorage.setItem(KEYS_STORAGE_KEY, JSON.stringify(map)); }
function saveKeyset(uid, b64){ const m = getSavedKeys(); m[uid] = b64; saveKeys(m); }
function clearAllKeys() { localStorage.removeItem(KEYS_STORAGE_KEY); }
const $ = id => document.getElementById(id);
const enable = (id, ok=true) => $(id).disabled = !ok;
const show = (id, visible=true) => $(id).hidden = !visible;
// Hide all spinners immediately
show('keygenSpin', false);
show('spin', false);
show('encIcon', false);
show('tokenizerSpin', false);
// Initialize WASM
(async () => {
try {
console.log('[Main] Initializing WASM module...');
await initWasm();
console.log('[Main] WASM module initialized successfully');
// Initialize the keygen worker
keygenWorker = new Worker(new URL('./keygen-worker.js', import.meta.url), { type: 'module' });
keygenWorker.onmessage = async function(e) {
if (e.data.type === 'success') {
const res = e.data.result;
console.log('[Main] Key generation successful');
console.log(`[Main] Client key size: ${res.clientKey.length} bytes`);
console.log(`[Main] Server key size: ${res.serverKey.length} bytes`);
clientKey = res.clientKey; serverKey = res.serverKey;
try {
// Initialize encryption worker
initEncryptWorkerWithKey(clientKey);
console.log('[Main] Sending server key to server...');
$('keygenStatus').textContent = 'Keys generated, sending server key...';
show('keygenSpin', true);
const formData = new FormData();
const serverKeyBlob = new Blob([serverKey], { type: 'application/octet-stream' });
const serverKeyFile = new File([serverKeyBlob], "server.key");
formData.append('key', serverKeyFile);
formData.append('task_name', 'synthid');
const addKeyResponse = await fetch(`${SERVER}/add_key`, {
method: 'POST',
body: formData
});
if (!addKeyResponse.ok) {
const errorText = await addKeyResponse.text();
throw new Error(`Server /add_key failed: ${addKeyResponse.status} ${errorText}`);
}
const { uid } = await addKeyResponse.json();
sessionUid = uid;
console.log('[Main] Server key sent and UID received:', sessionUid);
$('keygenStatus').textContent = 'Keys generated & UID received βœ“';
enable('btnEncrypt');
// Persist clientKey ⟷ uid for reuse
uint8ToBase64(clientKey)
.then(b64 => {
saveKeyset(sessionUid, b64);
console.log(`[Main] Saved clientKey for uid ${sessionUid} to localStorage`);
})
.catch(err => console.warn('[Main] Failed to save key:', err));
} catch (error) {
console.error('[Main] Server key submission error:', error);
$('keygenStatus').textContent = `Server key submission failed: ${error.message}`;
enable('btnEncrypt', false);
} finally {
show('keygenSpin', false);
}
} else {
console.error('[Main] Key generation error:', e.data.error);
$('keygenStatus').textContent = `Error generating keys: ${e.data.error}`;
show('keygenSpin', false);
}
};
} catch (e) {
console.error('[Main] Failed to initialize WASM module:', e);
$('keygenStatus').textContent = `Initialization Error: ${e.message}`;
throw e;
}
})();
$('btnKeygen').onclick = async () => {
if ($('keygenSpin').hidden === false) {
console.log('[Main] Keygen already in progress, ignoring click');
return;
}
show('keygenSpin', true);
$('keygenStatus').textContent = 'Generating keys…';
try {
keygenWorker.postMessage({});
} catch (e) {
console.error('[Main] Key generation error:', e);
$('keygenStatus').textContent = `Error generating keys: ${e.message}`;
show('keygenSpin', false);
}
};
$('btnLoadSaved').onclick = async () => {
const saved = getSavedKeys();
const ids = Object.keys(saved);
if (!ids.length) {
alert('No saved keys found on this machine.');
return;
}
// Very lightweight UI: ask which uid to use
const uid = prompt(
`Saved key sets:\n${ids.join('\n')}\n\nEnter the uid you want to use:`,
ids[0]
);
if (!uid || !saved[uid]) {
alert('Invalid or unknown uid.');
return;
}
try {
sessionUid = uid;
clientKey = base64ToUint8(saved[uid]);
// Make sure we have an encryption worker ready
initEncryptWorkerWithKey(clientKey);
$('keygenStatus').textContent = `Loaded saved keys for ${uid} βœ“`;
enable('btnEncrypt');
} catch (err) {
console.error('[Main] Failed to load key:', err);
alert(`Failed to load saved key: ${err.message}`);
}
};
$('btnDeleteKeys').onclick = async () => {
const saved = getSavedKeys();
const ids = Object.keys(saved);
if (!ids.length) {
alert('No saved keys found on this machine.');
return;
}
const confirmed = confirm(
`Are you sure you want to delete all saved keys?\n\nThis will remove ${ids.length} saved key set(s):\n${ids.join('\n')}\n\nThis action cannot be undone.`
);
if (confirmed) {
try {
clearAllKeys();
console.log('[Main] All saved keys deleted');
alert('All saved keys have been deleted.');
} catch (err) {
console.error('[Main] Failed to delete keys:', err);
alert(`Failed to delete keys: ${err.message}`);
}
}
};
// Add example text buttons
$('btnWatermarked').onclick = () => {
$('tokenInput').value = 'watermarking is useful for a variety of reasons like authentication, privacy';
$('tokenInput').dispatchEvent(new Event('input'));
};
// Add token counter functionality
$('tokenInput').addEventListener('input', () => {
const text = $('tokenInput').value.trim();
if (text && typeof llama3Tokenizer !== 'undefined') {
try {
const tokenIds = llama3Tokenizer.encode(text);
const tokenCount = tokenIds.length;
currentTokenCount = tokenCount;
const TOKEN_LIMIT = 16;
// Update estimated processing time
const estimatedSeconds = tokenCount * 30;
const minutes = Math.floor(estimatedSeconds / 60);
const seconds = estimatedSeconds % 60;
const timeText = minutes > 0 ? `${minutes}m ${seconds}s` : `${seconds}s`;
$('estimatedTime').textContent = timeText;
if (tokenCount > TOKEN_LIMIT) {
$('encStatus').textContent = `⚠️ ${tokenCount}/${TOKEN_LIMIT} tokens - exceeds limit, encryption disabled`;
$('encStatus').style.color = '#d32f2f';
enable('btnEncrypt', false);
} else if (tokenCount < 10) {
$('encStatus').textContent = `⚠️ ${tokenCount}/${TOKEN_LIMIT} tokens - low reliability`;
$('encStatus').style.color = '#f57c00';
enable('btnEncrypt', true);
} else {
$('encStatus').textContent = `${tokenCount}/${TOKEN_LIMIT} tokens`;
$('encStatus').style.color = '';
enable('btnEncrypt', true);
}
} catch (e) {
// Tokenizer might not be ready yet
$('encStatus').textContent = '';
currentTokenCount = 0;
enable('btnEncrypt', true);
}
} else {
$('encStatus').textContent = '';
$('encStatus').style.color = '';
currentTokenCount = 0;
enable('btnEncrypt', true);
}
});
$('btnEncrypt').onclick = async () => {
const text = $('tokenInput').value.trim();
if (!text) {
console.error('[Main] No text provided for tokenization/encryption');
alert('Please enter text to encrypt.');
return;
}
if (!encryptWorker) {
console.error('[Main] Encryption worker not initialized');
alert('Encryption worker is not ready. Please generate keys first.');
return;
}
// Validate token limit before proceeding
try {
const tokenIds = llama3Tokenizer.encode(text);
const TOKEN_LIMIT = 16;
if (tokenIds.length > TOKEN_LIMIT) {
console.error(`[Main] Token limit exceeded: ${tokenIds.length}/${TOKEN_LIMIT} tokens`);
alert(`Text is too long. Maximum ${TOKEN_LIMIT} tokens allowed, but your text has ${tokenIds.length} tokens. Please shorten your text.`);
return;
}
} catch (error) {
console.error('[Main] Token validation error:', error);
alert(`Error validating text: ${error.message}`);
return;
}
show('encryptSpin', true);
show('encIcon', false);
enable('btnEncrypt', false);
try {
console.log('[Main] Tokenizing text:', text);
const tokenIds = llama3Tokenizer.encode(text);
console.log('[Main] Token IDs:', tokenIds);
encryptWorker.postMessage({ type: 'encrypt', tokenIds });
} catch (error) {
console.error('[Main] Tokenization or encryption initiation error:', error);
show('encryptSpin', false);
enable('btnEncrypt', true);
alert(`Error during tokenization/encryption: ${error.message}`);
}
};
async function pollTaskStatus(currentTaskId, currentUid, retryCount = 0, maxRetries = 10) {
try {
const statusResponse = await fetch(`${SERVER}/get_task_status?task_id=${currentTaskId}&uid=${currentUid}`);
if (!statusResponse.ok) {
const errorText = await statusResponse.text();
console.error(`[Poll] Error fetching status: ${statusResponse.status} ${errorText}`);
$('srvStatus').textContent = `Error checking status`;
show('spin', false);
return null;
}
const statusData = await statusResponse.json();
console.log('[Poll] Task status:', statusData);
// Parse and display user-friendly status messages
let userMessage = '';
let showComputing = false;
if (statusData.status === 'queued') {
// Extract position from details if available
const positionMatch = statusData.details?.match(/Position:\s*(\d+)\/(\d+)/);
if (positionMatch) {
const [, position, total] = positionMatch;
userMessage = `Waiting in queue (${position} of ${total})`;
} else {
userMessage = 'Waiting in queue...';
}
} else if (statusData.status === 'processing' || statusData.status === 'running' || statusData.status === 'started') {
userMessage = 'Processing your request...';
showComputing = true;
// Start progress bar when processing actually begins
show('progressContainer', true);
if (!window.processingStartTime) {
window.processingStartTime = performance.now(); // Set processing start time
}
if (!progressTimer) {
progressTimer = setInterval(updateProgressBar, 1000);
}
} else if (statusData.status === 'success' || statusData.status === 'completed') {
userMessage = 'Processing complete!';
// Set progress to 100% and clear timer
$('progressBar').style.width = '100%';
if (progressTimer) {
clearInterval(progressTimer);
progressTimer = null;
}
} else if (['failure', 'revoked', 'unknown', 'error'].includes(statusData.status.toLowerCase())) {
userMessage = 'Task failed. Please try again.';
// Clear timer on failure
window.processingStartTime = null;
if (progressTimer) {
clearInterval(progressTimer);
progressTimer = null;
}
} else {
// Fallback for any other status
userMessage = `Status: ${statusData.status}`;
}
$('srvStatus').textContent = userMessage;
$('srvComputing').hidden = !showComputing;
if (statusData.status === 'success' || statusData.status === 'completed') {
return statusData;
} else if (['failure', 'revoked', 'unknown', 'error'].includes(statusData.status.toLowerCase())) {
console.error('[Poll] Task failed or unrecoverable:', statusData);
show('spin', false);
show('progressContainer', false);
window.processingStartTime = null;
if (progressTimer) {
clearInterval(progressTimer);
progressTimer = null;
}
return null;
} else {
// Calculate delay with exponential backoff (max 30 seconds)
const baseDelay = 5000;
const delay = Math.min(baseDelay * Math.pow(1.5, retryCount), 30000);
setTimeout(() => pollTaskStatus(currentTaskId, currentUid, retryCount, maxRetries).then(finalStatus => {
if (finalStatus && (finalStatus.status === 'success' || finalStatus.status === 'completed')) {
getTaskResult(currentTaskId, currentUid, 'synthid');
}
}), delay);
return null;
}
} catch (e) {
console.error('[Poll] Polling exception:', e);
// Check if we've exceeded max retries
if (retryCount >= maxRetries) {
console.error(`[Poll] Max retries (${maxRetries}) exceeded. Giving up.`);
$('srvStatus').textContent = 'Connection failed after multiple attempts. Please try again later.';
cleanupPolling();
return null;
}
// Distinguish between network errors and other errors
const isNetworkError = e.name === 'TypeError' && e.message.includes('fetch');
const isDnsError = e.message.includes('ERR_NAME_NOT_RESOLVED');
if (isNetworkError || isDnsError) {
console.warn(`[Poll] Network error (attempt ${retryCount + 1}/${maxRetries}):`, e.message);
$('srvStatus').textContent = `Connection error (attempt ${retryCount + 1}/${maxRetries}). Retrying...`;
// Calculate delay with exponential backoff
const baseDelay = 5000;
const delay = Math.min(baseDelay * Math.pow(1.5, retryCount), 30000);
setTimeout(() => pollTaskStatus(currentTaskId, currentUid, retryCount + 1, maxRetries).then(finalStatus => {
if (finalStatus && (finalStatus.status === 'success' || finalStatus.status === 'completed')) {
getTaskResult(currentTaskId, currentUid, 'synthid');
}
}), delay);
return null;
} else {
// Non-network error - don't retry
console.error('[Poll] Non-recoverable error:', e);
$('srvStatus').textContent = 'An unexpected error occurred. Please try again.';
cleanupPolling();
return null;
}
}
}
function cleanupPolling() {
show('spin', false);
show('progressContainer', false);
$('srvComputing').hidden = true;
window.processingStartTime = null;
if (progressTimer) {
clearInterval(progressTimer);
progressTimer = null;
}
}
function updateProgressBar() {
if (!window.processingStartTime || !window.expectedDuration) return;
const elapsed = performance.now() - window.processingStartTime;
const progress = Math.min((elapsed / window.expectedDuration) * 100, 95); // Cap at 95% until completion
$('progressBar').style.width = `${progress}%`;
console.log(`[Progress] ${progress.toFixed(1)}% (${(elapsed/1000).toFixed(1)}s / ${(window.expectedDuration/1000).toFixed(1)}s)`);
}
async function getTaskResult(currentTaskId, currentUid, taskName) {
$('srvStatus').textContent = 'Retrieving results...';
try {
const resultResponse = await fetch(`${SERVER}/get_task_result?task_name=${taskName}&task_id=${currentTaskId}&uid=${currentUid}`);
if (!resultResponse.ok) {
const errorText = await resultResponse.text();
throw new Error(`Failed to get results: ${errorText}`);
}
const resultArrayBuffer = await resultResponse.arrayBuffer();
encServerResult = new Uint8Array(resultArrayBuffer);
console.log(`[Main] Received encrypted result: ${encServerResult.length} bytes`);
const duration = window.taskStartTime ? ((performance.now() - window.taskStartTime) / 1000).toFixed(1) : 'N/A';
$('srvStatus').textContent = `βœ“ Complete! (${duration}s)`;
enable('btnDecrypt');
} catch (e) {
const duration = window.taskStartTime ? ((performance.now() - window.taskStartTime) / 1000).toFixed(1) : 'N/A';
console.error(`[Main] /get_task_result failed after ${duration}s:`, e);
$('srvStatus').textContent = 'Failed to retrieve results. Please try again.';
} finally {
show('spin', false);
show('progressContainer', false);
$('srvComputing').hidden = true;
window.processingStartTime = null;
if (progressTimer) {
clearInterval(progressTimer);
progressTimer = null;
}
}
}
$('btnSend').onclick = async () => {
if ($('spin').hidden === false) {
console.log('[Main] Task submission/polling already in progress, ignoring click');
return;
}
if (!sessionUid || !encTokens) {
alert('Please generate keys and encrypt text first.');
return;
}
show('encIcon', false);
show('spin', true);
show('processingNote', true);
show('progressContainer', false); // Hide initially, show when processing starts
$('progressBar').style.width = '0%'; // Reset progress bar
$('srvStatus').textContent = 'Sending encrypted data...';
$('srvComputing').hidden = true; // Ensure it's hidden initially
window.expectedDuration = currentTokenCount * 30 * 1000; // Convert to milliseconds
window.taskStartTime = performance.now(); // Set start time when task is submitted (for overall duration)
window.processingStartTime = null; // Will be set when processing actually begins
// Clear any existing progress timer
if (progressTimer) {
clearInterval(progressTimer);
progressTimer = null;
}
try {
const formData = new FormData();
formData.append('uid', sessionUid);
formData.append('task_name', 'synthid');
const encryptedInputBlob = new Blob([encTokens], { type: 'application/octet-stream' });
const encryptedInputFile = new File([encryptedInputBlob], "input.fheencrypted");
formData.append('encrypted_input', encryptedInputFile);
const startTaskResponse = await fetch(`${SERVER}/start_task`, {
method: 'POST',
body: formData
});
if (!startTaskResponse.ok) {
const errorText = await startTaskResponse.text();
throw new Error(`Server error: ${startTaskResponse.status} - ${errorText}`);
}
const { task_id: newTaskId } = await startTaskResponse.json();
taskId = newTaskId;
console.log('[Main] Task submitted to server. Task ID:', taskId);
$('srvStatus').textContent = 'Request submitted. Checking status...';
pollTaskStatus(taskId, sessionUid, 0, 10).then(finalStatus => {
if (finalStatus && (finalStatus.status === 'success' || finalStatus.status === 'completed')) {
getTaskResult(taskId, sessionUid, 'synthid');
}
});
} catch (e) {
const duration = window.taskStartTime ? ((performance.now() - window.taskStartTime) / 1000).toFixed(2) : 'N/A';
console.error(`[Main] Task submission failed after ${duration}s:`, e);
$('srvStatus').textContent = 'Failed to submit request. Please try again.';
cleanupPolling();
}
};
$('btnDecrypt').onclick = () => {
try {
console.log('[Main] Starting decryption...');
const dec = decrypt_serialized_u64_radix_flat_wasm(encServerResult, clientKey);
const [flag, score_scaled, total_g] = Array.from(dec);
const rawScore = Number(score_scaled) / 1e6;
console.log('[Main] Decryption successful');
console.log(`[Main] Result - flag: ${flag}, raw_score: ${rawScore}, total_g: ${total_g}`);
console.log(`[Main] Debug - flag type: ${typeof flag}, flag value: ${flag}, flag == 1: ${flag == 1}, flag === 1: ${flag === 1}`);
// Determine result based on flag only - use loose equality to handle BigInt/Number conversion
let resultText, resultClass;
if (Number(flag) === 1) {
resultText = 'βœ… Text is watermarked';
resultClass = 'watermarked';
} else {
resultText = '❌ Text is not watermarked';
resultClass = 'inconclusive';
}
$('decResult').innerHTML = `
<div class="watermark-flag ${resultClass}">${resultText}</div>
`;
} catch (e) {
console.error('[Main] Decryption error:', e);
$('decResult').textContent = `Decryption failed: ${e.message}`;
}
};
function encryptWorker_onmessage(e) {
if (e.data.type === 'ready') {
console.log('[Main] Encryption worker ready');
} else if (e.data.type === 'success') {
encTokens = e.data.result;
console.log(`[Main] Encryption completed: ${encTokens.length} bytes`);
show('encryptSpin', false);
show('encIcon', true);
enable('btnEncrypt', true);
enable('btnSend');
enable('btnDecrypt', false);
$('encStatus').textContent = 'Your text is encrypted πŸ”’';
$('decResult').textContent = '';
} else if (e.data.type === 'error') {
console.error('[Main] Encryption error:', e.data.error);
show('encryptSpin', false);
enable('btnEncrypt', true);
$('encStatus').textContent = `Encryption failed: ${e.data.error}`;
alert(`Encryption failed: ${e.data.error}`);
}
}
function initEncryptWorkerWithKey(keyUint8) {
encryptWorker = new Worker(new URL('./encrypt-worker.js', import.meta.url), { type: 'module' });
encryptWorker.onmessage = encryptWorker_onmessage;
encryptWorker.postMessage({ type: 'init', clientKey: keyUint8 });
}