Spaces:
Runtime error
Runtime error
import { pipeline, env } from "@huggingface/transformers"; | |
// Skip local model check | |
env.allowLocalModels = false; | |
async function supportsWebGPU() { | |
try { | |
if (!navigator.gpu) return false; | |
await navigator.gpu.requestAdapter(); | |
return true; | |
} catch (e) { | |
return false; | |
} | |
} | |
const device = (await supportsWebGPU()) ? "webgpu" : "wasm"; | |
class PipelineManager { | |
static defaultConfigs = { | |
"text-classification": { | |
model: "onnx-community/rubert-tiny-sentiment-balanced-ONNX", | |
}, | |
"image-classification": { | |
model: "onnx-community/mobilenet_v2_1.0_224", | |
}, | |
}; | |
static instances = {}; // key: `${task}:${modelName}` -> pipeline instance | |
static currentTask = "text-classification"; | |
static currentModel = PipelineManager.defaultConfigs["text-classification"].model; | |
static queue = []; | |
static isProcessing = false; | |
static async getInstance(task, modelName, progress_callback = null) { | |
const key = `${task}:${modelName}`; | |
if (!this.instances[key]) { | |
self.postMessage({ status: "initiate", file: modelName, task }); | |
this.instances[key] = await pipeline(task, modelName, { progress_callback, device: device}); | |
self.postMessage({ status: "ready", file: modelName, task }); | |
} | |
return this.instances[key]; | |
} | |
static async processQueue() { | |
if (this.isProcessing || this.queue.length === 0) return; | |
this.isProcessing = true; | |
const { input, task, modelName } = this.queue[this.queue.length - 1]; | |
this.queue = []; | |
try { | |
const classifier = await this.getInstance(task, modelName, (x) => { | |
self.postMessage({ | |
...x, | |
status: x.status || "progress", | |
file: x.file || modelName, | |
name: modelName, | |
task, | |
loaded: x.loaded, | |
total: x.total, | |
progress: x.loaded && x.total ? (x.loaded / x.total) * 100 : 0, | |
}); | |
}); | |
let output; | |
if (task === "image-classification") { | |
// input is a data URL or Blob | |
output = await classifier(input, { top_k: 5 }); | |
} else if (task === "automatic-speech-recognition") { | |
output = await classifier(input); | |
} else { | |
output = await classifier(input, { top_k: 5 }); | |
} | |
self.postMessage({ | |
status: "complete", | |
output, | |
file: modelName, | |
task, | |
}); | |
} catch (error) { | |
self.postMessage({ | |
status: "error", | |
error: error.message, | |
file: modelName, | |
task, | |
}); | |
} | |
this.isProcessing = false; | |
if (this.queue.length > 0) { | |
this.processQueue(); | |
} | |
} | |
} | |
// Listen for messages from the main thread | |
self.addEventListener("message", async (event) => { | |
const { input, modelName, task, action } = event.data; | |
// console.log("Worker received message:", event.data); // Add this line to log the received message t | |
if (action === "load-model") { | |
PipelineManager.currentTask = task || "text-classification"; | |
PipelineManager.currentModel = | |
modelName || | |
PipelineManager.defaultConfigs[PipelineManager.currentTask].model; | |
await PipelineManager.getInstance( | |
PipelineManager.currentTask, | |
PipelineManager.currentModel, | |
(x) => { | |
self.postMessage({ | |
...x, | |
file: PipelineManager.currentModel, | |
status: x.status || "progress", | |
loaded: x.loaded, | |
total: x.total, | |
task: PipelineManager.currentTask, | |
}); | |
} | |
); | |
return; | |
} | |
PipelineManager.queue.push({ | |
input, | |
task: task || PipelineManager.currentTask, | |
modelName: modelName || PipelineManager.currentModel, | |
}); | |
PipelineManager.processQueue(); | |
}); |