Vokturz commited on
Commit
ac2af95
·
1 Parent(s): 2656c1e

Add WASM fallback when WebGPU initialization fails

Browse files
public/workers/text-classification.js CHANGED
@@ -6,12 +6,35 @@ class MyTextClassificationPipeline {
6
  static instance = null
7
 
8
  static async getInstance(model, dtype = 'fp32', progress_callback = null) {
9
- this.instance = pipeline(this.task, model, {
10
- dtype,
11
- device: 'webgpu',
12
- progress_callback
13
- })
14
- return this.instance
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  }
16
  }
17
 
 
6
  static instance = null
7
 
8
  static async getInstance(model, dtype = 'fp32', progress_callback = null) {
9
+ try {
10
+ // Try WebGPU first
11
+ this.instance = await pipeline(this.task, model, {
12
+ dtype,
13
+ device: 'webgpu',
14
+ progress_callback
15
+ })
16
+ return this.instance
17
+ } catch (webgpuError) {
18
+ // Fallback to WASM if WebGPU fails
19
+ if (progress_callback) {
20
+ progress_callback({
21
+ status: 'fallback',
22
+ message: 'WebGPU failed, falling back to WASM'
23
+ })
24
+ }
25
+ try {
26
+ this.instance = await pipeline(this.task, model, {
27
+ dtype,
28
+ device: 'wasm',
29
+ progress_callback
30
+ })
31
+ return this.instance
32
+ } catch (wasmError) {
33
+ throw new Error(
34
+ `Both WebGPU and WASM failed. WebGPU error: ${webgpuError.message}. WASM error: ${wasmError.message}`
35
+ )
36
+ }
37
+ }
38
  }
39
  }
40
 
public/workers/text-generation.js CHANGED
@@ -7,12 +7,35 @@ class MyTextGenerationPipeline {
7
  static currentGeneration = null
8
 
9
  static async getInstance(model, dtype = 'fp32', progress_callback = null) {
10
- this.instance = pipeline(this.task, model, {
11
- dtype,
12
- device: 'webgpu',
13
- progress_callback
14
- })
15
- return this.instance
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  }
17
 
18
  static stopGeneration() {
 
7
  static currentGeneration = null
8
 
9
  static async getInstance(model, dtype = 'fp32', progress_callback = null) {
10
+ try {
11
+ // Try WebGPU first
12
+ this.instance = await pipeline(this.task, model, {
13
+ dtype,
14
+ device: 'webgpu',
15
+ progress_callback
16
+ })
17
+ return this.instance
18
+ } catch (webgpuError) {
19
+ // Fallback to WASM if WebGPU fails
20
+ if (progress_callback) {
21
+ progress_callback({
22
+ status: 'fallback',
23
+ message: 'WebGPU failed, falling back to WASM'
24
+ })
25
+ }
26
+ try {
27
+ this.instance = await pipeline(this.task, model, {
28
+ dtype,
29
+ device: 'wasm',
30
+ progress_callback
31
+ })
32
+ return this.instance
33
+ } catch (wasmError) {
34
+ throw new Error(
35
+ `Both WebGPU and WASM failed. WebGPU error: ${webgpuError.message}. WASM error: ${wasmError.message}`
36
+ )
37
+ }
38
+ }
39
  }
40
 
41
  static stopGeneration() {
public/workers/zero-shot-classification.js CHANGED
@@ -6,12 +6,35 @@ class MyZeroShotClassificationPipeline {
6
  static instance = null
7
 
8
  static async getInstance(model, dtype = 'fp32', progress_callback = null) {
9
- this.instance = pipeline(this.task, model, {
10
- dtype,
11
- device: 'webgpu',
12
- progress_callback
13
- })
14
- return this.instance
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  }
16
  }
17
 
@@ -50,7 +73,7 @@ self.addEventListener('message', async (event) => {
50
  self.postMessage({ status: 'ready' }) // Nothing to process
51
  return
52
  }
53
-
54
  const split = text.split('\n')
55
  for (const line of split) {
56
  if (line.trim()) {
@@ -58,7 +81,7 @@ self.addEventListener('message', async (event) => {
58
  hypothesis_template: 'This text is about {}.',
59
  multi_label: true
60
  })
61
-
62
  self.postMessage({ status: 'output', output })
63
  }
64
  }
 
6
  static instance = null
7
 
8
  static async getInstance(model, dtype = 'fp32', progress_callback = null) {
9
+ try {
10
+ // Try WebGPU first
11
+ this.instance = await pipeline(this.task, model, {
12
+ dtype,
13
+ device: 'webgpu',
14
+ progress_callback
15
+ })
16
+ return this.instance
17
+ } catch (webgpuError) {
18
+ // Fallback to WASM if WebGPU fails
19
+ if (progress_callback) {
20
+ progress_callback({
21
+ status: 'fallback',
22
+ message: 'WebGPU failed, falling back to WASM'
23
+ })
24
+ }
25
+ try {
26
+ this.instance = await pipeline(this.task, model, {
27
+ dtype,
28
+ device: 'wasm',
29
+ progress_callback
30
+ })
31
+ return this.instance
32
+ } catch (wasmError) {
33
+ throw new Error(
34
+ `Both WebGPU and WASM failed. WebGPU error: ${webgpuError.message}. WASM error: ${wasmError.message}`
35
+ )
36
+ }
37
+ }
38
  }
39
  }
40
 
 
73
  self.postMessage({ status: 'ready' }) // Nothing to process
74
  return
75
  }
76
+
77
  const split = text.split('\n')
78
  for (const line of split) {
79
  if (line.trim()) {
 
81
  hypothesis_template: 'This text is about {}.',
82
  multi_label: true
83
  })
84
+
85
  self.postMessage({ status: 'output', output })
86
  }
87
  }