Enhance text classification pipeline: add device support, improve error handling, and refine message processing logic
Browse files- public/workers/text-classification.js +47 -37
- src/components/ModelLoader.tsx +7 -1
- src/components/TextClassification.tsx +33 -32
- src/types.ts +7 -1
public/workers/text-classification.js
CHANGED
@@ -9,7 +9,7 @@ class MyTextClassificationPipeline {
|
|
9 |
this.instance = pipeline(
|
10 |
this.task,
|
11 |
model,
|
12 |
-
{ dtype, progress_callback },
|
13 |
)
|
14 |
return this.instance
|
15 |
}
|
@@ -17,49 +17,59 @@ class MyTextClassificationPipeline {
|
|
17 |
|
18 |
// Listen for messages from the main thread
|
19 |
self.addEventListener('message', async (event) => {
|
20 |
-
|
|
|
21 |
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
}
|
29 |
-
|
30 |
-
// Retrieve the pipeline. This will download the model if not already cached.
|
31 |
-
const classifier = await MyTextClassificationPipeline.getInstance(
|
32 |
-
model,
|
33 |
-
dtype,
|
34 |
-
(x) => {
|
35 |
-
self.postMessage({ status: 'loading', output: x })
|
36 |
}
|
37 |
-
)
|
38 |
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
|
|
|
|
|
|
|
|
43 |
|
44 |
-
|
45 |
-
|
46 |
-
|
|
|
|
|
47 |
return
|
48 |
}
|
49 |
-
|
50 |
-
|
51 |
-
if (
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
}
|
|
|
62 |
}
|
63 |
-
|
|
|
|
|
|
|
|
|
64 |
}
|
65 |
})
|
|
|
9 |
this.instance = pipeline(
|
10 |
this.task,
|
11 |
model,
|
12 |
+
{ dtype, device: "webgpu", progress_callback },
|
13 |
)
|
14 |
return this.instance
|
15 |
}
|
|
|
17 |
|
18 |
// Listen for messages from the main thread
|
19 |
self.addEventListener('message', async (event) => {
|
20 |
+
try {
|
21 |
+
const { type, model, dtype, text } = event.data
|
22 |
|
23 |
+
if (!model) {
|
24 |
+
self.postMessage({
|
25 |
+
status: 'error',
|
26 |
+
output: 'No model provided'
|
27 |
+
})
|
28 |
+
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
}
|
|
|
30 |
|
31 |
+
// Retrieve the pipeline. This will download the model if not already cached.
|
32 |
+
const classifier = await MyTextClassificationPipeline.getInstance(
|
33 |
+
model,
|
34 |
+
dtype,
|
35 |
+
(x) => {
|
36 |
+
self.postMessage({ status: 'loading', output: x })
|
37 |
+
}
|
38 |
+
)
|
39 |
|
40 |
+
if (type === 'load') {
|
41 |
+
self.postMessage({
|
42 |
+
status: 'ready',
|
43 |
+
output: `Model ${model}, dtype ${dtype} loaded`
|
44 |
+
})
|
45 |
return
|
46 |
}
|
47 |
+
|
48 |
+
if (type === 'classify') {
|
49 |
+
if (!text) {
|
50 |
+
self.postMessage({ status: 'ready' }) // Nothing to process
|
51 |
+
return
|
52 |
+
}
|
53 |
+
const split = text.split('\n')
|
54 |
+
for (const line of split) {
|
55 |
+
if (line.trim()) {
|
56 |
+
const output = await classifier(line)
|
57 |
+
self.postMessage({
|
58 |
+
status: 'output',
|
59 |
+
output: {
|
60 |
+
sequence: line,
|
61 |
+
labels: [output[0].label],
|
62 |
+
scores: [output[0].score]
|
63 |
+
}
|
64 |
+
})
|
65 |
+
}
|
66 |
}
|
67 |
+
self.postMessage({ status: 'ready' })
|
68 |
}
|
69 |
+
} catch (error) {
|
70 |
+
self.postMessage({
|
71 |
+
status: 'error',
|
72 |
+
output: error.message || 'An error occurred during processing'
|
73 |
+
})
|
74 |
}
|
75 |
})
|
src/components/ModelLoader.tsx
CHANGED
@@ -21,6 +21,9 @@ const ModelLoader = () => {
|
|
21 |
setHasBeenLoaded
|
22 |
} = useModel()
|
23 |
|
|
|
|
|
|
|
24 |
|
25 |
useEffect(() => {
|
26 |
if (!modelInfo) return
|
@@ -43,6 +46,8 @@ const ModelLoader = () => {
|
|
43 |
setHasBeenLoaded(false)
|
44 |
}, [modelInfo, setSelectedQuantization, setHasBeenLoaded])
|
45 |
|
|
|
|
|
46 |
useEffect(() => {
|
47 |
if (!modelInfo) return
|
48 |
|
@@ -61,6 +66,7 @@ const ModelLoader = () => {
|
|
61 |
const { status, output } = e.data
|
62 |
if (status === 'ready') {
|
63 |
setStatus('ready')
|
|
|
64 |
setHasBeenLoaded(true)
|
65 |
} else if (status === 'loading' && output && !hasBeenLoaded) {
|
66 |
setStatus('loading')
|
@@ -156,7 +162,7 @@ const ModelLoader = () => {
|
|
156 |
<div className="flex justify-center">
|
157 |
<button
|
158 |
className="w-32 py-2 px-4 bg-green-500 hover:bg-green-600 rounded text-white font-medium disabled:opacity-50 disabled:cursor-not-allowed transition-colors text-sm inline-flex items-center text-center justify-center space-x-2"
|
159 |
-
disabled={hasBeenLoaded}
|
160 |
onClick={loadModel}
|
161 |
>
|
162 |
{status === 'loading' && !hasBeenLoaded ? (
|
|
|
21 |
setHasBeenLoaded
|
22 |
} = useModel()
|
23 |
|
24 |
+
useEffect(() => {
|
25 |
+
setHasBeenLoaded(false)
|
26 |
+
}, [selectedQuantization])
|
27 |
|
28 |
useEffect(() => {
|
29 |
if (!modelInfo) return
|
|
|
46 |
setHasBeenLoaded(false)
|
47 |
}, [modelInfo, setSelectedQuantization, setHasBeenLoaded])
|
48 |
|
49 |
+
|
50 |
+
|
51 |
useEffect(() => {
|
52 |
if (!modelInfo) return
|
53 |
|
|
|
66 |
const { status, output } = e.data
|
67 |
if (status === 'ready') {
|
68 |
setStatus('ready')
|
69 |
+
if (e.data.output) console.log(e.data.output)
|
70 |
setHasBeenLoaded(true)
|
71 |
} else if (status === 'loading' && output && !hasBeenLoaded) {
|
72 |
setStatus('loading')
|
|
|
162 |
<div className="flex justify-center">
|
163 |
<button
|
164 |
className="w-32 py-2 px-4 bg-green-500 hover:bg-green-600 rounded text-white font-medium disabled:opacity-50 disabled:cursor-not-allowed transition-colors text-sm inline-flex items-center text-center justify-center space-x-2"
|
165 |
+
disabled={hasBeenLoaded || status === 'loading'}
|
166 |
onClick={loadModel}
|
167 |
>
|
168 |
{status === 'loading' && !hasBeenLoaded ? (
|
src/components/TextClassification.tsx
CHANGED
@@ -24,8 +24,6 @@ function TextClassification() {
|
|
24 |
const [text, setText] = useState<string>(PLACEHOLDER_TEXTS.join('\n'))
|
25 |
const { activeWorker, status, setStatus, modelInfo, results, setResults, hasBeenLoaded} = useModel()
|
26 |
|
27 |
-
|
28 |
-
|
29 |
const classify = useCallback(() => {
|
30 |
if (!modelInfo || !activeWorker) {
|
31 |
console.error('Model info or worker is not available')
|
@@ -48,46 +46,49 @@ function TextClassification() {
|
|
48 |
|
49 |
return (
|
50 |
<div className="flex flex-col h-[60vh] max-h-[100vh] w-full p-4">
|
51 |
-
<h1 className="text-2xl font-bold mb-4">Text Classification</h1>
|
52 |
|
53 |
-
<div className="flex flex-col lg:flex-row gap-4 h-
|
54 |
{/* Input Section */}
|
55 |
-
<div className="flex flex-col w-full lg:w-1/2">
|
56 |
-
<label className="text-lg font-medium mb-2">Input Text:</label>
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
|
|
|
|
63 |
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
|
|
81 |
</div>
|
82 |
</div>
|
83 |
|
84 |
{/* Results Section */}
|
85 |
-
<div className="flex flex-col w-full lg:w-1/2">
|
86 |
-
<label className="text-lg font-medium mb-2">
|
87 |
Classification Results ({results.length}):
|
88 |
</label>
|
89 |
|
90 |
-
<div className="border border-gray-300 rounded p-3 flex-
|
91 |
{results.length === 0 ? (
|
92 |
<div className="text-gray-500 text-center py-8">
|
93 |
No results yet. Click "Classify Text" to analyze your input.
|
|
|
24 |
const [text, setText] = useState<string>(PLACEHOLDER_TEXTS.join('\n'))
|
25 |
const { activeWorker, status, setStatus, modelInfo, results, setResults, hasBeenLoaded} = useModel()
|
26 |
|
|
|
|
|
27 |
const classify = useCallback(() => {
|
28 |
if (!modelInfo || !activeWorker) {
|
29 |
console.error('Model info or worker is not available')
|
|
|
46 |
|
47 |
return (
|
48 |
<div className="flex flex-col h-[60vh] max-h-[100vh] w-full p-4">
|
49 |
+
<h1 className="text-2xl font-bold mb-4 flex-shrink-0">Text Classification</h1>
|
50 |
|
51 |
+
<div className="flex flex-col lg:flex-row gap-4 flex-1 min-h-0">
|
52 |
{/* Input Section */}
|
53 |
+
<div className="flex flex-col w-full lg:w-1/2 min-h-0">
|
54 |
+
<label className="text-lg font-medium mb-2 flex-shrink-0">Input Text:</label>
|
55 |
+
|
56 |
+
<div className="flex flex-col flex-1 min-h-0">
|
57 |
+
<textarea
|
58 |
+
className="border border-gray-300 rounded p-3 flex-1 resize-none min-h-[200px]"
|
59 |
+
value={text}
|
60 |
+
onChange={(e) => setText(e.target.value)}
|
61 |
+
placeholder="Enter text to classify (one per line)..."
|
62 |
+
/>
|
63 |
|
64 |
+
<div className="flex gap-2 mt-4 flex-shrink-0">
|
65 |
+
<button
|
66 |
+
className="flex-1 py-2 px-4 bg-blue-500 hover:bg-blue-600 rounded text-white font-medium disabled:opacity-50 disabled:cursor-not-allowed transition-colors"
|
67 |
+
disabled={busy}
|
68 |
+
onClick={classify}
|
69 |
+
>
|
70 |
+
{hasBeenLoaded ? !busy
|
71 |
+
? 'Classify Text'
|
72 |
+
: 'Processing...'
|
73 |
+
: 'Load model first'}
|
74 |
+
</button>
|
75 |
+
<button
|
76 |
+
className="py-2 px-4 bg-gray-500 hover:bg-gray-600 rounded text-white font-medium transition-colors"
|
77 |
+
onClick={handleClear}
|
78 |
+
>
|
79 |
+
Clear Results
|
80 |
+
</button>
|
81 |
+
</div>
|
82 |
</div>
|
83 |
</div>
|
84 |
|
85 |
{/* Results Section */}
|
86 |
+
<div className="flex flex-col w-full lg:w-1/2 min-h-0">
|
87 |
+
<label className="text-lg font-medium mb-2 flex-shrink-0">
|
88 |
Classification Results ({results.length}):
|
89 |
</label>
|
90 |
|
91 |
+
<div className="border border-gray-300 rounded p-3 flex-1 overflow-y-auto min-h-[200px]">
|
92 |
{results.length === 0 ? (
|
93 |
<div className="text-gray-500 text-center py-8">
|
94 |
No results yet. Click "Classify Text" to analyze your input.
|
src/types.ts
CHANGED
@@ -9,7 +9,13 @@ export interface ClassificationOutput {
|
|
9 |
scores: number[]
|
10 |
}
|
11 |
|
12 |
-
export type WorkerStatus =
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
|
14 |
export interface WorkerMessage {
|
15 |
status: WorkerStatus
|
|
|
9 |
scores: number[]
|
10 |
}
|
11 |
|
12 |
+
export type WorkerStatus =
|
13 |
+
| 'initiate'
|
14 |
+
| 'ready'
|
15 |
+
| 'output'
|
16 |
+
| 'loading'
|
17 |
+
| 'error'
|
18 |
+
| 'disposed'
|
19 |
|
20 |
export interface WorkerMessage {
|
21 |
status: WorkerStatus
|