Vokturz commited on
Commit
e7ba29d
·
1 Parent(s): 6ebf2fd

wip: public access to workers

Browse files
public/workers/text-classification.js ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /* eslint-disable no-restricted-globals */
2
+ import { pipeline } from 'https://cdn.jsdelivr.net/npm/@huggingface/transformers@3.6.3';
3
+
4
+ class MyTextClassificationPipeline {
5
+ static task = 'text-classification'
6
+ static instance = null
7
+
8
+ static async getInstance(model, progress_callback = null) {
9
+ this.instance = pipeline(this.task, model, {
10
+ progress_callback
11
+ })
12
+ return this.instance
13
+ }
14
+ }
15
+
16
+ // Listen for messages from the main thread
17
+ self.addEventListener('message', async (event) => {
18
+ const { type, model, text } = event.data // Destructure 'type'
19
+
20
+ if (!model) {
21
+ self.postMessage({
22
+ status: 'error',
23
+ output: 'No model provided'
24
+ })
25
+ return
26
+ }
27
+
28
+ // Retrieve the pipeline. This will download the model if not already cached.
29
+ const classifier = await MyTextClassificationPipeline.getInstance(
30
+ model,
31
+ (x) => {
32
+ self.postMessage({ status: 'progress', output: x })
33
+ }
34
+ )
35
+
36
+ if (type === 'load') {
37
+ self.postMessage({ status: 'ready' })
38
+ return
39
+ }
40
+
41
+ if (type === 'classify') {
42
+ if (!text) {
43
+ self.postMessage({ status: 'complete' }) // Nothing to process
44
+ return
45
+ }
46
+ const split = text.split('\n')
47
+ for (const line of split) {
48
+ if (line.trim()) {
49
+ const output = await classifier(line)
50
+ self.postMessage({
51
+ status: 'output',
52
+ output: {
53
+ sequence: line,
54
+ labels: [output[0].label],
55
+ scores: [output[0].score]
56
+ }
57
+ })
58
+ }
59
+ }
60
+ self.postMessage({ status: 'complete' })
61
+ }
62
+ })
src/workers/zero-shot.js → public/workers/zero-shot-classification.js RENAMED
File without changes
src/components/ModelInfo.tsx CHANGED
@@ -1,8 +1,19 @@
1
- import { Bot, Heart, Download, Cpu, DatabaseIcon, CheckCircle, XCircle, ExternalLink, ChevronDown } from 'lucide-react'
 
 
 
 
 
 
 
 
 
 
2
  import { getModelSize } from '../lib/huggingface'
3
  import { useModel } from '../contexts/ModelContext'
4
- import { useEffect } from 'react'
5
- import { QuantizationType } from '../types'
 
6
 
7
  const ModelInfo = () => {
8
  const formatNumber = (num: number) => {
@@ -16,14 +27,25 @@ const ModelInfo = () => {
16
  return num.toString()
17
  }
18
 
19
- const { modelInfo, selectedQuantization, setSelectedQuantization } = useModel()
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
- // Set default quantization when model changes
22
  useEffect(() => {
23
  if (modelInfo.isCompatible && modelInfo.supportedQuantizations.length > 0) {
24
  const quantizations = modelInfo.supportedQuantizations
25
  let defaultQuant: QuantizationType = 'fp32'
26
-
27
  if (quantizations.includes('int8')) {
28
  defaultQuant = 'int8'
29
  } else if (quantizations.includes('q8')) {
@@ -31,17 +53,72 @@ const ModelInfo = () => {
31
  } else if (quantizations.includes('q4')) {
32
  defaultQuant = 'q4'
33
  }
34
-
35
  setSelectedQuantization(defaultQuant)
36
  }
37
- }, [modelInfo.supportedQuantizations, modelInfo.isCompatible, setSelectedQuantization])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
  if (!modelInfo.name) {
40
  return null
41
  }
42
 
43
  return (
44
- <div className="bg-gradient-to-r from-blue-50 to-indigo-50 px-4 py-3 rounded-lg border border-blue-200 space-y-3">
45
  {/* Model Name Row */}
46
  <div className="flex items-center space-x-2">
47
  <Bot className="w-4 h-4 text-blue-600" />
@@ -70,7 +147,7 @@ const ModelInfo = () => {
70
  </div>
71
  )}
72
  </div>
73
-
74
  {/* Base Model Link */}
75
  {modelInfo.baseId && (
76
  <div className="flex items-center space-x-2 ml-6">
@@ -87,7 +164,6 @@ const ModelInfo = () => {
87
  </div>
88
  )}
89
 
90
-
91
  {/* Stats Row */}
92
  <div className="flex items-center justify-self-end space-x-4 text-xs text-gray-600">
93
  {modelInfo.likes > 0 && (
@@ -115,36 +191,62 @@ const ModelInfo = () => {
115
  <div className="flex items-center space-x-1">
116
  <DatabaseIcon className="w-3 h-3 text-purple-500" />
117
  <span>
118
- {`~${getModelSize(modelInfo.parameters, selectedQuantization).toFixed(1)}MB`}
 
 
 
119
  </span>
120
  </div>
121
  )}
122
  </div>
123
 
124
  {/* Separator */}
125
- {modelInfo.isCompatible && modelInfo.supportedQuantizations.length > 0 && (
126
- <hr className="border-gray-200" />
127
- )}
128
-
 
129
  {/* Quantization Dropdown */}
130
- {modelInfo.isCompatible && modelInfo.supportedQuantizations.length > 0 && (
131
- <div className="flex items-center space-x-2">
132
- <span className="text-xs text-gray-600 font-medium">Quantization:</span>
133
- <div className="relative">
134
- <select
135
- value={selectedQuantization || ''}
136
- onChange={(e) => setSelectedQuantization(e.target.value as QuantizationType)}
137
- className="appearance-none bg-white border border-gray-300 rounded-md px-3 py-1 pr-8 text-xs text-gray-700 focus:outline-none focus:ring-2 focus:ring-blue-500 focus:border-blue-500"
138
- >
139
- <option value="">Select quantization</option>
140
- {modelInfo.supportedQuantizations.map((quant) => (
141
- <option key={quant} value={quant}>
142
- {quant}
143
- </option>
144
- ))}
145
- </select>
146
- <ChevronDown className="absolute right-2 top-1/2 transform -translate-y-1/2 w-3 h-3 text-gray-400 pointer-events-none" />
 
 
 
 
 
 
147
  </div>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
  </div>
149
  )}
150
 
 
1
+ import {
2
+ Bot,
3
+ Heart,
4
+ Download,
5
+ Cpu,
6
+ DatabaseIcon,
7
+ CheckCircle,
8
+ XCircle,
9
+ ExternalLink,
10
+ ChevronDown
11
+ } from 'lucide-react'
12
  import { getModelSize } from '../lib/huggingface'
13
  import { useModel } from '../contexts/ModelContext'
14
+ import { useEffect, useCallback } from 'react'
15
+ import { QuantizationType, WorkerMessage } from '../types'
16
+ import { getWorker } from '../lib/workerManager'
17
 
18
  const ModelInfo = () => {
19
  const formatNumber = (num: number) => {
 
27
  return num.toString()
28
  }
29
 
30
+ const {
31
+ modelInfo,
32
+ selectedQuantization,
33
+ setSelectedQuantization,
34
+ status,
35
+ setStatus,
36
+ setProgress,
37
+ activeWorker,
38
+ setActiveWorker,
39
+ pipeline,
40
+ workerLoaded,
41
+ setWorkerLoaded
42
+ } = useModel()
43
 
 
44
  useEffect(() => {
45
  if (modelInfo.isCompatible && modelInfo.supportedQuantizations.length > 0) {
46
  const quantizations = modelInfo.supportedQuantizations
47
  let defaultQuant: QuantizationType = 'fp32'
48
+
49
  if (quantizations.includes('int8')) {
50
  defaultQuant = 'int8'
51
  } else if (quantizations.includes('q8')) {
 
53
  } else if (quantizations.includes('q4')) {
54
  defaultQuant = 'q4'
55
  }
56
+
57
  setSelectedQuantization(defaultQuant)
58
  }
59
+ }, [
60
+ modelInfo.supportedQuantizations,
61
+ modelInfo.isCompatible,
62
+ setSelectedQuantization
63
+ ])
64
+
65
+ useEffect(() => {
66
+ const newWorker = getWorker(pipeline)
67
+ if (!newWorker) {
68
+ return
69
+ }
70
+
71
+ setStatus('idle')
72
+ setWorkerLoaded(false)
73
+ setActiveWorker(newWorker)
74
+
75
+ const onMessageReceived = (e: MessageEvent<WorkerMessage>) => {
76
+ const { status, output } = e.data
77
+ if (status === 'initiate') {
78
+ setStatus('loading')
79
+ } else if (status === 'ready') {
80
+ setStatus('ready')
81
+ setWorkerLoaded(true)
82
+ } else if (status === 'progress' && output) {
83
+ setStatus('progress')
84
+ if (
85
+ output.progress &&
86
+ typeof output.file === 'string' &&
87
+ output.file.startsWith('onnx')
88
+ ) {
89
+ setProgress(output.progress)
90
+ }
91
+ }
92
+ }
93
+
94
+ newWorker.addEventListener('message', onMessageReceived)
95
+
96
+ return () => {
97
+ newWorker.removeEventListener('message', onMessageReceived)
98
+ // terminateWorker(pipeline);
99
+ }
100
+ }, [pipeline, selectedQuantization, setActiveWorker, setStatus, setProgress, setWorkerLoaded])
101
+
102
+ const loadModel = useCallback(() => {
103
+ if (!modelInfo.name || !selectedQuantization) return
104
+
105
+ setStatus('loading')
106
+ const message = {
107
+ type: 'load',
108
+ model: modelInfo.name,
109
+ quantization: selectedQuantization
110
+ }
111
+ activeWorker?.postMessage(message)
112
+ }, [modelInfo.name, selectedQuantization, setStatus, activeWorker])
113
+
114
+ const busy: boolean = status !== 'idle'
115
 
116
  if (!modelInfo.name) {
117
  return null
118
  }
119
 
120
  return (
121
+ <div className="bg-gradient-to-r from-blue-50 to-indigo-50 px-4 py-3 rounded-lg border border-blue-200 space-y-3">
122
  {/* Model Name Row */}
123
  <div className="flex items-center space-x-2">
124
  <Bot className="w-4 h-4 text-blue-600" />
 
147
  </div>
148
  )}
149
  </div>
150
+
151
  {/* Base Model Link */}
152
  {modelInfo.baseId && (
153
  <div className="flex items-center space-x-2 ml-6">
 
164
  </div>
165
  )}
166
 
 
167
  {/* Stats Row */}
168
  <div className="flex items-center justify-self-end space-x-4 text-xs text-gray-600">
169
  {modelInfo.likes > 0 && (
 
191
  <div className="flex items-center space-x-1">
192
  <DatabaseIcon className="w-3 h-3 text-purple-500" />
193
  <span>
194
+ {`~${getModelSize(
195
+ modelInfo.parameters,
196
+ selectedQuantization
197
+ ).toFixed(1)}MB`}
198
  </span>
199
  </div>
200
  )}
201
  </div>
202
 
203
  {/* Separator */}
204
+ {modelInfo.isCompatible &&
205
+ modelInfo.supportedQuantizations.length > 0 && (
206
+ <hr className="border-gray-200" />
207
+ )}
208
+
209
  {/* Quantization Dropdown */}
210
+ {modelInfo.isCompatible &&
211
+ modelInfo.supportedQuantizations.length > 0 && (
212
+ <div className="flex items-center space-x-2">
213
+ <span className="text-xs text-gray-600 font-medium">
214
+ Quantization:
215
+ </span>
216
+ <div className="relative">
217
+ <select
218
+ value={selectedQuantization || ''}
219
+ onChange={(e) =>
220
+ setSelectedQuantization(e.target.value as QuantizationType)
221
+ }
222
+ className="appearance-none bg-white border border-gray-300 rounded-md px-3 py-1 pr-8 text-xs text-gray-700 focus:outline-none focus:ring-2 focus:ring-blue-500 focus:border-blue-500"
223
+ >
224
+ <option value="">Select quantization</option>
225
+ {modelInfo.supportedQuantizations.map((quant) => (
226
+ <option key={quant} value={quant}>
227
+ {quant}
228
+ </option>
229
+ ))}
230
+ </select>
231
+ <ChevronDown className="absolute right-2 top-1/2 transform -translate-y-1/2 w-3 h-3 text-gray-400 pointer-events-none" />
232
+ </div>
233
  </div>
234
+ )}
235
+
236
+ {/* Load Model Button */}
237
+ {modelInfo.isCompatible && selectedQuantization && (
238
+ <div className="flex justify-center">
239
+ <button
240
+ className="py-2 px-4 bg-green-500 hover:bg-green-600 rounded text-white font-medium disabled:opacity-50 disabled:cursor-not-allowed transition-colors text-sm"
241
+ disabled={busy || !selectedQuantization || workerLoaded}
242
+ onClick={loadModel}
243
+ >
244
+ {status === 'loading'
245
+ ? 'Loading Model...'
246
+ : workerLoaded
247
+ ? 'Model Ready'
248
+ : 'Load Model'}
249
+ </button>
250
  </div>
251
  )}
252
 
src/components/ModelSelector.tsx CHANGED
@@ -8,7 +8,7 @@ type SortOption = 'likes' | 'downloads' | 'createdAt' | 'name'
8
 
9
  const ModelSelector: React.FC = () => {
10
  const { models, setModelInfo, modelInfo, pipeline } = useModel()
11
- const [sortBy, setSortBy] = useState<SortOption>('likes')
12
  const [sortOrder, setSortOrder] = useState<'asc' | 'desc'>('desc')
13
 
14
  const formatNumber = (num: number) => {
 
8
 
9
  const ModelSelector: React.FC = () => {
10
  const { models, setModelInfo, modelInfo, pipeline } = useModel()
11
+ const [sortBy, setSortBy] = useState<SortOption>('downloads')
12
  const [sortOrder, setSortOrder] = useState<'asc' | 'desc'>('desc')
13
 
14
  const formatNumber = (num: number) => {
src/components/PipelineSelector.tsx CHANGED
@@ -9,8 +9,8 @@ import {
9
  import { ChevronDown, Check } from 'lucide-react';
10
 
11
  const pipelines = [
12
- 'zero-shot-classification',
13
  'text-classification',
 
14
  'text-generation',
15
  'summarization',
16
  'feature-extraction',
 
9
  import { ChevronDown, Check } from 'lucide-react';
10
 
11
  const pipelines = [
 
12
  'text-classification',
13
+ 'zero-shot-classification',
14
  'text-generation',
15
  'summarization',
16
  'feature-extraction',
src/components/TextClassification.tsx CHANGED
@@ -6,6 +6,7 @@ import {
6
  } from '../types';
7
  import { useModel } from '../contexts/ModelContext';
8
  import { getModelInfo } from '../lib/huggingface';
 
9
 
10
 
11
  const PLACEHOLDER_TEXTS: string[] = [
@@ -24,7 +25,8 @@ const PLACEHOLDER_TEXTS: string[] = [
24
  function TextClassification() {
25
  const [text, setText] = useState<string>(PLACEHOLDER_TEXTS.join('\n'))
26
  const [results, setResults] = useState<ClassificationOutput[]>([])
27
- const { setProgress, status, setStatus, modelInfo, setModelInfo, models, setModels} = useModel()
 
28
 
29
 
30
  useEffect(() => {
@@ -56,43 +58,23 @@ function TextClassification() {
56
  fetchModelInfo()
57
  }, [modelInfo.id, setModelInfo])
58
 
59
- // Create a reference to the worker object.
60
- const worker = useRef<Worker | null>(null)
61
-
62
  // We use the `useEffect` hook to setup the worker as soon as the component is mounted.
63
  useEffect(() => {
64
- if (!worker.current) {
65
- // Create the worker if it does not yet exist.
66
- worker.current = new Worker(
67
- new URL('../workers/text-classification.js', import.meta.url),
68
- {
69
- type: 'module'
70
- }
71
- )
72
  }
73
 
 
74
  // Create a callback function for messages from the worker thread.
75
  const onMessageReceived = (e: MessageEvent<WorkerMessage>) => {
76
  const status = e.data.status
77
- if (status === 'initiate') {
78
- setStatus('loading')
79
- } else if (status === 'ready') {
80
- setStatus('ready')
81
- } else if (status === 'progress') {
82
- setStatus('progress')
83
- if (
84
- e.data.output.progress &&
85
- (e.data.output.file as string).startsWith('onnx')
86
- )
87
- setProgress(e.data.output.progress)
88
- } else if (status === 'output') {
89
  setStatus('output')
90
  const result = e.data.output!
91
  setResults((prevResults) => [...prevResults, result])
92
  console.log(result)
93
  } else if (status === 'complete') {
94
  setStatus('idle')
95
- setProgress(100)
96
  } else if (status === 'error') {
97
  setStatus('error')
98
  console.error(e.data.output)
@@ -100,21 +82,26 @@ function TextClassification() {
100
  }
101
 
102
  // Attach the callback function as an event listener.
103
- worker.current.addEventListener('message', onMessageReceived)
104
 
105
  // Define a cleanup function for when the component is unmounted.
106
  return () =>
107
- worker.current?.removeEventListener('message', onMessageReceived)
108
  }, [])
109
 
110
  const classify = useCallback(() => {
111
  setStatus('processing')
112
  setResults([]) // Clear previous results
113
- const message: TextClassificationWorkerInput = { text, model: modelInfo.id }
114
- worker.current?.postMessage(message)
 
 
 
 
115
  }, [text, modelInfo.id])
116
 
117
- const busy: boolean = status !== 'idle'
 
118
 
119
  const handleClear = (): void => {
120
  setResults([])
@@ -138,14 +125,14 @@ function TextClassification() {
138
  <div className="flex gap-2 mt-4">
139
  <button
140
  className="flex-1 py-2 px-4 bg-blue-500 hover:bg-blue-600 rounded text-white font-medium disabled:opacity-50 disabled:cursor-not-allowed transition-colors"
141
- disabled={busy}
142
  onClick={classify}
143
  >
144
- {!busy
145
  ? 'Classify Text'
146
  : status === 'loading'
147
  ? 'Model loading...'
148
- : 'Processing...'}
149
  </button>
150
  <button
151
  className="py-2 px-4 bg-gray-500 hover:bg-gray-600 rounded text-white font-medium transition-colors"
 
6
  } from '../types';
7
  import { useModel } from '../contexts/ModelContext';
8
  import { getModelInfo } from '../lib/huggingface';
9
+ import { getWorker } from '../lib/workerManager';
10
 
11
 
12
  const PLACEHOLDER_TEXTS: string[] = [
 
25
  function TextClassification() {
26
  const [text, setText] = useState<string>(PLACEHOLDER_TEXTS.join('\n'))
27
  const [results, setResults] = useState<ClassificationOutput[]>([])
28
+ const { setProgress, status, setStatus, modelInfo, setModelInfo, workerLoaded} = useModel()
29
+ const workerRef = useRef<Worker | null>(null)
30
 
31
 
32
  useEffect(() => {
 
58
  fetchModelInfo()
59
  }, [modelInfo.id, setModelInfo])
60
 
 
 
 
61
  // We use the `useEffect` hook to setup the worker as soon as the component is mounted.
62
  useEffect(() => {
63
+ if(!workerRef.current) {
64
+ workerRef.current = getWorker('text-classification')
 
 
 
 
 
 
65
  }
66
 
67
+
68
  // Create a callback function for messages from the worker thread.
69
  const onMessageReceived = (e: MessageEvent<WorkerMessage>) => {
70
  const status = e.data.status
71
+ if (status === 'output') {
 
 
 
 
 
 
 
 
 
 
 
72
  setStatus('output')
73
  const result = e.data.output!
74
  setResults((prevResults) => [...prevResults, result])
75
  console.log(result)
76
  } else if (status === 'complete') {
77
  setStatus('idle')
 
78
  } else if (status === 'error') {
79
  setStatus('error')
80
  console.error(e.data.output)
 
82
  }
83
 
84
  // Attach the callback function as an event listener.
85
+ workerRef.current?.addEventListener('message', onMessageReceived)
86
 
87
  // Define a cleanup function for when the component is unmounted.
88
  return () =>
89
+ workerRef.current?.removeEventListener('message', onMessageReceived)
90
  }, [])
91
 
92
  const classify = useCallback(() => {
93
  setStatus('processing')
94
  setResults([]) // Clear previous results
95
+ const message: TextClassificationWorkerInput = {
96
+ type: 'classify',
97
+ text,
98
+ model: modelInfo.id
99
+ }
100
+ workerRef.current?.postMessage(message)
101
  }, [text, modelInfo.id])
102
 
103
+ const busy: boolean = status !== 'ready'
104
+
105
 
106
  const handleClear = (): void => {
107
  setResults([])
 
125
  <div className="flex gap-2 mt-4">
126
  <button
127
  className="flex-1 py-2 px-4 bg-blue-500 hover:bg-blue-600 rounded text-white font-medium disabled:opacity-50 disabled:cursor-not-allowed transition-colors"
128
+ disabled={busy || !workerLoaded}
129
  onClick={classify}
130
  >
131
+ {workerLoaded ? (!busy
132
  ? 'Classify Text'
133
  : status === 'loading'
134
  ? 'Model loading...'
135
+ : 'Processing...') : 'Load model first'}
136
  </button>
137
  <button
138
  className="py-2 px-4 bg-gray-500 hover:bg-gray-600 rounded text-white font-medium transition-colors"
src/components/ZeroShotClassification.tsx CHANGED
@@ -59,13 +59,14 @@ function ZeroShotClassification() {
59
  // We use the `useEffect` hook to setup the worker as soon as the `App` component is mounted.
60
  useEffect(() => {
61
  if (!worker.current) {
 
62
  // Create the worker if it does not yet exist.
63
- worker.current = new Worker(
64
- new URL('../workers/zero-shot.js', import.meta.url),
65
- {
66
- type: 'module'
67
- }
68
- )
69
  }
70
 
71
  // Create a callback function for messages from the worker thread.
 
59
  // We use the `useEffect` hook to setup the worker as soon as the `App` component is mounted.
60
  useEffect(() => {
61
  if (!worker.current) {
62
+ return
63
  // Create the worker if it does not yet exist.
64
+ // worker.current = new Worker(
65
+ // new URL('../workers/zero-shot-classification.js', import.meta.url),
66
+ // {
67
+ // type: 'module'
68
+ // }
69
+ // )
70
  }
71
 
72
  // Create a callback function for messages from the worker thread.
src/contexts/ModelContext.tsx CHANGED
@@ -1,4 +1,4 @@
1
- import React, { createContext, useContext, useEffect, useState } from 'react'
2
  import { ModelInfo, ModelInfoResponse, QuantizationType } from '../types'
3
 
4
  interface ModelContextType {
@@ -14,6 +14,10 @@ interface ModelContextType {
14
  setModels: (models: ModelInfoResponse[]) => void
15
  selectedQuantization: QuantizationType
16
  setSelectedQuantization: (quantization: QuantizationType) => void
 
 
 
 
17
  }
18
 
19
  const ModelContext = createContext<ModelContextType | undefined>(undefined)
@@ -23,8 +27,11 @@ export function ModelProvider({ children }: { children: React.ReactNode }) {
23
  const [status, setStatus] = useState<string>('idle')
24
  const [modelInfo, setModelInfo] = useState<ModelInfo>({} as ModelInfo)
25
  const [models, setModels] = useState<ModelInfoResponse[]>([] as ModelInfoResponse[])
26
- const [pipeline, setPipeline] = useState<string>('zero-shot-classification')
27
  const [selectedQuantization, setSelectedQuantization] = useState<QuantizationType>('int8')
 
 
 
28
 
29
  // set progress to 0 when model is changed
30
  useEffect(() => {
@@ -46,6 +53,10 @@ export function ModelProvider({ children }: { children: React.ReactNode }) {
46
  setPipeline,
47
  selectedQuantization,
48
  setSelectedQuantization,
 
 
 
 
49
  }}
50
  >
51
  {children}
 
1
+ import React, { createContext, RefObject, useContext, useEffect, useRef, useState } from 'react'
2
  import { ModelInfo, ModelInfoResponse, QuantizationType } from '../types'
3
 
4
  interface ModelContextType {
 
14
  setModels: (models: ModelInfoResponse[]) => void
15
  selectedQuantization: QuantizationType
16
  setSelectedQuantization: (quantization: QuantizationType) => void
17
+ activeWorker: Worker | null
18
+ setActiveWorker: (worker: Worker | null) => void
19
+ workerLoaded: boolean
20
+ setWorkerLoaded: (workerLoaded: boolean) => void
21
  }
22
 
23
  const ModelContext = createContext<ModelContextType | undefined>(undefined)
 
27
  const [status, setStatus] = useState<string>('idle')
28
  const [modelInfo, setModelInfo] = useState<ModelInfo>({} as ModelInfo)
29
  const [models, setModels] = useState<ModelInfoResponse[]>([] as ModelInfoResponse[])
30
+ const [pipeline, setPipeline] = useState<string>('text-classification')
31
  const [selectedQuantization, setSelectedQuantization] = useState<QuantizationType>('int8')
32
+ const [activeWorker, setActiveWorker] = useState<Worker | null>(null)
33
+ const [workerLoaded, setWorkerLoaded] = useState<boolean>(false)
34
+
35
 
36
  // set progress to 0 when model is changed
37
  useEffect(() => {
 
53
  setPipeline,
54
  selectedQuantization,
55
  setSelectedQuantization,
56
+ activeWorker,
57
+ setActiveWorker,
58
+ workerLoaded,
59
+ setWorkerLoaded
60
  }}
61
  >
62
  {children}
src/lib/workerManager.ts ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ const workers: Record<string, Worker | null> = {};
2
+
3
+ export const getWorker = (pipeline: string) => {
4
+ if (!workers[pipeline]) {
5
+ let workerUrl: string;
6
+
7
+ // Construct the public URL for the worker script.
8
+ // process.env.PUBLIC_URL ensures this works correctly even if the
9
+ // app is hosted in a sub-directory.
10
+ switch (pipeline) {
11
+ case 'text-classification':
12
+ workerUrl = `${process.env.PUBLIC_URL}/workers/text-classification.js`;
13
+ break;
14
+ case 'zero-shot-classification':
15
+ workerUrl = `${process.env.PUBLIC_URL}/workers/zero-shot-classification.js`;
16
+ break;
17
+ // Add other pipeline types here
18
+ default:
19
+ // Return null or throw an error if the pipeline is unknown
20
+ return null;
21
+ }
22
+ workers[pipeline] = new Worker(workerUrl, { type: 'module' });
23
+ }
24
+ return workers[pipeline];
25
+ };
26
+
27
+ export const terminateWorker = (pipeline: string) => {
28
+ const worker = workers[pipeline];
29
+ if (worker) {
30
+ worker.terminate();
31
+ delete workers[pipeline];
32
+ }
33
+ };
src/types.ts CHANGED
@@ -10,7 +10,9 @@ export interface ClassificationOutput {
10
  }
11
 
12
  export interface WorkerMessage {
13
- status: 'initiate' | 'ready' | 'output' | 'complete' | 'progress'
 
 
14
  output?: any
15
  }
16
 
@@ -21,6 +23,7 @@ export interface ZeroShotWorkerInput {
21
  }
22
 
23
  export interface TextClassificationWorkerInput {
 
24
  text: string
25
  model: string
26
  }
 
10
  }
11
 
12
  export interface WorkerMessage {
13
+ status: 'initiate' | 'ready' | 'output' | 'complete' | 'progress' | 'error'
14
+ progress?: number
15
+ error?: string
16
  output?: any
17
  }
18
 
 
23
  }
24
 
25
  export interface TextClassificationWorkerInput {
26
+ type: 'classify'
27
  text: string
28
  model: string
29
  }
src/workers/text-classification.js DELETED
@@ -1,55 +0,0 @@
1
- /* eslint-disable no-restricted-globals */
2
- import { pipeline } from '@huggingface/transformers';
3
-
4
- class MyTextClassificationPipeline {
5
- static task = 'text-classification';
6
- static instance = null;
7
-
8
- static async getInstance(model, progress_callback = null) {
9
- this.instance ??= pipeline(this.task, model, {
10
- progress_callback
11
- });
12
-
13
- return this.instance;
14
- }
15
- }
16
-
17
- // Listen for messages from the main thread
18
- self.addEventListener('message', async (event) => {
19
- const { text, model } = event.data;
20
- if (!model) {
21
- self.postMessage({
22
- status: 'error',
23
- output: 'No model provided'
24
- });
25
- return;
26
- }
27
-
28
- // Retrieve the pipeline. When called for the first time,
29
- // this will load the pipeline and save it for future use.
30
- const classifier = await MyTextClassificationPipeline.getInstance(model, (x) => {
31
- // We also add a progress callback to the pipeline so that we can
32
- // track model loading.
33
- self.postMessage({ status: 'progress', output: x });
34
- });
35
-
36
-
37
-
38
- const split = text.split('\n');
39
- for (const line of split) {
40
- if (line.trim()) {
41
- const output = await classifier(line);
42
- // Send the output back to the main thread
43
- self.postMessage({
44
- status: 'output',
45
- output: {
46
- sequence: line,
47
- labels: [output[0].label],
48
- scores: [output[0].score]
49
- }
50
- });
51
- }
52
- }
53
- // Send the output back to the main thread
54
- self.postMessage({ status: 'complete' });
55
- });
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tsconfig.json CHANGED
@@ -1,7 +1,12 @@
1
  {
2
  "compilerOptions": {
3
- "target": "es5",
4
- "lib": ["dom", "dom.iterable", "esnext"],
 
 
 
 
 
5
  "allowJs": true,
6
  "skipLibCheck": true,
7
  "esModuleInterop": true,
@@ -16,5 +21,5 @@
16
  "noEmit": true,
17
  "jsx": "react-jsx"
18
  },
19
- "include": ["src"]
20
- }
 
1
  {
2
  "compilerOptions": {
3
+ "target": "es2020",
4
+ "lib": [
5
+ "dom",
6
+ "dom.iterable",
7
+ "esnext",
8
+ "WebWorker"
9
+ ],
10
  "allowJs": true,
11
  "skipLibCheck": true,
12
  "esModuleInterop": true,
 
21
  "noEmit": true,
22
  "jsx": "react-jsx"
23
  },
24
+ "include": ["src", "public/workers"]
25
+ }