|
|
|
import { pipeline } from 'https://cdn.jsdelivr.net/npm/@huggingface/transformers@latest' |
|
|
|
class MyTextGenerationPipeline { |
|
static task = 'text-generation' |
|
static instance = null |
|
static currentGeneration = null |
|
|
|
static async getInstance(model, dtype = 'fp32', progress_callback = null) { |
|
try { |
|
|
|
this.instance = await pipeline(this.task, model, { |
|
dtype, |
|
device: 'webgpu', |
|
progress_callback |
|
}) |
|
return this.instance |
|
} catch (webgpuError) { |
|
|
|
if (progress_callback) { |
|
progress_callback({ |
|
status: 'fallback', |
|
message: 'WebGPU failed, falling back to WASM' |
|
}) |
|
} |
|
try { |
|
this.instance = await pipeline(this.task, model, { |
|
dtype, |
|
device: 'wasm', |
|
progress_callback |
|
}) |
|
return this.instance |
|
} catch (wasmError) { |
|
throw new Error( |
|
`Both WebGPU and WASM failed. WebGPU error: ${webgpuError.message}. WASM error: ${wasmError.message}` |
|
) |
|
} |
|
} |
|
} |
|
|
|
static stopGeneration() { |
|
if (this.currentGeneration) { |
|
this.currentGeneration.abort() |
|
this.currentGeneration = null |
|
} |
|
} |
|
} |
|
|
|
|
|
self.addEventListener('message', async (event) => { |
|
try { |
|
const { type, model, dtype, messages, prompt, hasChatTemplate, config } = |
|
event.data |
|
|
|
if (type === 'stop') { |
|
MyTextGenerationPipeline.stopGeneration() |
|
self.postMessage({ status: 'ready' }) |
|
return |
|
} |
|
|
|
if (!model) { |
|
self.postMessage({ |
|
status: 'error', |
|
output: 'No model provided' |
|
}) |
|
return |
|
} |
|
|
|
|
|
const generator = await MyTextGenerationPipeline.getInstance( |
|
model, |
|
dtype, |
|
(x) => { |
|
self.postMessage({ status: 'loading', output: x }) |
|
} |
|
) |
|
|
|
if (type === 'load') { |
|
self.postMessage({ |
|
status: 'ready', |
|
output: `Model ${model}, dtype ${dtype} loaded` |
|
}) |
|
return |
|
} |
|
|
|
if (type === 'generate') { |
|
let inputText = '' |
|
|
|
if (hasChatTemplate && messages && messages.length > 0) { |
|
inputText = messages |
|
} else if (!hasChatTemplate && prompt) { |
|
inputText = prompt |
|
} else { |
|
self.postMessage({ status: 'ready' }) |
|
return |
|
} |
|
|
|
const options = { |
|
max_new_tokens: config.max_new_tokens || 100, |
|
temperature: config.temperature || 0.7, |
|
do_sample: config.do_sample !== false, |
|
...(config.top_p && { top_p }), |
|
...(config.top_k && { top_k }) |
|
} |
|
|
|
|
|
const abortController = new AbortController() |
|
MyTextGenerationPipeline.currentGeneration = abortController |
|
|
|
try { |
|
const output = await generator(inputText, { |
|
...options, |
|
signal: abortController.signal |
|
}) |
|
|
|
if (hasChatTemplate) { |
|
|
|
self.postMessage({ |
|
status: 'output', |
|
output: output[0].generated_text.slice(-1)[0] |
|
}) |
|
} else { |
|
self.postMessage({ |
|
status: 'output', |
|
output: { |
|
role: 'assistant', |
|
content: output[0].generated_text |
|
} |
|
}) |
|
} |
|
|
|
self.postMessage({ status: 'ready' }) |
|
} catch (error) { |
|
if (error.name === 'AbortError') { |
|
self.postMessage({ status: 'ready' }) |
|
} else { |
|
throw error |
|
} |
|
} finally { |
|
MyTextGenerationPipeline.currentGeneration = null |
|
} |
|
} |
|
} catch (error) { |
|
self.postMessage({ |
|
status: 'error', |
|
output: error.message || 'An error occurred during text generation' |
|
}) |
|
} |
|
}) |
|
|