Spaces:
Running
Running
import { | |
// VAD | |
AutoModel, | |
// LLM | |
AutoTokenizer, | |
AutoModelForCausalLM, | |
TextStreamer, | |
InterruptableStoppingCriteria, | |
// Speech recognition | |
Tensor, | |
pipeline, | |
} from "@huggingface/transformers"; | |
import { KokoroTTS, TextSplitterStream } from "kokoro-js"; | |
import { | |
MAX_BUFFER_DURATION, | |
INPUT_SAMPLE_RATE, | |
SPEECH_THRESHOLD, | |
EXIT_THRESHOLD, | |
SPEECH_PAD_SAMPLES, | |
MAX_NUM_PREV_BUFFERS, | |
MIN_SILENCE_DURATION_SAMPLES, | |
MIN_SPEECH_DURATION_SAMPLES, | |
} from "./constants"; | |
// WebGPU availability check - fail fast | |
if (!navigator.gpu) { | |
self.postMessage({ | |
type: "error", | |
error: new Error("WebGPU not supported. This app requires Chrome 113+, Edge 113+, or Chrome Canary with WebGPU enabled.") | |
}); | |
throw new Error("WebGPU not available"); | |
} | |
// TTS Configuration | |
const model_id = "onnx-community/Kokoro-82M-v1.0-ONNX"; | |
let voice; | |
const tts = await KokoroTTS.from_pretrained(model_id, { | |
dtype: "fp16", // Keep fp16 for memory efficiency | |
device: "webgpu", | |
}).catch((error) => { | |
self.postMessage({ error: new Error(`TTS loading failed: ${error.message}`) }); | |
throw error; | |
}); | |
const device = "webgpu"; | |
self.postMessage({ type: "info", message: `Using device: "${device}"` }); | |
self.postMessage({ | |
type: "info", | |
message: "Loading models...", | |
duration: "until_next", | |
}); | |
// Load VAD model | |
const silero_vad = await AutoModel.from_pretrained( | |
"onnx-community/silero-vad", | |
{ | |
config: { model_type: "custom" }, | |
dtype: "fp32", | |
}, | |
).catch((error) => { | |
self.postMessage({ error: new Error(`VAD loading failed: ${error.message}`) }); | |
throw error; | |
}); | |
// Whisper configuration | |
const DEVICE_DTYPE_CONFIGS = { | |
webgpu: { | |
encoder_model: "fp32", | |
decoder_model_merged: "fp32", | |
}, | |
wasm: { | |
encoder_model: "fp32", | |
decoder_model_merged: "q8", | |
}, | |
}; | |
const transcriber = await pipeline( | |
"automatic-speech-recognition", | |
"onnx-community/whisper-base", | |
{ | |
device, | |
dtype: DEVICE_DTYPE_CONFIGS[device], | |
// Specify language to avoid warnings | |
language: "en", | |
task: "transcribe", | |
}, | |
).catch((error) => { | |
self.postMessage({ error: new Error(`Whisper loading failed: ${error.message}`) }); | |
throw error; | |
}); | |
// Warm up the transcriber | |
await transcriber(new Float32Array(INPUT_SAMPLE_RATE)); | |
// LLM Configuration - Split tokenizer and model sources | |
const TOKENIZER_MODEL_ID = "Qwen/Qwen3-1.7B"; // Original repo has tokenizer | |
const ONNX_MODEL_ID = "onnx-community/Qwen3-1.7B-ONNX"; // ONNX weights | |
// Load tokenizer from original repo | |
const tokenizer = await AutoTokenizer.from_pretrained(TOKENIZER_MODEL_ID).catch((error) => { | |
self.postMessage({ error: new Error(`Tokenizer loading failed: ${error.message}`) }); | |
throw error; | |
}); | |
// Load ONNX model weights | |
const llm = await AutoModelForCausalLM.from_pretrained(ONNX_MODEL_ID, { | |
dtype: "q4f16", | |
device: "webgpu", | |
// Add model-specific config for Qwen3 | |
model_config: { | |
use_cache: true, | |
attention_bias: false, | |
} | |
}).catch((error) => { | |
self.postMessage({ error: new Error(`LLM loading failed: ${error.message}`) }); | |
throw error; | |
}); | |
// System prompt optimized for conversational AI | |
const SYSTEM_MESSAGE = { | |
role: "system", | |
content: | |
"You're a helpful and conversational voice assistant. Keep your responses short, clear, and casual. Focus on being natural and engaging in conversation.", | |
}; | |
// Warm up the LLM | |
await llm.generate({ ...tokenizer("x"), max_new_tokens: 1 }); | |
// Conversation state | |
let messages = [SYSTEM_MESSAGE]; | |
let past_key_values_cache; | |
let stopping_criteria; | |
const MAX_CONTEXT_MESSAGES = 20; // Prevent unbounded memory growth | |
// Send ready signal with available voices | |
self.postMessage({ | |
type: "status", | |
status: "ready", | |
message: "Ready!", | |
voices: tts.voices, | |
}); | |
// Audio processing state | |
const BUFFER = new Float32Array(MAX_BUFFER_DURATION * INPUT_SAMPLE_RATE); | |
let bufferPointer = 0; | |
// VAD state | |
const sr = new Tensor("int64", [INPUT_SAMPLE_RATE], []); | |
let state = new Tensor("float32", new Float32Array(2 * 1 * 128), [2, 1, 128]); | |
// Recording state | |
let isRecording = false; | |
let isPlaying = false; | |
/** | |
* Perform Voice Activity Detection (VAD) | |
* @param {Float32Array} buffer The new audio buffer | |
* @returns {Promise<boolean>} `true` if the buffer is speech, `false` otherwise. | |
*/ | |
async function vad(buffer) { | |
const input = new Tensor("float32", buffer, [1, buffer.length]); | |
const { stateN, output } = await silero_vad({ input, sr, state }); | |
state = stateN; | |
const isSpeech = output.data[0]; | |
return ( | |
isSpeech > SPEECH_THRESHOLD || | |
(isRecording && isSpeech >= EXIT_THRESHOLD) | |
); | |
} | |
/** | |
* Handle speech-to-speech pipeline | |
* @param {Float32Array} buffer The audio buffer | |
* @param {Object} data Additional timing data | |
*/ | |
const speechToSpeech = async (buffer, data) => { | |
isPlaying = true; | |
try { | |
// 1. Transcribe audio | |
const transcription = await transcriber(buffer); | |
const text = transcription.text?.trim() || ""; | |
if (!text || text === "[BLANK_AUDIO]") { | |
isPlaying = false; | |
return; | |
} | |
// Add user message | |
messages.push({ role: "user", content: text }); | |
// Manage context window | |
if (messages.length > MAX_CONTEXT_MESSAGES) { | |
messages = [SYSTEM_MESSAGE, ...messages.slice(-(MAX_CONTEXT_MESSAGES - 1))]; | |
past_key_values_cache = null; // Reset cache when context changes | |
} | |
// Set up TTS streaming | |
const splitter = new TextSplitterStream(); | |
const stream = tts.stream(splitter, { voice }); | |
// Stream TTS output | |
(async () => { | |
try { | |
for await (const { text, phonemes, audio } of stream) { | |
self.postMessage({ type: "output", text, result: audio }); | |
} | |
} catch (error) { | |
console.error("TTS streaming error:", error); | |
} | |
})(); | |
// 2. Generate LLM response | |
const inputs = tokenizer.apply_chat_template(messages, { | |
add_generation_prompt: true, | |
return_dict: true, | |
// Qwen3 specific - disable thinking mode for conversational use | |
enable_thinking: false, | |
}); | |
const streamer = new TextStreamer(tokenizer, { | |
skip_prompt: true, | |
skip_special_tokens: true, | |
callback_function: (text) => { | |
splitter.push(text); | |
}, | |
token_callback_function: () => {}, | |
}); | |
stopping_criteria = new InterruptableStoppingCriteria(); | |
// Generate with appropriate settings for Qwen3 | |
const { past_key_values, sequences } = await llm.generate({ | |
...inputs, | |
past_key_values: past_key_values_cache, | |
// Qwen3 optimal settings for non-thinking mode | |
do_sample: true, | |
temperature: 0.7, | |
top_p: 0.8, | |
top_k: 20, | |
max_new_tokens: 512, // Keep responses concise for voice | |
streamer, | |
stopping_criteria, | |
return_dict_in_generate: true, | |
// Ensure proper EOS handling for Qwen3 | |
eos_token_id: [151643, 151645], | |
pad_token_id: tokenizer.pad_token_id, | |
}); | |
past_key_values_cache = past_key_values; | |
// Close the TTS stream | |
splitter.close(); | |
// Decode and store assistant response | |
const decoded = tokenizer.batch_decode( | |
sequences.slice(null, [inputs.input_ids.dims[1], null]), | |
{ skip_special_tokens: true }, | |
); | |
messages.push({ role: "assistant", content: decoded[0] }); | |
} catch (error) { | |
console.error("Speech-to-speech error:", error); | |
self.postMessage({ | |
type: "error", | |
error: new Error(`Processing failed: ${error.message}`) | |
}); | |
} finally { | |
isPlaying = false; | |
} | |
}; | |
// Audio buffer management | |
let postSpeechSamples = 0; | |
let prevBuffers = []; | |
const resetAfterRecording = (offset = 0) => { | |
self.postMessage({ | |
type: "status", | |
status: "recording_end", | |
message: "Transcribing...", | |
duration: "until_next", | |
}); | |
BUFFER.fill(0, offset); | |
bufferPointer = offset; | |
isRecording = false; | |
postSpeechSamples = 0; | |
}; | |
const dispatchForTranscriptionAndResetAudioBuffer = (overflow) => { | |
const now = Date.now(); | |
const end = now - ((postSpeechSamples + SPEECH_PAD_SAMPLES) / INPUT_SAMPLE_RATE) * 1000; | |
const start = end - (bufferPointer / INPUT_SAMPLE_RATE) * 1000; | |
const duration = end - start; | |
const overflowLength = overflow?.length ?? 0; | |
// Prepare padded buffer | |
const buffer = BUFFER.slice(0, bufferPointer + SPEECH_PAD_SAMPLES); | |
const prevLength = prevBuffers.reduce((acc, b) => acc + b.length, 0); | |
const paddedBuffer = new Float32Array(prevLength + buffer.length); | |
let offset = 0; | |
for (const prev of prevBuffers) { | |
paddedBuffer.set(prev, offset); | |
offset += prev.length; | |
} | |
paddedBuffer.set(buffer, offset); | |
// Process speech | |
speechToSpeech(paddedBuffer, { start, end, duration }); | |
// Handle overflow | |
if (overflow) { | |
BUFFER.set(overflow, 0); | |
} | |
resetAfterRecording(overflowLength); | |
}; | |
// Message handler | |
self.onmessage = async (event) => { | |
const { type, buffer } = event.data; | |
// Block audio during playback | |
if (type === "audio" && isPlaying) return; | |
switch (type) { | |
case "start_call": { | |
const name = tts.voices[voice ?? "af_heart"]?.name ?? "Heart"; | |
greet(`Hey there, my name is ${name}! How can I help you today?`); | |
return; | |
} | |
case "end_call": | |
messages = [SYSTEM_MESSAGE]; | |
past_key_values_cache = null; | |
// Fall through to interrupt | |
case "interrupt": | |
stopping_criteria?.interrupt(); | |
return; | |
case "set_voice": | |
voice = event.data.voice; | |
return; | |
case "playback_ended": | |
isPlaying = false; | |
return; | |
} | |
// Process audio buffer | |
const wasRecording = isRecording; | |
const isSpeech = await vad(buffer); | |
if (!wasRecording && !isSpeech) { | |
// Queue non-speech buffers for padding | |
if (prevBuffers.length >= MAX_NUM_PREV_BUFFERS) { | |
prevBuffers.shift(); | |
} | |
prevBuffers.push(buffer); | |
return; | |
} | |
const remaining = BUFFER.length - bufferPointer; | |
if (buffer.length >= remaining) { | |
// Buffer overflow - trigger transcription | |
BUFFER.set(buffer.subarray(0, remaining), bufferPointer); | |
bufferPointer += remaining; | |
const overflow = buffer.subarray(remaining); | |
dispatchForTranscriptionAndResetAudioBuffer(overflow); | |
return; | |
} else { | |
// Add to buffer | |
BUFFER.set(buffer, bufferPointer); | |
bufferPointer += buffer.length; | |
} | |
if (isSpeech) { | |
if (!isRecording) { | |
self.postMessage({ | |
type: "status", | |
status: "recording_start", | |
message: "Listening...", | |
duration: "until_next", | |
}); | |
} | |
isRecording = true; | |
postSpeechSamples = 0; | |
return; | |
} | |
postSpeechSamples += buffer.length; | |
// Check for end of speech | |
if (postSpeechSamples < MIN_SILENCE_DURATION_SAMPLES) { | |
return; | |
} | |
if (bufferPointer < MIN_SPEECH_DURATION_SAMPLES) { | |
resetAfterRecording(); | |
return; | |
} | |
dispatchForTranscriptionAndResetAudioBuffer(); | |
}; | |
// Greeting function | |
function greet(text) { | |
isPlaying = true; | |
const splitter = new TextSplitterStream(); | |
const stream = tts.stream(splitter, { voice }); | |
(async () => { | |
for await (const { text: chunkText, audio } of stream) { | |
self.postMessage({ type: "output", text: chunkText, result: audio }); | |
} | |
})(); | |
splitter.push(text); | |
splitter.close(); | |
messages.push({ role: "assistant", content: text }); | |
} |