Vokturz commited on
Commit
bd915ca
·
1 Parent(s): 4d810fa

Enhance text classification pipeline: add device support, improve error handling, and refine message processing logic

Browse files
public/workers/text-classification.js CHANGED
@@ -9,7 +9,7 @@ class MyTextClassificationPipeline {
9
  this.instance = pipeline(
10
  this.task,
11
  model,
12
- { dtype, progress_callback },
13
  )
14
  return this.instance
15
  }
@@ -17,49 +17,59 @@ class MyTextClassificationPipeline {
17
 
18
  // Listen for messages from the main thread
19
  self.addEventListener('message', async (event) => {
20
- const { type, model, dtype, text } = event.data
 
21
 
22
- if (!model) {
23
- self.postMessage({
24
- status: 'error',
25
- output: 'No model provided'
26
- })
27
- return
28
- }
29
-
30
- // Retrieve the pipeline. This will download the model if not already cached.
31
- const classifier = await MyTextClassificationPipeline.getInstance(
32
- model,
33
- dtype,
34
- (x) => {
35
- self.postMessage({ status: 'loading', output: x })
36
  }
37
- )
38
 
39
- if (type === 'load') {
40
- self.postMessage({ status: 'ready' })
41
- return
42
- }
 
 
 
 
43
 
44
- if (type === 'classify') {
45
- if (!text) {
46
- self.postMessage({ status: 'ready' }) // Nothing to process
 
 
47
  return
48
  }
49
- const split = text.split('\n')
50
- for (const line of split) {
51
- if (line.trim()) {
52
- const output = await classifier(line)
53
- self.postMessage({
54
- status: 'output',
55
- output: {
56
- sequence: line,
57
- labels: [output[0].label],
58
- scores: [output[0].score]
59
- }
60
- })
 
 
 
 
 
 
 
61
  }
 
62
  }
63
- self.postMessage({ status: 'ready' })
 
 
 
 
64
  }
65
  })
 
9
  this.instance = pipeline(
10
  this.task,
11
  model,
12
+ { dtype, device: "webgpu", progress_callback },
13
  )
14
  return this.instance
15
  }
 
17
 
18
  // Listen for messages from the main thread
19
  self.addEventListener('message', async (event) => {
20
+ try {
21
+ const { type, model, dtype, text } = event.data
22
 
23
+ if (!model) {
24
+ self.postMessage({
25
+ status: 'error',
26
+ output: 'No model provided'
27
+ })
28
+ return
 
 
 
 
 
 
 
 
29
  }
 
30
 
31
+ // Retrieve the pipeline. This will download the model if not already cached.
32
+ const classifier = await MyTextClassificationPipeline.getInstance(
33
+ model,
34
+ dtype,
35
+ (x) => {
36
+ self.postMessage({ status: 'loading', output: x })
37
+ }
38
+ )
39
 
40
+ if (type === 'load') {
41
+ self.postMessage({
42
+ status: 'ready',
43
+ output: `Model ${model}, dtype ${dtype} loaded`
44
+ })
45
  return
46
  }
47
+
48
+ if (type === 'classify') {
49
+ if (!text) {
50
+ self.postMessage({ status: 'ready' }) // Nothing to process
51
+ return
52
+ }
53
+ const split = text.split('\n')
54
+ for (const line of split) {
55
+ if (line.trim()) {
56
+ const output = await classifier(line)
57
+ self.postMessage({
58
+ status: 'output',
59
+ output: {
60
+ sequence: line,
61
+ labels: [output[0].label],
62
+ scores: [output[0].score]
63
+ }
64
+ })
65
+ }
66
  }
67
+ self.postMessage({ status: 'ready' })
68
  }
69
+ } catch (error) {
70
+ self.postMessage({
71
+ status: 'error',
72
+ output: error.message || 'An error occurred during processing'
73
+ })
74
  }
75
  })
src/components/ModelLoader.tsx CHANGED
@@ -21,6 +21,9 @@ const ModelLoader = () => {
21
  setHasBeenLoaded
22
  } = useModel()
23
 
 
 
 
24
 
25
  useEffect(() => {
26
  if (!modelInfo) return
@@ -43,6 +46,8 @@ const ModelLoader = () => {
43
  setHasBeenLoaded(false)
44
  }, [modelInfo, setSelectedQuantization, setHasBeenLoaded])
45
 
 
 
46
  useEffect(() => {
47
  if (!modelInfo) return
48
 
@@ -61,6 +66,7 @@ const ModelLoader = () => {
61
  const { status, output } = e.data
62
  if (status === 'ready') {
63
  setStatus('ready')
 
64
  setHasBeenLoaded(true)
65
  } else if (status === 'loading' && output && !hasBeenLoaded) {
66
  setStatus('loading')
@@ -156,7 +162,7 @@ const ModelLoader = () => {
156
  <div className="flex justify-center">
157
  <button
158
  className="w-32 py-2 px-4 bg-green-500 hover:bg-green-600 rounded text-white font-medium disabled:opacity-50 disabled:cursor-not-allowed transition-colors text-sm inline-flex items-center text-center justify-center space-x-2"
159
- disabled={hasBeenLoaded}
160
  onClick={loadModel}
161
  >
162
  {status === 'loading' && !hasBeenLoaded ? (
 
21
  setHasBeenLoaded
22
  } = useModel()
23
 
24
+ useEffect(() => {
25
+ setHasBeenLoaded(false)
26
+ }, [selectedQuantization])
27
 
28
  useEffect(() => {
29
  if (!modelInfo) return
 
46
  setHasBeenLoaded(false)
47
  }, [modelInfo, setSelectedQuantization, setHasBeenLoaded])
48
 
49
+
50
+
51
  useEffect(() => {
52
  if (!modelInfo) return
53
 
 
66
  const { status, output } = e.data
67
  if (status === 'ready') {
68
  setStatus('ready')
69
+ if (e.data.output) console.log(e.data.output)
70
  setHasBeenLoaded(true)
71
  } else if (status === 'loading' && output && !hasBeenLoaded) {
72
  setStatus('loading')
 
162
  <div className="flex justify-center">
163
  <button
164
  className="w-32 py-2 px-4 bg-green-500 hover:bg-green-600 rounded text-white font-medium disabled:opacity-50 disabled:cursor-not-allowed transition-colors text-sm inline-flex items-center text-center justify-center space-x-2"
165
+ disabled={hasBeenLoaded || status === 'loading'}
166
  onClick={loadModel}
167
  >
168
  {status === 'loading' && !hasBeenLoaded ? (
src/components/TextClassification.tsx CHANGED
@@ -24,8 +24,6 @@ function TextClassification() {
24
  const [text, setText] = useState<string>(PLACEHOLDER_TEXTS.join('\n'))
25
  const { activeWorker, status, setStatus, modelInfo, results, setResults, hasBeenLoaded} = useModel()
26
 
27
-
28
-
29
  const classify = useCallback(() => {
30
  if (!modelInfo || !activeWorker) {
31
  console.error('Model info or worker is not available')
@@ -48,46 +46,49 @@ function TextClassification() {
48
 
49
  return (
50
  <div className="flex flex-col h-[60vh] max-h-[100vh] w-full p-4">
51
- <h1 className="text-2xl font-bold mb-4">Text Classification</h1>
52
 
53
- <div className="flex flex-col lg:flex-row gap-4 h-full">
54
  {/* Input Section */}
55
- <div className="flex flex-col w-full lg:w-1/2">
56
- <label className="text-lg font-medium mb-2">Input Text:</label>
57
- <textarea
58
- className="border border-gray-300 rounded p-3 flex-grow resize-none"
59
- value={text}
60
- onChange={(e) => setText(e.target.value)}
61
- placeholder="Enter text to classify (one per line)..."
62
- />
 
 
63
 
64
- <div className="flex gap-2 mt-4">
65
- <button
66
- className="flex-1 py-2 px-4 bg-blue-500 hover:bg-blue-600 rounded text-white font-medium disabled:opacity-50 disabled:cursor-not-allowed transition-colors"
67
- disabled={busy}
68
- onClick={classify}
69
- >
70
- {hasBeenLoaded ? !busy
71
- ? 'Classify Text'
72
- : 'Processing...'
73
- : 'Load model first'}
74
- </button>
75
- <button
76
- className="py-2 px-4 bg-gray-500 hover:bg-gray-600 rounded text-white font-medium transition-colors"
77
- onClick={handleClear}
78
- >
79
- Clear Results
80
- </button>
 
81
  </div>
82
  </div>
83
 
84
  {/* Results Section */}
85
- <div className="flex flex-col w-full lg:w-1/2">
86
- <label className="text-lg font-medium mb-2">
87
  Classification Results ({results.length}):
88
  </label>
89
 
90
- <div className="border border-gray-300 rounded p-3 flex-grow overflow-y-auto">
91
  {results.length === 0 ? (
92
  <div className="text-gray-500 text-center py-8">
93
  No results yet. Click "Classify Text" to analyze your input.
 
24
  const [text, setText] = useState<string>(PLACEHOLDER_TEXTS.join('\n'))
25
  const { activeWorker, status, setStatus, modelInfo, results, setResults, hasBeenLoaded} = useModel()
26
 
 
 
27
  const classify = useCallback(() => {
28
  if (!modelInfo || !activeWorker) {
29
  console.error('Model info or worker is not available')
 
46
 
47
  return (
48
  <div className="flex flex-col h-[60vh] max-h-[100vh] w-full p-4">
49
+ <h1 className="text-2xl font-bold mb-4 flex-shrink-0">Text Classification</h1>
50
 
51
+ <div className="flex flex-col lg:flex-row gap-4 flex-1 min-h-0">
52
  {/* Input Section */}
53
+ <div className="flex flex-col w-full lg:w-1/2 min-h-0">
54
+ <label className="text-lg font-medium mb-2 flex-shrink-0">Input Text:</label>
55
+
56
+ <div className="flex flex-col flex-1 min-h-0">
57
+ <textarea
58
+ className="border border-gray-300 rounded p-3 flex-1 resize-none min-h-[200px]"
59
+ value={text}
60
+ onChange={(e) => setText(e.target.value)}
61
+ placeholder="Enter text to classify (one per line)..."
62
+ />
63
 
64
+ <div className="flex gap-2 mt-4 flex-shrink-0">
65
+ <button
66
+ className="flex-1 py-2 px-4 bg-blue-500 hover:bg-blue-600 rounded text-white font-medium disabled:opacity-50 disabled:cursor-not-allowed transition-colors"
67
+ disabled={busy}
68
+ onClick={classify}
69
+ >
70
+ {hasBeenLoaded ? !busy
71
+ ? 'Classify Text'
72
+ : 'Processing...'
73
+ : 'Load model first'}
74
+ </button>
75
+ <button
76
+ className="py-2 px-4 bg-gray-500 hover:bg-gray-600 rounded text-white font-medium transition-colors"
77
+ onClick={handleClear}
78
+ >
79
+ Clear Results
80
+ </button>
81
+ </div>
82
  </div>
83
  </div>
84
 
85
  {/* Results Section */}
86
+ <div className="flex flex-col w-full lg:w-1/2 min-h-0">
87
+ <label className="text-lg font-medium mb-2 flex-shrink-0">
88
  Classification Results ({results.length}):
89
  </label>
90
 
91
+ <div className="border border-gray-300 rounded p-3 flex-1 overflow-y-auto min-h-[200px]">
92
  {results.length === 0 ? (
93
  <div className="text-gray-500 text-center py-8">
94
  No results yet. Click "Classify Text" to analyze your input.
src/types.ts CHANGED
@@ -9,7 +9,13 @@ export interface ClassificationOutput {
9
  scores: number[]
10
  }
11
 
12
- export type WorkerStatus = 'initiate' | 'ready' | 'output' | 'loading' | 'error'
 
 
 
 
 
 
13
 
14
  export interface WorkerMessage {
15
  status: WorkerStatus
 
9
  scores: number[]
10
  }
11
 
12
+ export type WorkerStatus =
13
+ | 'initiate'
14
+ | 'ready'
15
+ | 'output'
16
+ | 'loading'
17
+ | 'error'
18
+ | 'disposed'
19
 
20
  export interface WorkerMessage {
21
  status: WorkerStatus