wip: public access to workers
Browse files- public/workers/text-classification.js +62 -0
- src/workers/zero-shot.js → public/workers/zero-shot-classification.js +0 -0
- src/components/ModelInfo.tsx +135 -33
- src/components/ModelSelector.tsx +1 -1
- src/components/PipelineSelector.tsx +1 -1
- src/components/TextClassification.tsx +20 -33
- src/components/ZeroShotClassification.tsx +7 -6
- src/contexts/ModelContext.tsx +13 -2
- src/lib/workerManager.ts +33 -0
- src/types.ts +4 -1
- src/workers/text-classification.js +0 -55
- tsconfig.json +9 -4
public/workers/text-classification.js
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/* eslint-disable no-restricted-globals */
|
2 |
+
import { pipeline } from 'https://cdn.jsdelivr.net/npm/@huggingface/transformers@3.6.3';
|
3 |
+
|
4 |
+
class MyTextClassificationPipeline {
|
5 |
+
static task = 'text-classification'
|
6 |
+
static instance = null
|
7 |
+
|
8 |
+
static async getInstance(model, progress_callback = null) {
|
9 |
+
this.instance = pipeline(this.task, model, {
|
10 |
+
progress_callback
|
11 |
+
})
|
12 |
+
return this.instance
|
13 |
+
}
|
14 |
+
}
|
15 |
+
|
16 |
+
// Listen for messages from the main thread
|
17 |
+
self.addEventListener('message', async (event) => {
|
18 |
+
const { type, model, text } = event.data // Destructure 'type'
|
19 |
+
|
20 |
+
if (!model) {
|
21 |
+
self.postMessage({
|
22 |
+
status: 'error',
|
23 |
+
output: 'No model provided'
|
24 |
+
})
|
25 |
+
return
|
26 |
+
}
|
27 |
+
|
28 |
+
// Retrieve the pipeline. This will download the model if not already cached.
|
29 |
+
const classifier = await MyTextClassificationPipeline.getInstance(
|
30 |
+
model,
|
31 |
+
(x) => {
|
32 |
+
self.postMessage({ status: 'progress', output: x })
|
33 |
+
}
|
34 |
+
)
|
35 |
+
|
36 |
+
if (type === 'load') {
|
37 |
+
self.postMessage({ status: 'ready' })
|
38 |
+
return
|
39 |
+
}
|
40 |
+
|
41 |
+
if (type === 'classify') {
|
42 |
+
if (!text) {
|
43 |
+
self.postMessage({ status: 'complete' }) // Nothing to process
|
44 |
+
return
|
45 |
+
}
|
46 |
+
const split = text.split('\n')
|
47 |
+
for (const line of split) {
|
48 |
+
if (line.trim()) {
|
49 |
+
const output = await classifier(line)
|
50 |
+
self.postMessage({
|
51 |
+
status: 'output',
|
52 |
+
output: {
|
53 |
+
sequence: line,
|
54 |
+
labels: [output[0].label],
|
55 |
+
scores: [output[0].score]
|
56 |
+
}
|
57 |
+
})
|
58 |
+
}
|
59 |
+
}
|
60 |
+
self.postMessage({ status: 'complete' })
|
61 |
+
}
|
62 |
+
})
|
src/workers/zero-shot.js → public/workers/zero-shot-classification.js
RENAMED
File without changes
|
src/components/ModelInfo.tsx
CHANGED
@@ -1,8 +1,19 @@
|
|
1 |
-
import {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
import { getModelSize } from '../lib/huggingface'
|
3 |
import { useModel } from '../contexts/ModelContext'
|
4 |
-
import { useEffect } from 'react'
|
5 |
-
import { QuantizationType } from '../types'
|
|
|
6 |
|
7 |
const ModelInfo = () => {
|
8 |
const formatNumber = (num: number) => {
|
@@ -16,14 +27,25 @@ const ModelInfo = () => {
|
|
16 |
return num.toString()
|
17 |
}
|
18 |
|
19 |
-
const {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
|
21 |
-
// Set default quantization when model changes
|
22 |
useEffect(() => {
|
23 |
if (modelInfo.isCompatible && modelInfo.supportedQuantizations.length > 0) {
|
24 |
const quantizations = modelInfo.supportedQuantizations
|
25 |
let defaultQuant: QuantizationType = 'fp32'
|
26 |
-
|
27 |
if (quantizations.includes('int8')) {
|
28 |
defaultQuant = 'int8'
|
29 |
} else if (quantizations.includes('q8')) {
|
@@ -31,17 +53,72 @@ const ModelInfo = () => {
|
|
31 |
} else if (quantizations.includes('q4')) {
|
32 |
defaultQuant = 'q4'
|
33 |
}
|
34 |
-
|
35 |
setSelectedQuantization(defaultQuant)
|
36 |
}
|
37 |
-
}, [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
|
39 |
if (!modelInfo.name) {
|
40 |
return null
|
41 |
}
|
42 |
|
43 |
return (
|
44 |
-
<div className="bg-gradient-to-r from-blue-50 to-indigo-50 px-4 py-3 rounded-lg border border-blue-200 space-y-3">
|
45 |
{/* Model Name Row */}
|
46 |
<div className="flex items-center space-x-2">
|
47 |
<Bot className="w-4 h-4 text-blue-600" />
|
@@ -70,7 +147,7 @@ const ModelInfo = () => {
|
|
70 |
</div>
|
71 |
)}
|
72 |
</div>
|
73 |
-
|
74 |
{/* Base Model Link */}
|
75 |
{modelInfo.baseId && (
|
76 |
<div className="flex items-center space-x-2 ml-6">
|
@@ -87,7 +164,6 @@ const ModelInfo = () => {
|
|
87 |
</div>
|
88 |
)}
|
89 |
|
90 |
-
|
91 |
{/* Stats Row */}
|
92 |
<div className="flex items-center justify-self-end space-x-4 text-xs text-gray-600">
|
93 |
{modelInfo.likes > 0 && (
|
@@ -115,36 +191,62 @@ const ModelInfo = () => {
|
|
115 |
<div className="flex items-center space-x-1">
|
116 |
<DatabaseIcon className="w-3 h-3 text-purple-500" />
|
117 |
<span>
|
118 |
-
{`~${getModelSize(
|
|
|
|
|
|
|
119 |
</span>
|
120 |
</div>
|
121 |
)}
|
122 |
</div>
|
123 |
|
124 |
{/* Separator */}
|
125 |
-
{modelInfo.isCompatible &&
|
126 |
-
|
127 |
-
|
128 |
-
|
|
|
129 |
{/* Quantization Dropdown */}
|
130 |
-
{modelInfo.isCompatible &&
|
131 |
-
|
132 |
-
<
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
147 |
</div>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
148 |
</div>
|
149 |
)}
|
150 |
|
|
|
1 |
+
import {
|
2 |
+
Bot,
|
3 |
+
Heart,
|
4 |
+
Download,
|
5 |
+
Cpu,
|
6 |
+
DatabaseIcon,
|
7 |
+
CheckCircle,
|
8 |
+
XCircle,
|
9 |
+
ExternalLink,
|
10 |
+
ChevronDown
|
11 |
+
} from 'lucide-react'
|
12 |
import { getModelSize } from '../lib/huggingface'
|
13 |
import { useModel } from '../contexts/ModelContext'
|
14 |
+
import { useEffect, useCallback } from 'react'
|
15 |
+
import { QuantizationType, WorkerMessage } from '../types'
|
16 |
+
import { getWorker } from '../lib/workerManager'
|
17 |
|
18 |
const ModelInfo = () => {
|
19 |
const formatNumber = (num: number) => {
|
|
|
27 |
return num.toString()
|
28 |
}
|
29 |
|
30 |
+
const {
|
31 |
+
modelInfo,
|
32 |
+
selectedQuantization,
|
33 |
+
setSelectedQuantization,
|
34 |
+
status,
|
35 |
+
setStatus,
|
36 |
+
setProgress,
|
37 |
+
activeWorker,
|
38 |
+
setActiveWorker,
|
39 |
+
pipeline,
|
40 |
+
workerLoaded,
|
41 |
+
setWorkerLoaded
|
42 |
+
} = useModel()
|
43 |
|
|
|
44 |
useEffect(() => {
|
45 |
if (modelInfo.isCompatible && modelInfo.supportedQuantizations.length > 0) {
|
46 |
const quantizations = modelInfo.supportedQuantizations
|
47 |
let defaultQuant: QuantizationType = 'fp32'
|
48 |
+
|
49 |
if (quantizations.includes('int8')) {
|
50 |
defaultQuant = 'int8'
|
51 |
} else if (quantizations.includes('q8')) {
|
|
|
53 |
} else if (quantizations.includes('q4')) {
|
54 |
defaultQuant = 'q4'
|
55 |
}
|
56 |
+
|
57 |
setSelectedQuantization(defaultQuant)
|
58 |
}
|
59 |
+
}, [
|
60 |
+
modelInfo.supportedQuantizations,
|
61 |
+
modelInfo.isCompatible,
|
62 |
+
setSelectedQuantization
|
63 |
+
])
|
64 |
+
|
65 |
+
useEffect(() => {
|
66 |
+
const newWorker = getWorker(pipeline)
|
67 |
+
if (!newWorker) {
|
68 |
+
return
|
69 |
+
}
|
70 |
+
|
71 |
+
setStatus('idle')
|
72 |
+
setWorkerLoaded(false)
|
73 |
+
setActiveWorker(newWorker)
|
74 |
+
|
75 |
+
const onMessageReceived = (e: MessageEvent<WorkerMessage>) => {
|
76 |
+
const { status, output } = e.data
|
77 |
+
if (status === 'initiate') {
|
78 |
+
setStatus('loading')
|
79 |
+
} else if (status === 'ready') {
|
80 |
+
setStatus('ready')
|
81 |
+
setWorkerLoaded(true)
|
82 |
+
} else if (status === 'progress' && output) {
|
83 |
+
setStatus('progress')
|
84 |
+
if (
|
85 |
+
output.progress &&
|
86 |
+
typeof output.file === 'string' &&
|
87 |
+
output.file.startsWith('onnx')
|
88 |
+
) {
|
89 |
+
setProgress(output.progress)
|
90 |
+
}
|
91 |
+
}
|
92 |
+
}
|
93 |
+
|
94 |
+
newWorker.addEventListener('message', onMessageReceived)
|
95 |
+
|
96 |
+
return () => {
|
97 |
+
newWorker.removeEventListener('message', onMessageReceived)
|
98 |
+
// terminateWorker(pipeline);
|
99 |
+
}
|
100 |
+
}, [pipeline, selectedQuantization, setActiveWorker, setStatus, setProgress, setWorkerLoaded])
|
101 |
+
|
102 |
+
const loadModel = useCallback(() => {
|
103 |
+
if (!modelInfo.name || !selectedQuantization) return
|
104 |
+
|
105 |
+
setStatus('loading')
|
106 |
+
const message = {
|
107 |
+
type: 'load',
|
108 |
+
model: modelInfo.name,
|
109 |
+
quantization: selectedQuantization
|
110 |
+
}
|
111 |
+
activeWorker?.postMessage(message)
|
112 |
+
}, [modelInfo.name, selectedQuantization, setStatus, activeWorker])
|
113 |
+
|
114 |
+
const busy: boolean = status !== 'idle'
|
115 |
|
116 |
if (!modelInfo.name) {
|
117 |
return null
|
118 |
}
|
119 |
|
120 |
return (
|
121 |
+
<div className="bg-gradient-to-r from-blue-50 to-indigo-50 px-4 py-3 rounded-lg border border-blue-200 space-y-3">
|
122 |
{/* Model Name Row */}
|
123 |
<div className="flex items-center space-x-2">
|
124 |
<Bot className="w-4 h-4 text-blue-600" />
|
|
|
147 |
</div>
|
148 |
)}
|
149 |
</div>
|
150 |
+
|
151 |
{/* Base Model Link */}
|
152 |
{modelInfo.baseId && (
|
153 |
<div className="flex items-center space-x-2 ml-6">
|
|
|
164 |
</div>
|
165 |
)}
|
166 |
|
|
|
167 |
{/* Stats Row */}
|
168 |
<div className="flex items-center justify-self-end space-x-4 text-xs text-gray-600">
|
169 |
{modelInfo.likes > 0 && (
|
|
|
191 |
<div className="flex items-center space-x-1">
|
192 |
<DatabaseIcon className="w-3 h-3 text-purple-500" />
|
193 |
<span>
|
194 |
+
{`~${getModelSize(
|
195 |
+
modelInfo.parameters,
|
196 |
+
selectedQuantization
|
197 |
+
).toFixed(1)}MB`}
|
198 |
</span>
|
199 |
</div>
|
200 |
)}
|
201 |
</div>
|
202 |
|
203 |
{/* Separator */}
|
204 |
+
{modelInfo.isCompatible &&
|
205 |
+
modelInfo.supportedQuantizations.length > 0 && (
|
206 |
+
<hr className="border-gray-200" />
|
207 |
+
)}
|
208 |
+
|
209 |
{/* Quantization Dropdown */}
|
210 |
+
{modelInfo.isCompatible &&
|
211 |
+
modelInfo.supportedQuantizations.length > 0 && (
|
212 |
+
<div className="flex items-center space-x-2">
|
213 |
+
<span className="text-xs text-gray-600 font-medium">
|
214 |
+
Quantization:
|
215 |
+
</span>
|
216 |
+
<div className="relative">
|
217 |
+
<select
|
218 |
+
value={selectedQuantization || ''}
|
219 |
+
onChange={(e) =>
|
220 |
+
setSelectedQuantization(e.target.value as QuantizationType)
|
221 |
+
}
|
222 |
+
className="appearance-none bg-white border border-gray-300 rounded-md px-3 py-1 pr-8 text-xs text-gray-700 focus:outline-none focus:ring-2 focus:ring-blue-500 focus:border-blue-500"
|
223 |
+
>
|
224 |
+
<option value="">Select quantization</option>
|
225 |
+
{modelInfo.supportedQuantizations.map((quant) => (
|
226 |
+
<option key={quant} value={quant}>
|
227 |
+
{quant}
|
228 |
+
</option>
|
229 |
+
))}
|
230 |
+
</select>
|
231 |
+
<ChevronDown className="absolute right-2 top-1/2 transform -translate-y-1/2 w-3 h-3 text-gray-400 pointer-events-none" />
|
232 |
+
</div>
|
233 |
</div>
|
234 |
+
)}
|
235 |
+
|
236 |
+
{/* Load Model Button */}
|
237 |
+
{modelInfo.isCompatible && selectedQuantization && (
|
238 |
+
<div className="flex justify-center">
|
239 |
+
<button
|
240 |
+
className="py-2 px-4 bg-green-500 hover:bg-green-600 rounded text-white font-medium disabled:opacity-50 disabled:cursor-not-allowed transition-colors text-sm"
|
241 |
+
disabled={busy || !selectedQuantization || workerLoaded}
|
242 |
+
onClick={loadModel}
|
243 |
+
>
|
244 |
+
{status === 'loading'
|
245 |
+
? 'Loading Model...'
|
246 |
+
: workerLoaded
|
247 |
+
? 'Model Ready'
|
248 |
+
: 'Load Model'}
|
249 |
+
</button>
|
250 |
</div>
|
251 |
)}
|
252 |
|
src/components/ModelSelector.tsx
CHANGED
@@ -8,7 +8,7 @@ type SortOption = 'likes' | 'downloads' | 'createdAt' | 'name'
|
|
8 |
|
9 |
const ModelSelector: React.FC = () => {
|
10 |
const { models, setModelInfo, modelInfo, pipeline } = useModel()
|
11 |
-
const [sortBy, setSortBy] = useState<SortOption>('
|
12 |
const [sortOrder, setSortOrder] = useState<'asc' | 'desc'>('desc')
|
13 |
|
14 |
const formatNumber = (num: number) => {
|
|
|
8 |
|
9 |
const ModelSelector: React.FC = () => {
|
10 |
const { models, setModelInfo, modelInfo, pipeline } = useModel()
|
11 |
+
const [sortBy, setSortBy] = useState<SortOption>('downloads')
|
12 |
const [sortOrder, setSortOrder] = useState<'asc' | 'desc'>('desc')
|
13 |
|
14 |
const formatNumber = (num: number) => {
|
src/components/PipelineSelector.tsx
CHANGED
@@ -9,8 +9,8 @@ import {
|
|
9 |
import { ChevronDown, Check } from 'lucide-react';
|
10 |
|
11 |
const pipelines = [
|
12 |
-
'zero-shot-classification',
|
13 |
'text-classification',
|
|
|
14 |
'text-generation',
|
15 |
'summarization',
|
16 |
'feature-extraction',
|
|
|
9 |
import { ChevronDown, Check } from 'lucide-react';
|
10 |
|
11 |
const pipelines = [
|
|
|
12 |
'text-classification',
|
13 |
+
'zero-shot-classification',
|
14 |
'text-generation',
|
15 |
'summarization',
|
16 |
'feature-extraction',
|
src/components/TextClassification.tsx
CHANGED
@@ -6,6 +6,7 @@ import {
|
|
6 |
} from '../types';
|
7 |
import { useModel } from '../contexts/ModelContext';
|
8 |
import { getModelInfo } from '../lib/huggingface';
|
|
|
9 |
|
10 |
|
11 |
const PLACEHOLDER_TEXTS: string[] = [
|
@@ -24,7 +25,8 @@ const PLACEHOLDER_TEXTS: string[] = [
|
|
24 |
function TextClassification() {
|
25 |
const [text, setText] = useState<string>(PLACEHOLDER_TEXTS.join('\n'))
|
26 |
const [results, setResults] = useState<ClassificationOutput[]>([])
|
27 |
-
const { setProgress, status, setStatus, modelInfo, setModelInfo,
|
|
|
28 |
|
29 |
|
30 |
useEffect(() => {
|
@@ -56,43 +58,23 @@ function TextClassification() {
|
|
56 |
fetchModelInfo()
|
57 |
}, [modelInfo.id, setModelInfo])
|
58 |
|
59 |
-
// Create a reference to the worker object.
|
60 |
-
const worker = useRef<Worker | null>(null)
|
61 |
-
|
62 |
// We use the `useEffect` hook to setup the worker as soon as the component is mounted.
|
63 |
useEffect(() => {
|
64 |
-
if
|
65 |
-
|
66 |
-
worker.current = new Worker(
|
67 |
-
new URL('../workers/text-classification.js', import.meta.url),
|
68 |
-
{
|
69 |
-
type: 'module'
|
70 |
-
}
|
71 |
-
)
|
72 |
}
|
73 |
|
|
|
74 |
// Create a callback function for messages from the worker thread.
|
75 |
const onMessageReceived = (e: MessageEvent<WorkerMessage>) => {
|
76 |
const status = e.data.status
|
77 |
-
if (status === '
|
78 |
-
setStatus('loading')
|
79 |
-
} else if (status === 'ready') {
|
80 |
-
setStatus('ready')
|
81 |
-
} else if (status === 'progress') {
|
82 |
-
setStatus('progress')
|
83 |
-
if (
|
84 |
-
e.data.output.progress &&
|
85 |
-
(e.data.output.file as string).startsWith('onnx')
|
86 |
-
)
|
87 |
-
setProgress(e.data.output.progress)
|
88 |
-
} else if (status === 'output') {
|
89 |
setStatus('output')
|
90 |
const result = e.data.output!
|
91 |
setResults((prevResults) => [...prevResults, result])
|
92 |
console.log(result)
|
93 |
} else if (status === 'complete') {
|
94 |
setStatus('idle')
|
95 |
-
setProgress(100)
|
96 |
} else if (status === 'error') {
|
97 |
setStatus('error')
|
98 |
console.error(e.data.output)
|
@@ -100,21 +82,26 @@ function TextClassification() {
|
|
100 |
}
|
101 |
|
102 |
// Attach the callback function as an event listener.
|
103 |
-
|
104 |
|
105 |
// Define a cleanup function for when the component is unmounted.
|
106 |
return () =>
|
107 |
-
|
108 |
}, [])
|
109 |
|
110 |
const classify = useCallback(() => {
|
111 |
setStatus('processing')
|
112 |
setResults([]) // Clear previous results
|
113 |
-
const message: TextClassificationWorkerInput = {
|
114 |
-
|
|
|
|
|
|
|
|
|
115 |
}, [text, modelInfo.id])
|
116 |
|
117 |
-
const busy: boolean = status !== '
|
|
|
118 |
|
119 |
const handleClear = (): void => {
|
120 |
setResults([])
|
@@ -138,14 +125,14 @@ function TextClassification() {
|
|
138 |
<div className="flex gap-2 mt-4">
|
139 |
<button
|
140 |
className="flex-1 py-2 px-4 bg-blue-500 hover:bg-blue-600 rounded text-white font-medium disabled:opacity-50 disabled:cursor-not-allowed transition-colors"
|
141 |
-
disabled={busy}
|
142 |
onClick={classify}
|
143 |
>
|
144 |
-
{!busy
|
145 |
? 'Classify Text'
|
146 |
: status === 'loading'
|
147 |
? 'Model loading...'
|
148 |
-
: 'Processing...'}
|
149 |
</button>
|
150 |
<button
|
151 |
className="py-2 px-4 bg-gray-500 hover:bg-gray-600 rounded text-white font-medium transition-colors"
|
|
|
6 |
} from '../types';
|
7 |
import { useModel } from '../contexts/ModelContext';
|
8 |
import { getModelInfo } from '../lib/huggingface';
|
9 |
+
import { getWorker } from '../lib/workerManager';
|
10 |
|
11 |
|
12 |
const PLACEHOLDER_TEXTS: string[] = [
|
|
|
25 |
function TextClassification() {
|
26 |
const [text, setText] = useState<string>(PLACEHOLDER_TEXTS.join('\n'))
|
27 |
const [results, setResults] = useState<ClassificationOutput[]>([])
|
28 |
+
const { setProgress, status, setStatus, modelInfo, setModelInfo, workerLoaded} = useModel()
|
29 |
+
const workerRef = useRef<Worker | null>(null)
|
30 |
|
31 |
|
32 |
useEffect(() => {
|
|
|
58 |
fetchModelInfo()
|
59 |
}, [modelInfo.id, setModelInfo])
|
60 |
|
|
|
|
|
|
|
61 |
// We use the `useEffect` hook to setup the worker as soon as the component is mounted.
|
62 |
useEffect(() => {
|
63 |
+
if(!workerRef.current) {
|
64 |
+
workerRef.current = getWorker('text-classification')
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
}
|
66 |
|
67 |
+
|
68 |
// Create a callback function for messages from the worker thread.
|
69 |
const onMessageReceived = (e: MessageEvent<WorkerMessage>) => {
|
70 |
const status = e.data.status
|
71 |
+
if (status === 'output') {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
72 |
setStatus('output')
|
73 |
const result = e.data.output!
|
74 |
setResults((prevResults) => [...prevResults, result])
|
75 |
console.log(result)
|
76 |
} else if (status === 'complete') {
|
77 |
setStatus('idle')
|
|
|
78 |
} else if (status === 'error') {
|
79 |
setStatus('error')
|
80 |
console.error(e.data.output)
|
|
|
82 |
}
|
83 |
|
84 |
// Attach the callback function as an event listener.
|
85 |
+
workerRef.current?.addEventListener('message', onMessageReceived)
|
86 |
|
87 |
// Define a cleanup function for when the component is unmounted.
|
88 |
return () =>
|
89 |
+
workerRef.current?.removeEventListener('message', onMessageReceived)
|
90 |
}, [])
|
91 |
|
92 |
const classify = useCallback(() => {
|
93 |
setStatus('processing')
|
94 |
setResults([]) // Clear previous results
|
95 |
+
const message: TextClassificationWorkerInput = {
|
96 |
+
type: 'classify',
|
97 |
+
text,
|
98 |
+
model: modelInfo.id
|
99 |
+
}
|
100 |
+
workerRef.current?.postMessage(message)
|
101 |
}, [text, modelInfo.id])
|
102 |
|
103 |
+
const busy: boolean = status !== 'ready'
|
104 |
+
|
105 |
|
106 |
const handleClear = (): void => {
|
107 |
setResults([])
|
|
|
125 |
<div className="flex gap-2 mt-4">
|
126 |
<button
|
127 |
className="flex-1 py-2 px-4 bg-blue-500 hover:bg-blue-600 rounded text-white font-medium disabled:opacity-50 disabled:cursor-not-allowed transition-colors"
|
128 |
+
disabled={busy || !workerLoaded}
|
129 |
onClick={classify}
|
130 |
>
|
131 |
+
{workerLoaded ? (!busy
|
132 |
? 'Classify Text'
|
133 |
: status === 'loading'
|
134 |
? 'Model loading...'
|
135 |
+
: 'Processing...') : 'Load model first'}
|
136 |
</button>
|
137 |
<button
|
138 |
className="py-2 px-4 bg-gray-500 hover:bg-gray-600 rounded text-white font-medium transition-colors"
|
src/components/ZeroShotClassification.tsx
CHANGED
@@ -59,13 +59,14 @@ function ZeroShotClassification() {
|
|
59 |
// We use the `useEffect` hook to setup the worker as soon as the `App` component is mounted.
|
60 |
useEffect(() => {
|
61 |
if (!worker.current) {
|
|
|
62 |
// Create the worker if it does not yet exist.
|
63 |
-
worker.current = new Worker(
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
)
|
69 |
}
|
70 |
|
71 |
// Create a callback function for messages from the worker thread.
|
|
|
59 |
// We use the `useEffect` hook to setup the worker as soon as the `App` component is mounted.
|
60 |
useEffect(() => {
|
61 |
if (!worker.current) {
|
62 |
+
return
|
63 |
// Create the worker if it does not yet exist.
|
64 |
+
// worker.current = new Worker(
|
65 |
+
// new URL('../workers/zero-shot-classification.js', import.meta.url),
|
66 |
+
// {
|
67 |
+
// type: 'module'
|
68 |
+
// }
|
69 |
+
// )
|
70 |
}
|
71 |
|
72 |
// Create a callback function for messages from the worker thread.
|
src/contexts/ModelContext.tsx
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
import React, { createContext, useContext, useEffect, useState } from 'react'
|
2 |
import { ModelInfo, ModelInfoResponse, QuantizationType } from '../types'
|
3 |
|
4 |
interface ModelContextType {
|
@@ -14,6 +14,10 @@ interface ModelContextType {
|
|
14 |
setModels: (models: ModelInfoResponse[]) => void
|
15 |
selectedQuantization: QuantizationType
|
16 |
setSelectedQuantization: (quantization: QuantizationType) => void
|
|
|
|
|
|
|
|
|
17 |
}
|
18 |
|
19 |
const ModelContext = createContext<ModelContextType | undefined>(undefined)
|
@@ -23,8 +27,11 @@ export function ModelProvider({ children }: { children: React.ReactNode }) {
|
|
23 |
const [status, setStatus] = useState<string>('idle')
|
24 |
const [modelInfo, setModelInfo] = useState<ModelInfo>({} as ModelInfo)
|
25 |
const [models, setModels] = useState<ModelInfoResponse[]>([] as ModelInfoResponse[])
|
26 |
-
const [pipeline, setPipeline] = useState<string>('
|
27 |
const [selectedQuantization, setSelectedQuantization] = useState<QuantizationType>('int8')
|
|
|
|
|
|
|
28 |
|
29 |
// set progress to 0 when model is changed
|
30 |
useEffect(() => {
|
@@ -46,6 +53,10 @@ export function ModelProvider({ children }: { children: React.ReactNode }) {
|
|
46 |
setPipeline,
|
47 |
selectedQuantization,
|
48 |
setSelectedQuantization,
|
|
|
|
|
|
|
|
|
49 |
}}
|
50 |
>
|
51 |
{children}
|
|
|
1 |
+
import React, { createContext, RefObject, useContext, useEffect, useRef, useState } from 'react'
|
2 |
import { ModelInfo, ModelInfoResponse, QuantizationType } from '../types'
|
3 |
|
4 |
interface ModelContextType {
|
|
|
14 |
setModels: (models: ModelInfoResponse[]) => void
|
15 |
selectedQuantization: QuantizationType
|
16 |
setSelectedQuantization: (quantization: QuantizationType) => void
|
17 |
+
activeWorker: Worker | null
|
18 |
+
setActiveWorker: (worker: Worker | null) => void
|
19 |
+
workerLoaded: boolean
|
20 |
+
setWorkerLoaded: (workerLoaded: boolean) => void
|
21 |
}
|
22 |
|
23 |
const ModelContext = createContext<ModelContextType | undefined>(undefined)
|
|
|
27 |
const [status, setStatus] = useState<string>('idle')
|
28 |
const [modelInfo, setModelInfo] = useState<ModelInfo>({} as ModelInfo)
|
29 |
const [models, setModels] = useState<ModelInfoResponse[]>([] as ModelInfoResponse[])
|
30 |
+
const [pipeline, setPipeline] = useState<string>('text-classification')
|
31 |
const [selectedQuantization, setSelectedQuantization] = useState<QuantizationType>('int8')
|
32 |
+
const [activeWorker, setActiveWorker] = useState<Worker | null>(null)
|
33 |
+
const [workerLoaded, setWorkerLoaded] = useState<boolean>(false)
|
34 |
+
|
35 |
|
36 |
// set progress to 0 when model is changed
|
37 |
useEffect(() => {
|
|
|
53 |
setPipeline,
|
54 |
selectedQuantization,
|
55 |
setSelectedQuantization,
|
56 |
+
activeWorker,
|
57 |
+
setActiveWorker,
|
58 |
+
workerLoaded,
|
59 |
+
setWorkerLoaded
|
60 |
}}
|
61 |
>
|
62 |
{children}
|
src/lib/workerManager.ts
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
const workers: Record<string, Worker | null> = {};
|
2 |
+
|
3 |
+
export const getWorker = (pipeline: string) => {
|
4 |
+
if (!workers[pipeline]) {
|
5 |
+
let workerUrl: string;
|
6 |
+
|
7 |
+
// Construct the public URL for the worker script.
|
8 |
+
// process.env.PUBLIC_URL ensures this works correctly even if the
|
9 |
+
// app is hosted in a sub-directory.
|
10 |
+
switch (pipeline) {
|
11 |
+
case 'text-classification':
|
12 |
+
workerUrl = `${process.env.PUBLIC_URL}/workers/text-classification.js`;
|
13 |
+
break;
|
14 |
+
case 'zero-shot-classification':
|
15 |
+
workerUrl = `${process.env.PUBLIC_URL}/workers/zero-shot-classification.js`;
|
16 |
+
break;
|
17 |
+
// Add other pipeline types here
|
18 |
+
default:
|
19 |
+
// Return null or throw an error if the pipeline is unknown
|
20 |
+
return null;
|
21 |
+
}
|
22 |
+
workers[pipeline] = new Worker(workerUrl, { type: 'module' });
|
23 |
+
}
|
24 |
+
return workers[pipeline];
|
25 |
+
};
|
26 |
+
|
27 |
+
export const terminateWorker = (pipeline: string) => {
|
28 |
+
const worker = workers[pipeline];
|
29 |
+
if (worker) {
|
30 |
+
worker.terminate();
|
31 |
+
delete workers[pipeline];
|
32 |
+
}
|
33 |
+
};
|
src/types.ts
CHANGED
@@ -10,7 +10,9 @@ export interface ClassificationOutput {
|
|
10 |
}
|
11 |
|
12 |
export interface WorkerMessage {
|
13 |
-
status: 'initiate' | 'ready' | 'output' | 'complete' | 'progress'
|
|
|
|
|
14 |
output?: any
|
15 |
}
|
16 |
|
@@ -21,6 +23,7 @@ export interface ZeroShotWorkerInput {
|
|
21 |
}
|
22 |
|
23 |
export interface TextClassificationWorkerInput {
|
|
|
24 |
text: string
|
25 |
model: string
|
26 |
}
|
|
|
10 |
}
|
11 |
|
12 |
export interface WorkerMessage {
|
13 |
+
status: 'initiate' | 'ready' | 'output' | 'complete' | 'progress' | 'error'
|
14 |
+
progress?: number
|
15 |
+
error?: string
|
16 |
output?: any
|
17 |
}
|
18 |
|
|
|
23 |
}
|
24 |
|
25 |
export interface TextClassificationWorkerInput {
|
26 |
+
type: 'classify'
|
27 |
text: string
|
28 |
model: string
|
29 |
}
|
src/workers/text-classification.js
DELETED
@@ -1,55 +0,0 @@
|
|
1 |
-
/* eslint-disable no-restricted-globals */
|
2 |
-
import { pipeline } from '@huggingface/transformers';
|
3 |
-
|
4 |
-
class MyTextClassificationPipeline {
|
5 |
-
static task = 'text-classification';
|
6 |
-
static instance = null;
|
7 |
-
|
8 |
-
static async getInstance(model, progress_callback = null) {
|
9 |
-
this.instance ??= pipeline(this.task, model, {
|
10 |
-
progress_callback
|
11 |
-
});
|
12 |
-
|
13 |
-
return this.instance;
|
14 |
-
}
|
15 |
-
}
|
16 |
-
|
17 |
-
// Listen for messages from the main thread
|
18 |
-
self.addEventListener('message', async (event) => {
|
19 |
-
const { text, model } = event.data;
|
20 |
-
if (!model) {
|
21 |
-
self.postMessage({
|
22 |
-
status: 'error',
|
23 |
-
output: 'No model provided'
|
24 |
-
});
|
25 |
-
return;
|
26 |
-
}
|
27 |
-
|
28 |
-
// Retrieve the pipeline. When called for the first time,
|
29 |
-
// this will load the pipeline and save it for future use.
|
30 |
-
const classifier = await MyTextClassificationPipeline.getInstance(model, (x) => {
|
31 |
-
// We also add a progress callback to the pipeline so that we can
|
32 |
-
// track model loading.
|
33 |
-
self.postMessage({ status: 'progress', output: x });
|
34 |
-
});
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
const split = text.split('\n');
|
39 |
-
for (const line of split) {
|
40 |
-
if (line.trim()) {
|
41 |
-
const output = await classifier(line);
|
42 |
-
// Send the output back to the main thread
|
43 |
-
self.postMessage({
|
44 |
-
status: 'output',
|
45 |
-
output: {
|
46 |
-
sequence: line,
|
47 |
-
labels: [output[0].label],
|
48 |
-
scores: [output[0].score]
|
49 |
-
}
|
50 |
-
});
|
51 |
-
}
|
52 |
-
}
|
53 |
-
// Send the output back to the main thread
|
54 |
-
self.postMessage({ status: 'complete' });
|
55 |
-
});
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tsconfig.json
CHANGED
@@ -1,7 +1,12 @@
|
|
1 |
{
|
2 |
"compilerOptions": {
|
3 |
-
"target": "
|
4 |
-
"lib": [
|
|
|
|
|
|
|
|
|
|
|
5 |
"allowJs": true,
|
6 |
"skipLibCheck": true,
|
7 |
"esModuleInterop": true,
|
@@ -16,5 +21,5 @@
|
|
16 |
"noEmit": true,
|
17 |
"jsx": "react-jsx"
|
18 |
},
|
19 |
-
"include": ["src"]
|
20 |
-
}
|
|
|
1 |
{
|
2 |
"compilerOptions": {
|
3 |
+
"target": "es2020",
|
4 |
+
"lib": [
|
5 |
+
"dom",
|
6 |
+
"dom.iterable",
|
7 |
+
"esnext",
|
8 |
+
"WebWorker"
|
9 |
+
],
|
10 |
"allowJs": true,
|
11 |
"skipLibCheck": true,
|
12 |
"esModuleInterop": true,
|
|
|
21 |
"noEmit": true,
|
22 |
"jsx": "react-jsx"
|
23 |
},
|
24 |
+
"include": ["src", "public/workers"]
|
25 |
+
}
|