nisten's picture
Update src/worker.js
f7fbf77 verified
raw
history blame
11.4 kB
import {
// VAD
AutoModel,
// LLM
AutoTokenizer,
AutoModelForCausalLM,
TextStreamer,
InterruptableStoppingCriteria,
// Speech recognition
Tensor,
pipeline,
} from "@huggingface/transformers";
import { KokoroTTS, TextSplitterStream } from "kokoro-js";
import {
MAX_BUFFER_DURATION,
INPUT_SAMPLE_RATE,
SPEECH_THRESHOLD,
EXIT_THRESHOLD,
SPEECH_PAD_SAMPLES,
MAX_NUM_PREV_BUFFERS,
MIN_SILENCE_DURATION_SAMPLES,
MIN_SPEECH_DURATION_SAMPLES,
} from "./constants";
// WebGPU availability check - fail fast
if (!navigator.gpu) {
self.postMessage({
type: "error",
error: new Error("WebGPU not supported. This app requires Chrome 113+, Edge 113+, or Chrome Canary with WebGPU enabled.")
});
throw new Error("WebGPU not available");
}
// TTS Configuration
const model_id = "onnx-community/Kokoro-82M-v1.0-ONNX";
let voice;
const tts = await KokoroTTS.from_pretrained(model_id, {
dtype: "fp16", // Keep fp16 for memory efficiency
device: "webgpu",
}).catch((error) => {
self.postMessage({ error: new Error(`TTS loading failed: ${error.message}`) });
throw error;
});
const device = "webgpu";
self.postMessage({ type: "info", message: `Using device: "${device}"` });
self.postMessage({
type: "info",
message: "Loading models...",
duration: "until_next",
});
// Load VAD model
const silero_vad = await AutoModel.from_pretrained(
"onnx-community/silero-vad",
{
config: { model_type: "custom" },
dtype: "fp32",
},
).catch((error) => {
self.postMessage({ error: new Error(`VAD loading failed: ${error.message}`) });
throw error;
});
// Whisper configuration
const DEVICE_DTYPE_CONFIGS = {
webgpu: {
encoder_model: "fp32",
decoder_model_merged: "fp32",
},
wasm: {
encoder_model: "fp32",
decoder_model_merged: "q8",
},
};
const transcriber = await pipeline(
"automatic-speech-recognition",
"onnx-community/whisper-base",
{
device,
dtype: DEVICE_DTYPE_CONFIGS[device],
// Specify language to avoid warnings
language: "en",
task: "transcribe",
},
).catch((error) => {
self.postMessage({ error: new Error(`Whisper loading failed: ${error.message}`) });
throw error;
});
// Warm up the transcriber
await transcriber(new Float32Array(INPUT_SAMPLE_RATE));
// LLM Configuration - Split tokenizer and model sources
const TOKENIZER_MODEL_ID = "Qwen/Qwen3-1.7B"; // Original repo has tokenizer
const ONNX_MODEL_ID = "onnx-community/Qwen3-1.7B-ONNX"; // ONNX weights
// Load tokenizer from original repo
const tokenizer = await AutoTokenizer.from_pretrained(TOKENIZER_MODEL_ID).catch((error) => {
self.postMessage({ error: new Error(`Tokenizer loading failed: ${error.message}`) });
throw error;
});
// Load ONNX model weights
const llm = await AutoModelForCausalLM.from_pretrained(ONNX_MODEL_ID, {
dtype: "q4f16",
device: "webgpu",
// Add model-specific config for Qwen3
model_config: {
use_cache: true,
attention_bias: false,
}
}).catch((error) => {
self.postMessage({ error: new Error(`LLM loading failed: ${error.message}`) });
throw error;
});
// System prompt optimized for conversational AI
const SYSTEM_MESSAGE = {
role: "system",
content:
"You're a helpful and conversational voice assistant. Keep your responses short, clear, and casual. Focus on being natural and engaging in conversation.",
};
// Warm up the LLM
await llm.generate({ ...tokenizer("x"), max_new_tokens: 1 });
// Conversation state
let messages = [SYSTEM_MESSAGE];
let past_key_values_cache;
let stopping_criteria;
const MAX_CONTEXT_MESSAGES = 20; // Prevent unbounded memory growth
// Send ready signal with available voices
self.postMessage({
type: "status",
status: "ready",
message: "Ready!",
voices: tts.voices,
});
// Audio processing state
const BUFFER = new Float32Array(MAX_BUFFER_DURATION * INPUT_SAMPLE_RATE);
let bufferPointer = 0;
// VAD state
const sr = new Tensor("int64", [INPUT_SAMPLE_RATE], []);
let state = new Tensor("float32", new Float32Array(2 * 1 * 128), [2, 1, 128]);
// Recording state
let isRecording = false;
let isPlaying = false;
/**
* Perform Voice Activity Detection (VAD)
* @param {Float32Array} buffer The new audio buffer
* @returns {Promise<boolean>} `true` if the buffer is speech, `false` otherwise.
*/
async function vad(buffer) {
const input = new Tensor("float32", buffer, [1, buffer.length]);
const { stateN, output } = await silero_vad({ input, sr, state });
state = stateN;
const isSpeech = output.data[0];
return (
isSpeech > SPEECH_THRESHOLD ||
(isRecording && isSpeech >= EXIT_THRESHOLD)
);
}
/**
* Handle speech-to-speech pipeline
* @param {Float32Array} buffer The audio buffer
* @param {Object} data Additional timing data
*/
const speechToSpeech = async (buffer, data) => {
isPlaying = true;
try {
// 1. Transcribe audio
const transcription = await transcriber(buffer);
const text = transcription.text?.trim() || "";
if (!text || text === "[BLANK_AUDIO]") {
isPlaying = false;
return;
}
// Add user message
messages.push({ role: "user", content: text });
// Manage context window
if (messages.length > MAX_CONTEXT_MESSAGES) {
messages = [SYSTEM_MESSAGE, ...messages.slice(-(MAX_CONTEXT_MESSAGES - 1))];
past_key_values_cache = null; // Reset cache when context changes
}
// Set up TTS streaming
const splitter = new TextSplitterStream();
const stream = tts.stream(splitter, { voice });
// Stream TTS output
(async () => {
try {
for await (const { text, phonemes, audio } of stream) {
self.postMessage({ type: "output", text, result: audio });
}
} catch (error) {
console.error("TTS streaming error:", error);
}
})();
// 2. Generate LLM response
const inputs = tokenizer.apply_chat_template(messages, {
add_generation_prompt: true,
return_dict: true,
// Qwen3 specific - disable thinking mode for conversational use
enable_thinking: false,
});
const streamer = new TextStreamer(tokenizer, {
skip_prompt: true,
skip_special_tokens: true,
callback_function: (text) => {
splitter.push(text);
},
token_callback_function: () => {},
});
stopping_criteria = new InterruptableStoppingCriteria();
// Generate with appropriate settings for Qwen3
const { past_key_values, sequences } = await llm.generate({
...inputs,
past_key_values: past_key_values_cache,
// Qwen3 optimal settings for non-thinking mode
do_sample: true,
temperature: 0.7,
top_p: 0.8,
top_k: 20,
max_new_tokens: 512, // Keep responses concise for voice
streamer,
stopping_criteria,
return_dict_in_generate: true,
// Ensure proper EOS handling for Qwen3
eos_token_id: [151643, 151645],
pad_token_id: tokenizer.pad_token_id,
});
past_key_values_cache = past_key_values;
// Close the TTS stream
splitter.close();
// Decode and store assistant response
const decoded = tokenizer.batch_decode(
sequences.slice(null, [inputs.input_ids.dims[1], null]),
{ skip_special_tokens: true },
);
messages.push({ role: "assistant", content: decoded[0] });
} catch (error) {
console.error("Speech-to-speech error:", error);
self.postMessage({
type: "error",
error: new Error(`Processing failed: ${error.message}`)
});
} finally {
isPlaying = false;
}
};
// Audio buffer management
let postSpeechSamples = 0;
let prevBuffers = [];
const resetAfterRecording = (offset = 0) => {
self.postMessage({
type: "status",
status: "recording_end",
message: "Transcribing...",
duration: "until_next",
});
BUFFER.fill(0, offset);
bufferPointer = offset;
isRecording = false;
postSpeechSamples = 0;
};
const dispatchForTranscriptionAndResetAudioBuffer = (overflow) => {
const now = Date.now();
const end = now - ((postSpeechSamples + SPEECH_PAD_SAMPLES) / INPUT_SAMPLE_RATE) * 1000;
const start = end - (bufferPointer / INPUT_SAMPLE_RATE) * 1000;
const duration = end - start;
const overflowLength = overflow?.length ?? 0;
// Prepare padded buffer
const buffer = BUFFER.slice(0, bufferPointer + SPEECH_PAD_SAMPLES);
const prevLength = prevBuffers.reduce((acc, b) => acc + b.length, 0);
const paddedBuffer = new Float32Array(prevLength + buffer.length);
let offset = 0;
for (const prev of prevBuffers) {
paddedBuffer.set(prev, offset);
offset += prev.length;
}
paddedBuffer.set(buffer, offset);
// Process speech
speechToSpeech(paddedBuffer, { start, end, duration });
// Handle overflow
if (overflow) {
BUFFER.set(overflow, 0);
}
resetAfterRecording(overflowLength);
};
// Message handler
self.onmessage = async (event) => {
const { type, buffer } = event.data;
// Block audio during playback
if (type === "audio" && isPlaying) return;
switch (type) {
case "start_call": {
const name = tts.voices[voice ?? "af_heart"]?.name ?? "Heart";
greet(`Hey there, my name is ${name}! How can I help you today?`);
return;
}
case "end_call":
messages = [SYSTEM_MESSAGE];
past_key_values_cache = null;
// Fall through to interrupt
case "interrupt":
stopping_criteria?.interrupt();
return;
case "set_voice":
voice = event.data.voice;
return;
case "playback_ended":
isPlaying = false;
return;
}
// Process audio buffer
const wasRecording = isRecording;
const isSpeech = await vad(buffer);
if (!wasRecording && !isSpeech) {
// Queue non-speech buffers for padding
if (prevBuffers.length >= MAX_NUM_PREV_BUFFERS) {
prevBuffers.shift();
}
prevBuffers.push(buffer);
return;
}
const remaining = BUFFER.length - bufferPointer;
if (buffer.length >= remaining) {
// Buffer overflow - trigger transcription
BUFFER.set(buffer.subarray(0, remaining), bufferPointer);
bufferPointer += remaining;
const overflow = buffer.subarray(remaining);
dispatchForTranscriptionAndResetAudioBuffer(overflow);
return;
} else {
// Add to buffer
BUFFER.set(buffer, bufferPointer);
bufferPointer += buffer.length;
}
if (isSpeech) {
if (!isRecording) {
self.postMessage({
type: "status",
status: "recording_start",
message: "Listening...",
duration: "until_next",
});
}
isRecording = true;
postSpeechSamples = 0;
return;
}
postSpeechSamples += buffer.length;
// Check for end of speech
if (postSpeechSamples < MIN_SILENCE_DURATION_SAMPLES) {
return;
}
if (bufferPointer < MIN_SPEECH_DURATION_SAMPLES) {
resetAfterRecording();
return;
}
dispatchForTranscriptionAndResetAudioBuffer();
};
// Greeting function
function greet(text) {
isPlaying = true;
const splitter = new TextSplitterStream();
const stream = tts.stream(splitter, { voice });
(async () => {
for await (const { text: chunkText, audio } of stream) {
self.postMessage({ type: "output", text: chunkText, result: audio });
}
})();
splitter.push(text);
splitter.close();
messages.push({ role: "assistant", content: text });
}