Vokturz commited on
Commit
b1c66bb
·
1 Parent(s): 2f35054

fix: enhance zero-shot classification worker and UI components

Browse files
public/workers/zero-shot-classification.js CHANGED
@@ -1,46 +1,73 @@
1
  /* eslint-disable no-restricted-globals */
2
- import { pipeline } from '@huggingface/transformers';
3
 
4
  class MyZeroShotClassificationPipeline {
5
- static task = 'zero-shot-classification';
6
- static instance = null;
7
 
8
- static async getInstance(model, progress_callback = null) {
9
- this.instance ??= pipeline(this.task, model, {
 
 
10
  progress_callback
11
- });
12
-
13
- return this.instance;
14
  }
15
  }
16
 
17
  // Listen for messages from the main thread
18
  self.addEventListener('message', async (event) => {
19
- const { text, labels, model } = event.data;
20
- if (!model) {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  self.postMessage({
22
  status: 'error',
23
- output: 'No model provided'
24
- });
25
- return;
26
- }
27
-
28
- // Retrieve the pipeline. When called for the first time,
29
- // this will load the pipeline and save it for future use.
30
- const classifier = await MyZeroShotClassificationPipeline.getInstance(model, (x) => {
31
- // We also add a progress callback to the pipeline so that we can
32
- // track model loading.
33
- self.postMessage({ status: 'progress', output: x });
34
- });
35
- const split = text.split('\n');
36
- for (const line of split) {
37
- const output = await classifier(line, labels, {
38
- hypothesis_template: 'This text is about {}.',
39
- multi_label: true
40
- });
41
- // Send the output back to the main thread
42
- self.postMessage({ status: 'output', output });
43
  }
44
- // Send the output back to the main thread
45
- self.postMessage({ status: 'complete' });
46
- });
 
1
  /* eslint-disable no-restricted-globals */
2
+ import { pipeline } from 'https://cdn.jsdelivr.net/npm/@huggingface/transformers@3.6.3'
3
 
4
  class MyZeroShotClassificationPipeline {
5
+ static task = 'zero-shot-classification'
6
+ static instance = null
7
 
8
+ static async getInstance(model, dtype = 'fp32', progress_callback = null) {
9
+ this.instance = pipeline(this.task, model, {
10
+ dtype,
11
+ device: 'webgpu',
12
  progress_callback
13
+ })
14
+ return this.instance
 
15
  }
16
  }
17
 
18
  // Listen for messages from the main thread
19
  self.addEventListener('message', async (event) => {
20
+ try {
21
+ const { type, model, dtype, text, labels } = event.data
22
+
23
+ if (!model) {
24
+ self.postMessage({
25
+ status: 'error',
26
+ output: 'No model provided'
27
+ })
28
+ return
29
+ }
30
+
31
+ // Retrieve the pipeline. This will download the model if not already cached.
32
+ const classifier = await MyZeroShotClassificationPipeline.getInstance(
33
+ model,
34
+ dtype,
35
+ (x) => {
36
+ self.postMessage({ status: 'loading', output: x })
37
+ }
38
+ )
39
+
40
+ if (type === 'load') {
41
+ self.postMessage({
42
+ status: 'ready',
43
+ output: `Model ${model}, dtype ${dtype} loaded`
44
+ })
45
+ return
46
+ }
47
+
48
+ if (type === 'classify') {
49
+ if (!text || !labels) {
50
+ self.postMessage({ status: 'ready' }) // Nothing to process
51
+ return
52
+ }
53
+
54
+ const split = text.split('\n')
55
+ for (const line of split) {
56
+ if (line.trim()) {
57
+ const output = await classifier(line, labels, {
58
+ hypothesis_template: 'This text is about {}.',
59
+ multi_label: true
60
+ })
61
+
62
+ self.postMessage({ status: 'output', output })
63
+ }
64
+ }
65
+ self.postMessage({ status: 'ready' })
66
+ }
67
+ } catch (error) {
68
  self.postMessage({
69
  status: 'error',
70
+ output: error.message || 'An error occurred during processing'
71
+ })
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  }
73
+ })
 
 
src/components/ModelLoader.tsx CHANGED
@@ -64,6 +64,8 @@ const ModelLoader = () => {
64
 
65
  const onMessageReceived = (e: MessageEvent<WorkerMessage>) => {
66
  const { status, output } = e.data
 
 
67
  if (status === 'ready') {
68
  setStatus('ready')
69
  if (e.data.output) console.log(e.data.output)
@@ -81,6 +83,7 @@ const ModelLoader = () => {
81
  setStatus('output')
82
  const result = e.data.output!
83
  setResults((prev: any[]) => [...prev, result])
 
84
  // console.log(result)
85
  } else if (status === 'error') {
86
  setStatus('error')
@@ -141,7 +144,6 @@ const ModelLoader = () => {
141
  }
142
  className="appearance-none bg-white border border-gray-300 rounded-md px-3 py-1 pr-8 text-xs text-gray-700 focus:outline-none focus:ring-2 focus:ring-blue-500 focus:border-blue-500"
143
  >
144
- <option value="">Select quantization</option>
145
  {modelInfo.supportedQuantizations.map((quant) => (
146
  <option key={quant} value={quant}>
147
  {quant}
 
64
 
65
  const onMessageReceived = (e: MessageEvent<WorkerMessage>) => {
66
  const { status, output } = e.data
67
+ console.log('Received output from worker', e.data)
68
+
69
  if (status === 'ready') {
70
  setStatus('ready')
71
  if (e.data.output) console.log(e.data.output)
 
83
  setStatus('output')
84
  const result = e.data.output!
85
  setResults((prev: any[]) => [...prev, result])
86
+
87
  // console.log(result)
88
  } else if (status === 'error') {
89
  setStatus('error')
 
144
  }
145
  className="appearance-none bg-white border border-gray-300 rounded-md px-3 py-1 pr-8 text-xs text-gray-700 focus:outline-none focus:ring-2 focus:ring-blue-500 focus:border-blue-500"
146
  >
 
147
  {modelInfo.supportedQuantizations.map((quant) => (
148
  <option key={quant} value={quant}>
149
  {quant}
src/components/ZeroShotClassification.tsx CHANGED
@@ -1,4 +1,3 @@
1
- // src/App.tsx
2
  import { useState, useRef, useEffect, useCallback } from 'react'
3
  import {
4
  Section,
@@ -49,31 +48,40 @@ function ZeroShotClassification() {
49
  PLACEHOLDER_SECTIONS.map((title) => ({ title, items: [] }))
50
  )
51
 
52
- const { status, setStatus, modelInfo } = useModel()
53
 
54
- // Create a reference to the worker object.
55
- const worker = useRef<Worker | null>(null)
56
-
57
- // We use the `useEffect` hook to setup the worker as soon as the `App` component is mounted.
58
- useEffect(() => {
59
- if (!worker.current) {
60
  return
61
- // Create the worker if it does not yet exist.
62
- // worker.current = new Worker(
63
- // new URL('../workers/zero-shot-classification.js', import.meta.url),
64
- // {
65
- // type: 'module'
66
- // }
67
- // )
68
  }
69
 
70
- // Create a callback function for messages from the worker thread.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  const onMessageReceived = (e: MessageEvent<WorkerMessage>) => {
72
  const status = e.data.status
73
- if (status === 'ready') {
74
- setStatus('ready')
75
- } else if (status === 'output') {
76
- setStatus('output')
77
  const { sequence, labels, scores } = e.data.output!
78
 
79
  // Threshold for classification
@@ -89,33 +97,12 @@ function ZeroShotClassification() {
89
  }
90
  return newSections
91
  })
92
- } else if (status === 'error') {
93
- setStatus('error')
94
- console.error(e.data.output)
95
  }
96
  }
97
 
98
- // Attach the callback function as an event listener.
99
- worker.current.addEventListener('message', onMessageReceived)
100
-
101
- // Define a cleanup function for when the component is unmounted.
102
- return () =>
103
- worker.current?.removeEventListener('message', onMessageReceived)
104
- }, [sections])
105
-
106
- const classify = useCallback(() => {
107
- if (!modelInfo) return
108
-
109
- setStatus('loading')
110
- const message: ZeroShotWorkerInput = {
111
- text,
112
- labels: sections
113
- .slice(0, sections.length - 1)
114
- .map((section) => section.title),
115
- model: modelInfo.name
116
- }
117
- worker.current?.postMessage(message)
118
- }, [text, sections, modelInfo])
119
 
120
  const busy: boolean = status !== 'ready'
121
 
@@ -169,11 +156,10 @@ function ZeroShotClassification() {
169
  disabled={busy}
170
  onClick={classify}
171
  >
172
- {!busy
173
- ? 'Categorize'
174
- : status === 'loading'
175
- ? 'Model loading...'
176
- : 'Processing'}
177
  </button>
178
  <div className="flex gap-1">
179
  <button
 
 
1
  import { useState, useRef, useEffect, useCallback } from 'react'
2
  import {
3
  Section,
 
48
  PLACEHOLDER_SECTIONS.map((title) => ({ title, items: [] }))
49
  )
50
 
51
+ const { activeWorker, status, modelInfo, hasBeenLoaded } = useModel()
52
 
53
+ const classify = useCallback(() => {
54
+ if (!modelInfo || !activeWorker) {
55
+ console.error('Model info or worker is not available')
 
 
 
56
  return
 
 
 
 
 
 
 
57
  }
58
 
59
+ // Clear previous results
60
+ setSections((sections) =>
61
+ sections.map((section) => ({
62
+ ...section,
63
+ items: []
64
+ }))
65
+ )
66
+
67
+ const message: ZeroShotWorkerInput = {
68
+ type: 'classify',
69
+ text,
70
+ labels: sections
71
+ .slice(0, sections.length - 1)
72
+ .map((section) => section.title),
73
+ model: modelInfo.id
74
+ }
75
+ activeWorker.postMessage(message)
76
+ }, [text, sections, modelInfo, activeWorker])
77
+
78
+ // Handle worker messages
79
+ useEffect(() => {
80
+ if (!activeWorker) return
81
+
82
  const onMessageReceived = (e: MessageEvent<WorkerMessage>) => {
83
  const status = e.data.status
84
+ if (status === 'output') {
 
 
 
85
  const { sequence, labels, scores } = e.data.output!
86
 
87
  // Threshold for classification
 
97
  }
98
  return newSections
99
  })
 
 
 
100
  }
101
  }
102
 
103
+ activeWorker.addEventListener('message', onMessageReceived)
104
+ return () => activeWorker.removeEventListener('message', onMessageReceived)
105
+ }, [sections, activeWorker])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
 
107
  const busy: boolean = status !== 'ready'
108
 
 
156
  disabled={busy}
157
  onClick={classify}
158
  >
159
+ {hasBeenLoaded ? !busy
160
+ ? 'Categorize'
161
+ : 'Processing...'
162
+ : 'Load model first'}
 
163
  </button>
164
  <div className="flex gap-1">
165
  <button
src/types.ts CHANGED
@@ -35,6 +35,7 @@ export interface WorkerMessage {
35
  }
36
 
37
  export interface ZeroShotWorkerInput {
 
38
  text: string
39
  labels: string[]
40
  model: string
 
35
  }
36
 
37
  export interface ZeroShotWorkerInput {
38
+ type: 'classify'
39
  text: string
40
  labels: string[]
41
  model: string