File size: 4,188 Bytes
22f8eb7 415aaef 22f8eb7 415aaef 22f8eb7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 |
/* eslint-disable no-restricted-globals */
import { pipeline } from 'https://cdn.jsdelivr.net/npm/@huggingface/transformers@latest'
class MyFeatureExtractionPipeline {
static task = 'feature-extraction'
static instance = null
static async getInstance(model, dtype = 'fp32', progress_callback = null) {
try {
// Try WebGPU first
throw Error('onnxruntime-web failed for feature-extraction with transformers 3.7.1')
// this.instance = await pipeline(this.task, model, {
// dtype,
// device: 'webgpu',
// progress_callback,
// })
// return this.instance
} catch (webgpuError) {
// Fallback to WASM if WebGPU fails
if (progress_callback) {
progress_callback({
status: 'fallback',
message: 'WebGPU failed, falling back to WASM'
})
}
try {
this.instance = await pipeline(this.task, model, {
dtype,
device: 'wasm',
progress_callback
})
return this.instance
} catch (wasmError) {
throw new Error(
`Both WebGPU and WASM failed. WebGPU error: ${webgpuError.message}. WASM error: ${wasmError.message}`
)
}
}
}
}
// Listen for messages from the main thread
self.addEventListener('message', async (event) => {
try {
const { type, model, dtype, texts, config } = event.data
if (!model) {
self.postMessage({
status: 'error',
output: 'No model provided'
})
return
}
// Get the pipeline instance
const extractor = await MyFeatureExtractionPipeline.getInstance(
model,
dtype,
(x) => {
self.postMessage({ status: 'loading', output: x })
}
)
if (type === 'load') {
self.postMessage({
status: 'ready',
output: `Feature extraction model ${model}, dtype ${dtype} loaded`
})
return
}
if (type === 'extract') {
if (!texts || !Array.isArray(texts) || texts.length === 0) {
self.postMessage({
status: 'error',
output: 'No texts provided for feature extraction'
})
return
}
const embeddings = []
for (let i = 0; i < texts.length; i++) {
const text = texts[i]
try {
const output = await extractor(text, config)
// Convert tensor to array and get the embedding
let embedding
if (output && typeof output.tolist === 'function') {
embedding = output.tolist()
} else if (Array.isArray(output)) {
embedding = output
} else if (output && output.data) {
embedding = Array.from(output.data)
} else {
throw new Error('Unexpected output format from feature extraction')
}
// If the embedding is 2D (batch dimension), take the first element
if (Array.isArray(embedding[0])) {
embedding = embedding[0]
}
embeddings.push({
text: text,
embedding: embedding,
index: i
})
// Send progress update
self.postMessage({
status: 'progress',
output: {
completed: i + 1,
total: texts.length,
currentText: text,
embedding: embedding
}
})
} catch (error) {
embeddings.push({
text: text,
embedding: null,
error: error.message,
index: i
})
self.postMessage({
status: 'progress',
output: {
completed: i + 1,
total: texts.length,
currentText: text,
error: error.message
}
})
}
}
self.postMessage({
status: 'output',
output: {
embeddings: embeddings,
completed: true
}
})
self.postMessage({ status: 'ready' })
}
} catch (error) {
self.postMessage({
status: 'error',
output: error.message || 'An error occurred during feature extraction'
})
}
})
|