Add WASM fallback when WebGPU initialization fails
Browse files
public/workers/text-classification.js
CHANGED
@@ -6,12 +6,35 @@ class MyTextClassificationPipeline {
|
|
6 |
static instance = null
|
7 |
|
8 |
static async getInstance(model, dtype = 'fp32', progress_callback = null) {
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
}
|
16 |
}
|
17 |
|
|
|
6 |
static instance = null
|
7 |
|
8 |
static async getInstance(model, dtype = 'fp32', progress_callback = null) {
|
9 |
+
try {
|
10 |
+
// Try WebGPU first
|
11 |
+
this.instance = await pipeline(this.task, model, {
|
12 |
+
dtype,
|
13 |
+
device: 'webgpu',
|
14 |
+
progress_callback
|
15 |
+
})
|
16 |
+
return this.instance
|
17 |
+
} catch (webgpuError) {
|
18 |
+
// Fallback to WASM if WebGPU fails
|
19 |
+
if (progress_callback) {
|
20 |
+
progress_callback({
|
21 |
+
status: 'fallback',
|
22 |
+
message: 'WebGPU failed, falling back to WASM'
|
23 |
+
})
|
24 |
+
}
|
25 |
+
try {
|
26 |
+
this.instance = await pipeline(this.task, model, {
|
27 |
+
dtype,
|
28 |
+
device: 'wasm',
|
29 |
+
progress_callback
|
30 |
+
})
|
31 |
+
return this.instance
|
32 |
+
} catch (wasmError) {
|
33 |
+
throw new Error(
|
34 |
+
`Both WebGPU and WASM failed. WebGPU error: ${webgpuError.message}. WASM error: ${wasmError.message}`
|
35 |
+
)
|
36 |
+
}
|
37 |
+
}
|
38 |
}
|
39 |
}
|
40 |
|
public/workers/text-generation.js
CHANGED
@@ -7,12 +7,35 @@ class MyTextGenerationPipeline {
|
|
7 |
static currentGeneration = null
|
8 |
|
9 |
static async getInstance(model, dtype = 'fp32', progress_callback = null) {
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
}
|
17 |
|
18 |
static stopGeneration() {
|
|
|
7 |
static currentGeneration = null
|
8 |
|
9 |
static async getInstance(model, dtype = 'fp32', progress_callback = null) {
|
10 |
+
try {
|
11 |
+
// Try WebGPU first
|
12 |
+
this.instance = await pipeline(this.task, model, {
|
13 |
+
dtype,
|
14 |
+
device: 'webgpu',
|
15 |
+
progress_callback
|
16 |
+
})
|
17 |
+
return this.instance
|
18 |
+
} catch (webgpuError) {
|
19 |
+
// Fallback to WASM if WebGPU fails
|
20 |
+
if (progress_callback) {
|
21 |
+
progress_callback({
|
22 |
+
status: 'fallback',
|
23 |
+
message: 'WebGPU failed, falling back to WASM'
|
24 |
+
})
|
25 |
+
}
|
26 |
+
try {
|
27 |
+
this.instance = await pipeline(this.task, model, {
|
28 |
+
dtype,
|
29 |
+
device: 'wasm',
|
30 |
+
progress_callback
|
31 |
+
})
|
32 |
+
return this.instance
|
33 |
+
} catch (wasmError) {
|
34 |
+
throw new Error(
|
35 |
+
`Both WebGPU and WASM failed. WebGPU error: ${webgpuError.message}. WASM error: ${wasmError.message}`
|
36 |
+
)
|
37 |
+
}
|
38 |
+
}
|
39 |
}
|
40 |
|
41 |
static stopGeneration() {
|
public/workers/zero-shot-classification.js
CHANGED
@@ -6,12 +6,35 @@ class MyZeroShotClassificationPipeline {
|
|
6 |
static instance = null
|
7 |
|
8 |
static async getInstance(model, dtype = 'fp32', progress_callback = null) {
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
}
|
16 |
}
|
17 |
|
@@ -50,7 +73,7 @@ self.addEventListener('message', async (event) => {
|
|
50 |
self.postMessage({ status: 'ready' }) // Nothing to process
|
51 |
return
|
52 |
}
|
53 |
-
|
54 |
const split = text.split('\n')
|
55 |
for (const line of split) {
|
56 |
if (line.trim()) {
|
@@ -58,7 +81,7 @@ self.addEventListener('message', async (event) => {
|
|
58 |
hypothesis_template: 'This text is about {}.',
|
59 |
multi_label: true
|
60 |
})
|
61 |
-
|
62 |
self.postMessage({ status: 'output', output })
|
63 |
}
|
64 |
}
|
|
|
6 |
static instance = null
|
7 |
|
8 |
static async getInstance(model, dtype = 'fp32', progress_callback = null) {
|
9 |
+
try {
|
10 |
+
// Try WebGPU first
|
11 |
+
this.instance = await pipeline(this.task, model, {
|
12 |
+
dtype,
|
13 |
+
device: 'webgpu',
|
14 |
+
progress_callback
|
15 |
+
})
|
16 |
+
return this.instance
|
17 |
+
} catch (webgpuError) {
|
18 |
+
// Fallback to WASM if WebGPU fails
|
19 |
+
if (progress_callback) {
|
20 |
+
progress_callback({
|
21 |
+
status: 'fallback',
|
22 |
+
message: 'WebGPU failed, falling back to WASM'
|
23 |
+
})
|
24 |
+
}
|
25 |
+
try {
|
26 |
+
this.instance = await pipeline(this.task, model, {
|
27 |
+
dtype,
|
28 |
+
device: 'wasm',
|
29 |
+
progress_callback
|
30 |
+
})
|
31 |
+
return this.instance
|
32 |
+
} catch (wasmError) {
|
33 |
+
throw new Error(
|
34 |
+
`Both WebGPU and WASM failed. WebGPU error: ${webgpuError.message}. WASM error: ${wasmError.message}`
|
35 |
+
)
|
36 |
+
}
|
37 |
+
}
|
38 |
}
|
39 |
}
|
40 |
|
|
|
73 |
self.postMessage({ status: 'ready' }) // Nothing to process
|
74 |
return
|
75 |
}
|
76 |
+
|
77 |
const split = text.split('\n')
|
78 |
for (const line of split) {
|
79 |
if (line.trim()) {
|
|
|
81 |
hypothesis_template: 'This text is about {}.',
|
82 |
multi_label: true
|
83 |
})
|
84 |
+
|
85 |
self.postMessage({ status: 'output', output })
|
86 |
}
|
87 |
}
|