Vokturz commited on
Commit
ad5cef3
·
1 Parent(s): e7ba29d

wip: refactor model loading and classification components

Browse files
public/workers/text-classification.js CHANGED
@@ -29,7 +29,7 @@ self.addEventListener('message', async (event) => {
29
  const classifier = await MyTextClassificationPipeline.getInstance(
30
  model,
31
  (x) => {
32
- self.postMessage({ status: 'progress', output: x })
33
  }
34
  )
35
 
@@ -40,7 +40,7 @@ self.addEventListener('message', async (event) => {
40
 
41
  if (type === 'classify') {
42
  if (!text) {
43
- self.postMessage({ status: 'complete' }) // Nothing to process
44
  return
45
  }
46
  const split = text.split('\n')
@@ -57,6 +57,6 @@ self.addEventListener('message', async (event) => {
57
  })
58
  }
59
  }
60
- self.postMessage({ status: 'complete' })
61
  }
62
  })
 
29
  const classifier = await MyTextClassificationPipeline.getInstance(
30
  model,
31
  (x) => {
32
+ self.postMessage({ status: 'loading', output: x })
33
  }
34
  )
35
 
 
40
 
41
  if (type === 'classify') {
42
  if (!text) {
43
+ self.postMessage({ status: 'ready' }) // Nothing to process
44
  return
45
  }
46
  const split = text.split('\n')
 
57
  })
58
  }
59
  }
60
+ self.postMessage({ status: 'ready' })
61
  }
62
  })
src/App.tsx CHANGED
@@ -1,4 +1,4 @@
1
- import { useEffect, useState } from 'react'
2
  import PipelineSelector from './components/PipelineSelector'
3
  import ZeroShotClassification from './components/ZeroShotClassification'
4
  import TextClassification from './components/TextClassification'
@@ -10,14 +10,12 @@ import ModelSelector from './components/ModelSelector'
10
  import ModelInfo from './components/ModelInfo'
11
 
12
  function App() {
13
- const { pipeline, setPipeline, progress, status, modelInfo, setModels } =
14
- useModel()
15
 
16
  useEffect(() => {
17
  const fetchModels = async () => {
18
  const fetchedModels = await getModelsByPipeline(pipeline)
19
  setModels(fetchedModels)
20
- console.log(fetchedModels)
21
  }
22
  fetchModels()
23
  }, [setModels, pipeline])
@@ -26,76 +24,34 @@ function App() {
26
  <div className="min-h-screen bg-gradient-to-br from-blue-50 to-indigo-100">
27
  <Header />
28
 
29
- <main className="max-w-8xl mx-auto px-4 sm:px-6 lg:px-8 py-8">
30
- {/* Pipeline Selection */}
31
  <div className="mb-8">
32
- <div className="bg-white rounded-lg shadow-sm border p-6">
33
- <div className="flex items-center justify-between mb-4">
34
- <ModelInfo />
35
- </div>
 
 
 
 
 
 
 
 
36
 
37
- <div className="grid grid-cols-1 xl:grid-cols-2 gap-4 items-start">
38
- <div className="space-y-2">
39
- <span className="text-lg font-semibold text-gray-900 block">
40
- Choose a Pipeline
41
- </span>
42
- <PipelineSelector
43
- pipeline={pipeline}
44
- setPipeline={setPipeline}
45
- />
46
  </div>
47
 
48
- <div className="space-y-2">
49
- <span className="text-lg font-semibold text-gray-900 block">
50
- Select Model
51
- </span>
52
- <ModelSelector />
53
  </div>
54
  </div>
55
 
56
- {/* Model Loading Progress */}
57
- {status === 'progress' && (
58
- <div className="mt-4 p-4 bg-blue-50 rounded-lg">
59
- <div className="flex items-center space-x-3">
60
- <div className="flex-shrink-0">
61
- <svg
62
- className="animate-spin h-5 w-5 text-blue-500"
63
- fill="none"
64
- viewBox="0 0 24 24"
65
- >
66
- <circle
67
- className="opacity-25"
68
- cx="12"
69
- cy="12"
70
- r="10"
71
- stroke="currentColor"
72
- strokeWidth="4"
73
- ></circle>
74
- <path
75
- className="opacity-75"
76
- fill="currentColor"
77
- d="M4 12a8 8 0 018-8V0C5.373 0 0 5.373 0 12h4zm2 5.291A7.962 7.962 0 014 12H0c0 3.042 1.135 5.824 3 7.938l3-2.647z"
78
- ></path>
79
- </svg>
80
- </div>
81
- <div className="flex-1">
82
- <p className="text-sm font-medium text-blue-900">
83
- Loading Model...
84
- </p>
85
- <div className="mt-2 bg-blue-200 rounded-full h-2">
86
- <div
87
- className="bg-blue-500 h-2 rounded-full transition-all duration-300"
88
- style={{ width: `${progress.toFixed(2)}%` }}
89
- ></div>
90
- </div>
91
- <p className="text-xs text-blue-700 mt-1">
92
- {progress.toFixed(2)}%
93
- </p>
94
- </div>
95
- </div>
96
- </div>
97
- )}
98
-
99
  {/* Pipeline Description */}
100
  <div className="mt-4 p-4 bg-gray-50 rounded-lg">
101
  <div className="flex items-start space-x-3">
@@ -129,7 +85,6 @@ function App() {
129
  </div>
130
  </div>
131
 
132
- {/* Pipeline Component */}
133
  <div className="bg-white rounded-lg shadow-sm border overflow-hidden">
134
  {pipeline === 'zero-shot-classification' && (
135
  <ZeroShotClassification />
 
1
+ import { useEffect } from 'react'
2
  import PipelineSelector from './components/PipelineSelector'
3
  import ZeroShotClassification from './components/ZeroShotClassification'
4
  import TextClassification from './components/TextClassification'
 
10
  import ModelInfo from './components/ModelInfo'
11
 
12
  function App() {
13
+ const { pipeline, setPipeline, setModels } = useModel()
 
14
 
15
  useEffect(() => {
16
  const fetchModels = async () => {
17
  const fetchedModels = await getModelsByPipeline(pipeline)
18
  setModels(fetchedModels)
 
19
  }
20
  fetchModels()
21
  }, [setModels, pipeline])
 
24
  <div className="min-h-screen bg-gradient-to-br from-blue-50 to-indigo-100">
25
  <Header />
26
 
27
+ <main className="max-w-6xl mx-auto px-4 sm:px-6 lg:px-8 py-8">
 
28
  <div className="mb-8">
29
+ <div className="bg-white rounded-lg border p-6">
30
+ <div className="flex items-start justify-between max-w-6xl mx-auto">
31
+ <div className="space-y-2 flex-1">
32
+ <div className="space-y-2">
33
+ <span className="text-lg font-semibold text-gray-900 block">
34
+ Choose a Pipeline
35
+ </span>
36
+ <PipelineSelector
37
+ pipeline={pipeline}
38
+ setPipeline={setPipeline}
39
+ />
40
+ </div>
41
 
42
+ <div className="space-y-2">
43
+ <span className="text-lg font-semibold text-gray-900 block">
44
+ Select Model
45
+ </span>
46
+ <ModelSelector />
47
+ </div>
 
 
 
48
  </div>
49
 
50
+ <div className="ml-6">
51
+ <ModelInfo />
 
 
 
52
  </div>
53
  </div>
54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  {/* Pipeline Description */}
56
  <div className="mt-4 p-4 bg-gray-50 rounded-lg">
57
  <div className="flex items-start space-x-3">
 
85
  </div>
86
  </div>
87
 
 
88
  <div className="bg-white rounded-lg shadow-sm border overflow-hidden">
89
  {pipeline === 'zero-shot-classification' && (
90
  <ZeroShotClassification />
src/components/ModelInfo.tsx CHANGED
@@ -6,14 +6,11 @@ import {
6
  DatabaseIcon,
7
  CheckCircle,
8
  XCircle,
9
- ExternalLink,
10
- ChevronDown
11
  } from 'lucide-react'
12
  import { getModelSize } from '../lib/huggingface'
13
  import { useModel } from '../contexts/ModelContext'
14
- import { useEffect, useCallback } from 'react'
15
- import { QuantizationType, WorkerMessage } from '../types'
16
- import { getWorker } from '../lib/workerManager'
17
 
18
  const ModelInfo = () => {
19
  const formatNumber = (num: number) => {
@@ -29,96 +26,50 @@ const ModelInfo = () => {
29
 
30
  const {
31
  modelInfo,
32
- selectedQuantization,
33
- setSelectedQuantization,
34
- status,
35
- setStatus,
36
- setProgress,
37
- activeWorker,
38
- setActiveWorker,
39
- pipeline,
40
- workerLoaded,
41
- setWorkerLoaded
42
  } = useModel()
43
 
44
- useEffect(() => {
45
- if (modelInfo.isCompatible && modelInfo.supportedQuantizations.length > 0) {
46
- const quantizations = modelInfo.supportedQuantizations
47
- let defaultQuant: QuantizationType = 'fp32'
48
-
49
- if (quantizations.includes('int8')) {
50
- defaultQuant = 'int8'
51
- } else if (quantizations.includes('q8')) {
52
- defaultQuant = 'q8'
53
- } else if (quantizations.includes('q4')) {
54
- defaultQuant = 'q4'
55
- }
56
-
57
- setSelectedQuantization(defaultQuant)
58
- }
59
- }, [
60
- modelInfo.supportedQuantizations,
61
- modelInfo.isCompatible,
62
- setSelectedQuantization
63
- ])
64
-
65
- useEffect(() => {
66
- const newWorker = getWorker(pipeline)
67
- if (!newWorker) {
68
- return
69
- }
70
-
71
- setStatus('idle')
72
- setWorkerLoaded(false)
73
- setActiveWorker(newWorker)
74
-
75
- const onMessageReceived = (e: MessageEvent<WorkerMessage>) => {
76
- const { status, output } = e.data
77
- if (status === 'initiate') {
78
- setStatus('loading')
79
- } else if (status === 'ready') {
80
- setStatus('ready')
81
- setWorkerLoaded(true)
82
- } else if (status === 'progress' && output) {
83
- setStatus('progress')
84
- if (
85
- output.progress &&
86
- typeof output.file === 'string' &&
87
- output.file.startsWith('onnx')
88
- ) {
89
- setProgress(output.progress)
90
- }
91
- }
92
- }
93
-
94
- newWorker.addEventListener('message', onMessageReceived)
95
-
96
- return () => {
97
- newWorker.removeEventListener('message', onMessageReceived)
98
- // terminateWorker(pipeline);
99
- }
100
- }, [pipeline, selectedQuantization, setActiveWorker, setStatus, setProgress, setWorkerLoaded])
101
-
102
- const loadModel = useCallback(() => {
103
- if (!modelInfo.name || !selectedQuantization) return
104
 
105
- setStatus('loading')
106
- const message = {
107
- type: 'load',
108
- model: modelInfo.name,
109
- quantization: selectedQuantization
110
- }
111
- activeWorker?.postMessage(message)
112
- }, [modelInfo.name, selectedQuantization, setStatus, activeWorker])
113
 
114
- const busy: boolean = status !== 'idle'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
 
116
  if (!modelInfo.name) {
117
- return null
118
  }
119
 
120
  return (
121
- <div className="bg-gradient-to-r from-blue-50 to-indigo-50 px-4 py-3 rounded-lg border border-blue-200 space-y-3">
122
  {/* Model Name Row */}
123
  <div className="flex items-center space-x-2">
124
  <Bot className="w-4 h-4 text-blue-600" />
@@ -150,16 +101,16 @@ const ModelInfo = () => {
150
 
151
  {/* Base Model Link */}
152
  {modelInfo.baseId && (
153
- <div className="flex items-center space-x-2 ml-6">
154
  <a
155
  href={`https://huggingface.co/${modelInfo.baseId}`}
156
  target="_blank"
157
  rel="noopener noreferrer"
158
- className="text-xs text-gray-600 truncate max-w-100 hover:underline"
159
  title={`Base model: ${modelInfo.baseId}`}
160
  >
161
  <ExternalLink className="w-3 h-3 inline-block mr-1" />
162
- {modelInfo.baseId}
163
  </a>
164
  </div>
165
  )}
@@ -180,75 +131,31 @@ const ModelInfo = () => {
180
  </div>
181
  )}
182
 
183
- {modelInfo.parameters > 0 && (
184
- <div className="flex items-center space-x-1">
185
- <Cpu className="w-3 h-3 text-purple-500" />
186
  <span>{formatNumber(modelInfo.parameters)}</span>
187
- </div>
188
- )}
 
 
189
 
190
- {modelInfo.parameters > 0 && (
191
- <div className="flex items-center space-x-1">
192
- <DatabaseIcon className="w-3 h-3 text-purple-500" />
193
  <span>
194
  {`~${getModelSize(
195
  modelInfo.parameters,
196
  selectedQuantization
197
  ).toFixed(1)}MB`}
198
  </span>
199
- </div>
200
- )}
 
 
201
  </div>
202
 
203
- {/* Separator */}
204
- {modelInfo.isCompatible &&
205
- modelInfo.supportedQuantizations.length > 0 && (
206
- <hr className="border-gray-200" />
207
- )}
208
-
209
- {/* Quantization Dropdown */}
210
- {modelInfo.isCompatible &&
211
- modelInfo.supportedQuantizations.length > 0 && (
212
- <div className="flex items-center space-x-2">
213
- <span className="text-xs text-gray-600 font-medium">
214
- Quantization:
215
- </span>
216
- <div className="relative">
217
- <select
218
- value={selectedQuantization || ''}
219
- onChange={(e) =>
220
- setSelectedQuantization(e.target.value as QuantizationType)
221
- }
222
- className="appearance-none bg-white border border-gray-300 rounded-md px-3 py-1 pr-8 text-xs text-gray-700 focus:outline-none focus:ring-2 focus:ring-blue-500 focus:border-blue-500"
223
- >
224
- <option value="">Select quantization</option>
225
- {modelInfo.supportedQuantizations.map((quant) => (
226
- <option key={quant} value={quant}>
227
- {quant}
228
- </option>
229
- ))}
230
- </select>
231
- <ChevronDown className="absolute right-2 top-1/2 transform -translate-y-1/2 w-3 h-3 text-gray-400 pointer-events-none" />
232
- </div>
233
- </div>
234
- )}
235
-
236
- {/* Load Model Button */}
237
- {modelInfo.isCompatible && selectedQuantization && (
238
- <div className="flex justify-center">
239
- <button
240
- className="py-2 px-4 bg-green-500 hover:bg-green-600 rounded text-white font-medium disabled:opacity-50 disabled:cursor-not-allowed transition-colors text-sm"
241
- disabled={busy || !selectedQuantization || workerLoaded}
242
- onClick={loadModel}
243
- >
244
- {status === 'loading'
245
- ? 'Loading Model...'
246
- : workerLoaded
247
- ? 'Model Ready'
248
- : 'Load Model'}
249
- </button>
250
- </div>
251
- )}
252
 
253
  {/* Incompatibility Message */}
254
  {modelInfo.isCompatible === false && modelInfo.incompatibilityReason && (
 
6
  DatabaseIcon,
7
  CheckCircle,
8
  XCircle,
9
+ ExternalLink
 
10
  } from 'lucide-react'
11
  import { getModelSize } from '../lib/huggingface'
12
  import { useModel } from '../contexts/ModelContext'
13
+ import ModelLoader from './ModelLoader'
 
 
14
 
15
  const ModelInfo = () => {
16
  const formatNumber = (num: number) => {
 
26
 
27
  const {
28
  modelInfo,
29
+ selectedQuantization
 
 
 
 
 
 
 
 
 
30
  } = useModel()
31
 
32
+ const ModelInfoSkeleton = () => (
33
+ <div className="mt-5 bg-gradient-to-r from-blue-50 to-indigo-50 px-4 py-3 rounded-lg border border-blue-200 space-y-4 h-full min-h-[160px] animate-pulse w-[400px]">
34
+ <div className="flex items-center space-x-2">
35
+ <Bot className="w-4 h-4 text-blue-300" />
36
+ <div className="h-4 bg-gray-300 rounded w-48"></div>
37
+ <div className="w-4 h-4 bg-gray-300 rounded-full"></div>
38
+ </div>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
+ <div className="flex items-center space-x-2 ml-6">
41
+ <div className="h-3 bg-gray-200 rounded w-32"></div>
42
+ </div>
 
 
 
 
 
43
 
44
+ <div className="flex items-center justify-self-end space-x-4">
45
+ <div className="flex items-center space-x-1">
46
+ <Heart className="w-3 h-3 text-red-300" />
47
+ <div className="h-3 bg-gray-200 rounded w-8"></div>
48
+ </div>
49
+ <div className="flex items-center space-x-1">
50
+ <Download className="w-3 h-3 text-green-300" />
51
+ <div className="h-3 bg-gray-200 rounded w-8"></div>
52
+ </div>
53
+ <div className="flex items-center space-x-1">
54
+ <Cpu className="w-3 h-3 text-purple-300" />
55
+ <div className="h-3 bg-gray-200 rounded w-8"></div>
56
+ </div>
57
+ <div className="flex items-center space-x-1">
58
+ <DatabaseIcon className="w-3 h-3 text-purple-300" />
59
+ <div className="h-3 bg-gray-200 rounded w-12"></div>
60
+ </div>
61
+ </div>
62
+ <hr className="border-gray-200" />
63
+ <div className="h-8 bg-gray-200 rounded w-full"></div>
64
+ </div>
65
+ )
66
 
67
  if (!modelInfo.name) {
68
+ return <ModelInfoSkeleton />
69
  }
70
 
71
  return (
72
+ <div className="mt-5 bg-gradient-to-r from-blue-50 to-indigo-50 px-4 py-3 rounded-lg border border-blue-200 space-y-3 h-full min-h-[150px]">
73
  {/* Model Name Row */}
74
  <div className="flex items-center space-x-2">
75
  <Bot className="w-4 h-4 text-blue-600" />
 
101
 
102
  {/* Base Model Link */}
103
  {modelInfo.baseId && (
104
+ <div className="flex items-center space-x-2 ml-6 text-xs text-gray-600 truncate max-w-100">
105
  <a
106
  href={`https://huggingface.co/${modelInfo.baseId}`}
107
  target="_blank"
108
  rel="noopener noreferrer"
109
+ className=" hover:underline"
110
  title={`Base model: ${modelInfo.baseId}`}
111
  >
112
  <ExternalLink className="w-3 h-3 inline-block mr-1" />
113
+ ({modelInfo.baseId})
114
  </a>
115
  </div>
116
  )}
 
131
  </div>
132
  )}
133
 
134
+ <div className="flex items-center space-x-1">
135
+ <Cpu className="w-3 h-3 text-purple-500" />
136
+ {modelInfo.parameters ? (
137
  <span>{formatNumber(modelInfo.parameters)}</span>
138
+ ) : (
139
+ <span>?</span>
140
+ )}
141
+ </div>
142
 
143
+ <div className="flex items-center space-x-1">
144
+ <DatabaseIcon className="w-3 h-3 text-purple-500" />
145
+ {modelInfo.parameters ? (
146
  <span>
147
  {`~${getModelSize(
148
  modelInfo.parameters,
149
  selectedQuantization
150
  ).toFixed(1)}MB`}
151
  </span>
152
+ ) : (
153
+ <span>?</span>
154
+ )}
155
+ </div>
156
  </div>
157
 
158
+ <ModelLoader />
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
 
160
  {/* Incompatibility Message */}
161
  {modelInfo.isCompatible === false && modelInfo.incompatibilityReason && (
src/components/ModelLoader.tsx ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { useEffect, useCallback } from 'react'
2
+ import { ChevronDown, Loader } from 'lucide-react'
3
+ import { QuantizationType, WorkerMessage } from '../types'
4
+ import { useModel } from '../contexts/ModelContext'
5
+ import { getWorker } from '../lib/workerManager'
6
+
7
+ const ModelLoader = () => {
8
+ const {
9
+ modelInfo,
10
+ selectedQuantization,
11
+ setSelectedQuantization,
12
+ status,
13
+ progress,
14
+ setStatus,
15
+ setProgress,
16
+ activeWorker,
17
+ setActiveWorker,
18
+ pipeline
19
+ } = useModel()
20
+
21
+ useEffect(() => {
22
+ if (modelInfo.isCompatible && modelInfo.supportedQuantizations.length > 0) {
23
+ const quantizations = modelInfo.supportedQuantizations
24
+ let defaultQuant: QuantizationType = 'fp32'
25
+
26
+ if (quantizations.includes('int8')) {
27
+ defaultQuant = 'int8'
28
+ } else if (quantizations.includes('q8')) {
29
+ defaultQuant = 'q8'
30
+ } else if (quantizations.includes('q4')) {
31
+ defaultQuant = 'q4'
32
+ }
33
+
34
+ setSelectedQuantization(defaultQuant)
35
+ }
36
+ }, [
37
+ modelInfo.supportedQuantizations,
38
+ modelInfo.isCompatible,
39
+ setSelectedQuantization
40
+ ])
41
+
42
+ useEffect(() => {
43
+ const newWorker = getWorker(pipeline)
44
+ if (!newWorker) {
45
+ return
46
+ }
47
+
48
+ setStatus('initiate')
49
+ setActiveWorker(newWorker)
50
+
51
+ const onMessageReceived = (e: MessageEvent<WorkerMessage>) => {
52
+ const { status, output } = e.data
53
+ if (status === 'ready') {
54
+ setStatus('ready')
55
+ } else if (status === 'loading' && output) {
56
+ setStatus('loading')
57
+ if (
58
+ output.progress &&
59
+ typeof output.file === 'string' &&
60
+ output.file.startsWith('onnx')
61
+ ) {
62
+ setProgress(output.progress)
63
+ }
64
+ }
65
+ }
66
+
67
+ newWorker.addEventListener('message', onMessageReceived)
68
+
69
+ return () => {
70
+ newWorker.removeEventListener('message', onMessageReceived)
71
+ // terminateWorker(pipeline);
72
+ }
73
+ }, [pipeline, modelInfo.name, selectedQuantization, setActiveWorker, setStatus, setProgress])
74
+
75
+ const loadModel = useCallback(() => {
76
+ if (!modelInfo.name || !selectedQuantization) return
77
+
78
+ setStatus('loading')
79
+ const message = {
80
+ type: 'load',
81
+ model: modelInfo.name,
82
+ quantization: selectedQuantization
83
+ }
84
+ activeWorker?.postMessage(message)
85
+ }, [modelInfo.name, selectedQuantization, setStatus, activeWorker])
86
+
87
+ const ready: boolean = status === 'ready'
88
+ const busy: boolean = status === 'loading'
89
+
90
+ if (!modelInfo.isCompatible || modelInfo.supportedQuantizations.length === 0) {
91
+ return null
92
+ }
93
+
94
+ return (
95
+ <div className="space-y-3">
96
+ <hr className="border-gray-200" />
97
+
98
+ <div className="flex items-center justify-between space-x-4">
99
+ <div className="flex items-center space-x-2">
100
+ <span className="text-xs text-gray-600 font-medium">
101
+ Quantization:
102
+ </span>
103
+ <div className="relative">
104
+ <select
105
+ value={selectedQuantization || ''}
106
+ onChange={(e) =>
107
+ setSelectedQuantization(e.target.value as QuantizationType)
108
+ }
109
+ className="appearance-none bg-white border border-gray-300 rounded-md px-3 py-1 pr-8 text-xs text-gray-700 focus:outline-none focus:ring-2 focus:ring-blue-500 focus:border-blue-500"
110
+ >
111
+ <option value="">Select quantization</option>
112
+ {modelInfo.supportedQuantizations.map((quant) => (
113
+ <option key={quant} value={quant}>
114
+ {quant}
115
+ </option>
116
+ ))}
117
+ </select>
118
+ <ChevronDown className="absolute right-2 top-1/2 transform -translate-y-1/2 w-3 h-3 text-gray-400 pointer-events-none" />
119
+ </div>
120
+ </div>
121
+
122
+ {selectedQuantization && (
123
+ <div className="flex justify-center">
124
+ <button
125
+ className="w-32 py-2 px-4 bg-green-500 hover:bg-green-600 rounded text-white font-medium disabled:opacity-50 disabled:cursor-not-allowed transition-colors text-sm inline-flex items-center text-center justify-center space-x-2"
126
+ disabled={(busy && !ready) || !selectedQuantization || ready}
127
+ onClick={loadModel}
128
+ >
129
+ {status === 'loading' && (
130
+ <>
131
+ <Loader className="animate-spin h-4 w-4" />
132
+ <span>{progress.toFixed(0)}%</span>
133
+ </>
134
+ )}
135
+ {!ready && !busy ? <span>Load Model</span> : !ready ? null : <span>Model Ready</span>}
136
+ </button>
137
+ </div>
138
+ )}
139
+ </div>
140
+ </div>
141
+ )
142
+ }
143
+
144
+ export default ModelLoader
src/components/ModelSelector.tsx CHANGED
@@ -1,5 +1,11 @@
1
- import React, { useEffect, useState } from 'react'
2
- import { Listbox, ListboxButton, ListboxOption, ListboxOptions, Transition } from '@headlessui/react'
 
 
 
 
 
 
3
  import { useModel } from '../contexts/ModelContext'
4
  import { getModelInfo } from '../lib/huggingface'
5
  import { Heart, Download, ChevronDown, Check, ArrowUpDown } from 'lucide-react'
@@ -50,42 +56,46 @@ const ModelSelector: React.FC = () => {
50
  }, [models, sortBy, sortOrder])
51
 
52
  // Function to fetch detailed model info and set as selected
53
- const fetchAndSetModelInfo = async (modelId: string) => {
54
- try {
55
- const modelInfoResponse = await getModelInfo(modelId)
56
-
57
- let parameters = 0
58
- if (modelInfoResponse.safetensors) {
59
- const safetensors = modelInfoResponse.safetensors
60
- parameters =
61
- safetensors.parameters.BF16 ||
62
- safetensors.parameters.F16 ||
63
- safetensors.parameters.F32 ||
64
- safetensors.parameters.total ||
65
- 0
66
- }
67
 
68
- const modelInfo = {
69
- id: modelId,
70
- name: modelInfoResponse.id || modelId,
71
- architecture: modelInfoResponse.config?.architectures?.[0] || 'Unknown',
72
- parameters,
73
- likes: modelInfoResponse.likes || 0,
74
- downloads: modelInfoResponse.downloads || 0,
75
- createdAt: modelInfoResponse.createdAt || '',
76
- isCompatible: modelInfoResponse.isCompatible,
77
- incompatibilityReason: modelInfoResponse.incompatibilityReason,
78
- supportedQuantizations: modelInfoResponse.supportedQuantizations,
79
- baseId: modelInfoResponse.baseId
80
- }
81
 
82
- console.log('Fetched model info:', modelInfoResponse)
 
 
 
 
 
 
 
 
 
 
 
 
 
83
 
84
- setModelInfo(modelInfo)
85
- } catch (error) {
86
- console.error('Error fetching model info:', error)
87
- }
88
- }
 
 
 
 
89
 
90
  // Update modelInfo to first model when pipeline changes
91
  useEffect(() => {
@@ -93,7 +103,7 @@ const ModelSelector: React.FC = () => {
93
  const firstModel = models[0]
94
  fetchAndSetModelInfo(firstModel.id)
95
  }
96
- }, [pipeline, models])
97
 
98
  const handleModelSelect = (modelId: string) => {
99
  fetchAndSetModelInfo(modelId)
@@ -108,35 +118,42 @@ const ModelSelector: React.FC = () => {
108
  }
109
  }
110
 
111
- const selectedModel = models.find(model => model.id === modelInfo.id) || models[0]
 
112
 
113
  return (
114
  <div className="relative">
115
- <Listbox value={selectedModel} onChange={(model) => handleModelSelect(model.id)}>
 
 
 
116
  <div className="relative">
117
  <ListboxButton className="w-full px-3 py-2 border border-gray-300 rounded-md focus:outline-none focus:ring-2 focus:ring-blue-500 focus:border-transparent bg-white text-left flex items-center justify-between">
118
  <div className="flex items-center justify-between w-full">
119
  <div className="flex flex-col flex-1 min-w-0">
120
- <span className="truncate font-medium">{modelInfo.id || 'Select a model'}</span>
 
 
121
  </div>
122
-
123
  <div className="flex items-center space-x-3">
124
- {selectedModel && (selectedModel.likes > 0 || selectedModel.downloads > 0) && (
125
- <div className="flex items-center space-x-3 text-xs text-gray-500">
126
- {selectedModel.likes > 0 && (
127
- <div className="flex items-center space-x-1">
128
- <Heart className="w-3 h-3 text-red-500" />
129
- <span>{formatNumber(selectedModel.likes)}</span>
130
- </div>
131
- )}
132
- {selectedModel.downloads > 0 && (
133
- <div className="flex items-center space-x-1">
134
- <Download className="w-3 h-3 text-green-500" />
135
- <span>{formatNumber(selectedModel.downloads)}</span>
136
- </div>
137
- )}
138
- </div>
139
- )}
 
140
  <ChevronDown className="w-4 h-4 ui-open:rotate-180 transition-transform flex-shrink-0" />
141
  </div>
142
  </div>
@@ -158,7 +175,9 @@ const ModelSelector: React.FC = () => {
158
  <button
159
  onClick={() => handleSortChange('name')}
160
  className={`px-2 py-1 rounded flex items-center space-x-1 ${
161
- sortBy === 'name' ? 'bg-blue-100 text-blue-700' : 'text-gray-600 hover:bg-gray-100'
 
 
162
  }`}
163
  >
164
  <span>Name</span>
@@ -167,7 +186,9 @@ const ModelSelector: React.FC = () => {
167
  <button
168
  onClick={() => handleSortChange('likes')}
169
  className={`px-2 py-1 rounded flex items-center space-x-1 ${
170
- sortBy === 'likes' ? 'bg-blue-100 text-blue-700' : 'text-gray-600 hover:bg-gray-100'
 
 
171
  }`}
172
  >
173
  <Heart className="w-3 h-3" />
@@ -177,21 +198,29 @@ const ModelSelector: React.FC = () => {
177
  <button
178
  onClick={() => handleSortChange('downloads')}
179
  className={`px-2 py-1 rounded flex items-center space-x-1 ${
180
- sortBy === 'downloads' ? 'bg-blue-100 text-blue-700' : 'text-gray-600 hover:bg-gray-100'
 
 
181
  }`}
182
  >
183
  <Download className="w-3 h-3" />
184
  <span>Downloads</span>
185
- {sortBy === 'downloads' && <ArrowUpDown className="w-3 h-3" />}
 
 
186
  </button>
187
  <button
188
  onClick={() => handleSortChange('createdAt')}
189
  className={`px-2 py-1 rounded flex items-center space-x-1 ${
190
- sortBy === 'createdAt' ? 'bg-blue-100 text-blue-700' : 'text-gray-600 hover:bg-gray-100'
 
 
191
  }`}
192
  >
193
  <span>Date</span>
194
- {sortBy === 'createdAt' && <ArrowUpDown className="w-3 h-3" />}
 
 
195
  </button>
196
  </div>
197
  </div>
@@ -200,7 +229,7 @@ const ModelSelector: React.FC = () => {
200
  <div className="overflow-auto max-h-48">
201
  {sortedModels.map((model) => {
202
  const hasStats = model.likes > 0 || model.downloads > 0
203
-
204
  return (
205
  <ListboxOption
206
  key={model.id}
@@ -238,7 +267,7 @@ const ModelSelector: React.FC = () => {
238
  <span>{formatNumber(model.downloads)}</span>
239
  </div>
240
  )}
241
-
242
  {model.createdAt && (
243
  <span className="text-xs text-gray-400">
244
  {model.createdAt.split('T')[0]}
 
1
+ import React, { useCallback, useEffect, useState } from 'react'
2
+ import {
3
+ Listbox,
4
+ ListboxButton,
5
+ ListboxOption,
6
+ ListboxOptions,
7
+ Transition
8
+ } from '@headlessui/react'
9
  import { useModel } from '../contexts/ModelContext'
10
  import { getModelInfo } from '../lib/huggingface'
11
  import { Heart, Download, ChevronDown, Check, ArrowUpDown } from 'lucide-react'
 
56
  }, [models, sortBy, sortOrder])
57
 
58
  // Function to fetch detailed model info and set as selected
59
+ const fetchAndSetModelInfo = useCallback(
60
+ async (modelId: string) => {
61
+ try {
62
+ const modelInfoResponse = await getModelInfo(modelId)
 
 
 
 
 
 
 
 
 
 
63
 
64
+ let parameters = 0
65
+ if (modelInfoResponse.safetensors) {
66
+ const safetensors = modelInfoResponse.safetensors
67
+ parameters =
68
+ safetensors.parameters.BF16 ||
69
+ safetensors.parameters.F16 ||
70
+ safetensors.parameters.F32 ||
71
+ safetensors.parameters.total ||
72
+ 0
73
+ }
 
 
 
74
 
75
+ const modelInfo = {
76
+ id: modelId,
77
+ name: modelInfoResponse.id || modelId,
78
+ architecture:
79
+ modelInfoResponse.config?.architectures?.[0] || 'Unknown',
80
+ parameters,
81
+ likes: modelInfoResponse.likes || 0,
82
+ downloads: modelInfoResponse.downloads || 0,
83
+ createdAt: modelInfoResponse.createdAt || '',
84
+ isCompatible: modelInfoResponse.isCompatible,
85
+ incompatibilityReason: modelInfoResponse.incompatibilityReason,
86
+ supportedQuantizations: modelInfoResponse.supportedQuantizations,
87
+ baseId: modelInfoResponse.baseId
88
+ }
89
 
90
+ console.log('Fetched model info:', modelInfoResponse)
91
+
92
+ setModelInfo(modelInfo)
93
+ } catch (error) {
94
+ console.error('Error fetching model info:', error)
95
+ }
96
+ },
97
+ [setModelInfo]
98
+ )
99
 
100
  // Update modelInfo to first model when pipeline changes
101
  useEffect(() => {
 
103
  const firstModel = models[0]
104
  fetchAndSetModelInfo(firstModel.id)
105
  }
106
+ }, [pipeline, models, fetchAndSetModelInfo])
107
 
108
  const handleModelSelect = (modelId: string) => {
109
  fetchAndSetModelInfo(modelId)
 
118
  }
119
  }
120
 
121
+ const selectedModel =
122
+ models.find((model) => model.id === modelInfo.id) || models[0]
123
 
124
  return (
125
  <div className="relative">
126
+ <Listbox
127
+ value={selectedModel}
128
+ onChange={(model) => handleModelSelect(model.id)}
129
+ >
130
  <div className="relative">
131
  <ListboxButton className="w-full px-3 py-2 border border-gray-300 rounded-md focus:outline-none focus:ring-2 focus:ring-blue-500 focus:border-transparent bg-white text-left flex items-center justify-between">
132
  <div className="flex items-center justify-between w-full">
133
  <div className="flex flex-col flex-1 min-w-0">
134
+ <span className="truncate font-medium">
135
+ {modelInfo.id || 'Select a model'}
136
+ </span>
137
  </div>
138
+
139
  <div className="flex items-center space-x-3">
140
+ {selectedModel &&
141
+ (selectedModel.likes > 0 || selectedModel.downloads > 0) && (
142
+ <div className="flex items-center space-x-3 text-xs text-gray-500">
143
+ {selectedModel.likes > 0 && (
144
+ <div className="flex items-center space-x-1">
145
+ <Heart className="w-3 h-3 text-red-500" />
146
+ <span>{formatNumber(selectedModel.likes)}</span>
147
+ </div>
148
+ )}
149
+ {selectedModel.downloads > 0 && (
150
+ <div className="flex items-center space-x-1">
151
+ <Download className="w-3 h-3 text-green-500" />
152
+ <span>{formatNumber(selectedModel.downloads)}</span>
153
+ </div>
154
+ )}
155
+ </div>
156
+ )}
157
  <ChevronDown className="w-4 h-4 ui-open:rotate-180 transition-transform flex-shrink-0" />
158
  </div>
159
  </div>
 
175
  <button
176
  onClick={() => handleSortChange('name')}
177
  className={`px-2 py-1 rounded flex items-center space-x-1 ${
178
+ sortBy === 'name'
179
+ ? 'bg-blue-100 text-blue-700'
180
+ : 'text-gray-600 hover:bg-gray-100'
181
  }`}
182
  >
183
  <span>Name</span>
 
186
  <button
187
  onClick={() => handleSortChange('likes')}
188
  className={`px-2 py-1 rounded flex items-center space-x-1 ${
189
+ sortBy === 'likes'
190
+ ? 'bg-blue-100 text-blue-700'
191
+ : 'text-gray-600 hover:bg-gray-100'
192
  }`}
193
  >
194
  <Heart className="w-3 h-3" />
 
198
  <button
199
  onClick={() => handleSortChange('downloads')}
200
  className={`px-2 py-1 rounded flex items-center space-x-1 ${
201
+ sortBy === 'downloads'
202
+ ? 'bg-blue-100 text-blue-700'
203
+ : 'text-gray-600 hover:bg-gray-100'
204
  }`}
205
  >
206
  <Download className="w-3 h-3" />
207
  <span>Downloads</span>
208
+ {sortBy === 'downloads' && (
209
+ <ArrowUpDown className="w-3 h-3" />
210
+ )}
211
  </button>
212
  <button
213
  onClick={() => handleSortChange('createdAt')}
214
  className={`px-2 py-1 rounded flex items-center space-x-1 ${
215
+ sortBy === 'createdAt'
216
+ ? 'bg-blue-100 text-blue-700'
217
+ : 'text-gray-600 hover:bg-gray-100'
218
  }`}
219
  >
220
  <span>Date</span>
221
+ {sortBy === 'createdAt' && (
222
+ <ArrowUpDown className="w-3 h-3" />
223
+ )}
224
  </button>
225
  </div>
226
  </div>
 
229
  <div className="overflow-auto max-h-48">
230
  {sortedModels.map((model) => {
231
  const hasStats = model.likes > 0 || model.downloads > 0
232
+
233
  return (
234
  <ListboxOption
235
  key={model.id}
 
267
  <span>{formatNumber(model.downloads)}</span>
268
  </div>
269
  )}
270
+
271
  {model.createdAt && (
272
  <span className="text-xs text-gray-400">
273
  {model.createdAt.split('T')[0]}
src/components/PipelineSelector.tsx CHANGED
@@ -42,7 +42,7 @@ const PipelineSelector: React.FC<PipelineSelectorProps> = ({
42
  <div className="relative">
43
  <Listbox value={selectedPipeline} onChange={setPipeline}>
44
  <div className="relative">
45
- <ListboxButton className="relative w-full cursor-default rounded-lg bg-white py-2 pl-3 pr-10 text-left shadow-md focus:outline-none focus-visible:border-indigo-500 focus-visible:ring-2 focus-visible:ring-white focus-visible:ring-opacity-75 focus-visible:ring-offset-2 focus-visible:ring-offset-orange-300 sm:text-sm border border-gray-300">
46
  <span className="block truncate font-medium">
47
  {formatPipelineName(selectedPipeline)}
48
  </span>
@@ -62,7 +62,7 @@ const PipelineSelector: React.FC<PipelineSelectorProps> = ({
62
  leaveFrom="transform scale-100 opacity-100"
63
  leaveTo="transform scale-95 opacity-0"
64
  >
65
- <ListboxOptions className="absolute z-10 mt-1 max-h-60 w-full overflow-auto rounded-md bg-white py-1 text-base shadow-lg ring-1 ring-black ring-opacity-5 focus:outline-none sm:text-sm">
66
  {pipelines.map((p) => (
67
  <ListboxOption
68
  key={p}
 
42
  <div className="relative">
43
  <Listbox value={selectedPipeline} onChange={setPipeline}>
44
  <div className="relative">
45
+ <ListboxButton className="relative w-full cursor-default rounded-lg bg-white py-2 pl-3 pr-10 text-left focus:outline-none focus-visible:border-indigo-500 focus-visible:ring-2 focus-visible:ring-white focus-visible:ring-opacity-75 focus-visible:ring-offset-2 focus-visible:ring-offset-orange-300 sm:text-sm border border-gray-300">
46
  <span className="block truncate font-medium">
47
  {formatPipelineName(selectedPipeline)}
48
  </span>
 
62
  leaveFrom="transform scale-100 opacity-100"
63
  leaveTo="transform scale-95 opacity-0"
64
  >
65
+ <ListboxOptions className="absolute z-10 mt-1 max-h-60 w-full overflow-auto rounded-md bg-white py-1 text-base ring-1 ring-black ring-opacity-5 focus:outline-none sm:text-sm">
66
  {pipelines.map((p) => (
67
  <ListboxOption
68
  key={p}
src/components/TextClassification.tsx CHANGED
@@ -1,13 +1,11 @@
1
- import { useState, useRef, useEffect, useCallback } from 'react';
2
  import {
3
  ClassificationOutput,
4
  TextClassificationWorkerInput,
5
- WorkerMessage,
6
- } from '../types';
7
- import { useModel } from '../contexts/ModelContext';
8
- import { getModelInfo } from '../lib/huggingface';
9
- import { getWorker } from '../lib/workerManager';
10
-
11
 
12
  const PLACEHOLDER_TEXTS: string[] = [
13
  'I absolutely love this product! It exceeded all my expectations.',
@@ -20,61 +18,31 @@ const PLACEHOLDER_TEXTS: string[] = [
20
  'The product arrived damaged and the return process was a nightmare.',
21
  'Pretty good overall. A few minor issues but mostly positive experience.',
22
  'Outstanding! This company really knows how to treat their customers.'
23
- ].sort(() => Math.random() - 0.5);
24
 
25
  function TextClassification() {
26
  const [text, setText] = useState<string>(PLACEHOLDER_TEXTS.join('\n'))
27
  const [results, setResults] = useState<ClassificationOutput[]>([])
28
- const { setProgress, status, setStatus, modelInfo, setModelInfo, workerLoaded} = useModel()
29
  const workerRef = useRef<Worker | null>(null)
30
 
31
 
32
- useEffect(() => {
33
- if (!modelInfo.id) return;
34
- const fetchModelInfo = async () => {
35
- try {
36
- const modelInfoResponse = await getModelInfo(modelInfo.id)
37
- let parameters = 0
38
- if (modelInfoResponse.safetensors) {
39
- const safetensors = modelInfoResponse.safetensors
40
- parameters =
41
- (safetensors.parameters.F16 ||
42
- safetensors.parameters.F32 ||
43
- safetensors.parameters.total ||
44
- 0)
45
- }
46
- setModelInfo({
47
- ...modelInfo,
48
- architecture: modelInfoResponse.config?.architectures[0] ?? '',
49
- parameters,
50
- likes: modelInfoResponse.likes,
51
- downloads: modelInfoResponse.downloads
52
- })
53
- } catch (error) {
54
- console.error('Error fetching model info:', error)
55
- }
56
- }
57
-
58
- fetchModelInfo()
59
- }, [modelInfo.id, setModelInfo])
60
-
61
  // We use the `useEffect` hook to setup the worker as soon as the component is mounted.
62
  useEffect(() => {
63
- if(!workerRef.current) {
64
  workerRef.current = getWorker('text-classification')
65
  }
66
 
67
-
68
  // Create a callback function for messages from the worker thread.
69
  const onMessageReceived = (e: MessageEvent<WorkerMessage>) => {
70
  const status = e.data.status
71
- if (status === 'output') {
 
 
72
  setStatus('output')
73
  const result = e.data.output!
74
  setResults((prevResults) => [...prevResults, result])
75
  console.log(result)
76
- } else if (status === 'complete') {
77
- setStatus('idle')
78
  } else if (status === 'error') {
79
  setStatus('error')
80
  console.error(e.data.output)
@@ -87,10 +55,10 @@ function TextClassification() {
87
  // Define a cleanup function for when the component is unmounted.
88
  return () =>
89
  workerRef.current?.removeEventListener('message', onMessageReceived)
90
- }, [])
91
 
92
  const classify = useCallback(() => {
93
- setStatus('processing')
94
  setResults([]) // Clear previous results
95
  const message: TextClassificationWorkerInput = {
96
  type: 'classify',
@@ -98,17 +66,16 @@ function TextClassification() {
98
  model: modelInfo.id
99
  }
100
  workerRef.current?.postMessage(message)
101
- }, [text, modelInfo.id])
102
 
103
  const busy: boolean = status !== 'ready'
104
 
105
-
106
  const handleClear = (): void => {
107
  setResults([])
108
  }
109
 
110
  return (
111
- <div className="flex flex-col h-[40vh] max-h-[80vh] w-full p-4">
112
  <h1 className="text-2xl font-bold mb-4">Text Classification</h1>
113
 
114
  <div className="flex flex-col lg:flex-row gap-4 h-full">
@@ -125,14 +92,14 @@ function TextClassification() {
125
  <div className="flex gap-2 mt-4">
126
  <button
127
  className="flex-1 py-2 px-4 bg-blue-500 hover:bg-blue-600 rounded text-white font-medium disabled:opacity-50 disabled:cursor-not-allowed transition-colors"
128
- disabled={busy || !workerLoaded}
129
  onClick={classify}
130
  >
131
- {workerLoaded ? (!busy
132
- ? 'Classify Text'
133
- : status === 'loading'
134
- ? 'Model loading...'
135
- : 'Processing...') : 'Load model first'}
136
  </button>
137
  <button
138
  className="py-2 px-4 bg-gray-500 hover:bg-gray-600 rounded text-white font-medium transition-colors"
@@ -180,4 +147,4 @@ function TextClassification() {
180
  )
181
  }
182
 
183
- export default TextClassification;
 
1
+ import { useState, useRef, useEffect, useCallback } from 'react'
2
  import {
3
  ClassificationOutput,
4
  TextClassificationWorkerInput,
5
+ WorkerMessage
6
+ } from '../types'
7
+ import { useModel } from '../contexts/ModelContext'
8
+ import { getWorker } from '../lib/workerManager'
 
 
9
 
10
  const PLACEHOLDER_TEXTS: string[] = [
11
  'I absolutely love this product! It exceeded all my expectations.',
 
18
  'The product arrived damaged and the return process was a nightmare.',
19
  'Pretty good overall. A few minor issues but mostly positive experience.',
20
  'Outstanding! This company really knows how to treat their customers.'
21
+ ].sort(() => Math.random() - 0.5)
22
 
23
  function TextClassification() {
24
  const [text, setText] = useState<string>(PLACEHOLDER_TEXTS.join('\n'))
25
  const [results, setResults] = useState<ClassificationOutput[]>([])
26
+ const { status, setStatus, modelInfo } = useModel()
27
  const workerRef = useRef<Worker | null>(null)
28
 
29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  // We use the `useEffect` hook to setup the worker as soon as the component is mounted.
31
  useEffect(() => {
32
+ if (!workerRef.current) {
33
  workerRef.current = getWorker('text-classification')
34
  }
35
 
 
36
  // Create a callback function for messages from the worker thread.
37
  const onMessageReceived = (e: MessageEvent<WorkerMessage>) => {
38
  const status = e.data.status
39
+ if (status === 'ready') {
40
+ setStatus('ready')
41
+ } else if (status === 'output') {
42
  setStatus('output')
43
  const result = e.data.output!
44
  setResults((prevResults) => [...prevResults, result])
45
  console.log(result)
 
 
46
  } else if (status === 'error') {
47
  setStatus('error')
48
  console.error(e.data.output)
 
55
  // Define a cleanup function for when the component is unmounted.
56
  return () =>
57
  workerRef.current?.removeEventListener('message', onMessageReceived)
58
+ }, [setStatus])
59
 
60
  const classify = useCallback(() => {
61
+ setStatus('loading')
62
  setResults([]) // Clear previous results
63
  const message: TextClassificationWorkerInput = {
64
  type: 'classify',
 
66
  model: modelInfo.id
67
  }
68
  workerRef.current?.postMessage(message)
69
+ }, [text, modelInfo.id, setStatus])
70
 
71
  const busy: boolean = status !== 'ready'
72
 
 
73
  const handleClear = (): void => {
74
  setResults([])
75
  }
76
 
77
  return (
78
+ <div className="flex flex-col h-[60vh] max-h-[100vh] w-full p-4">
79
  <h1 className="text-2xl font-bold mb-4">Text Classification</h1>
80
 
81
  <div className="flex flex-col lg:flex-row gap-4 h-full">
 
92
  <div className="flex gap-2 mt-4">
93
  <button
94
  className="flex-1 py-2 px-4 bg-blue-500 hover:bg-blue-600 rounded text-white font-medium disabled:opacity-50 disabled:cursor-not-allowed transition-colors"
95
+ disabled={busy}
96
  onClick={classify}
97
  >
98
+ {status === 'ready'
99
+ ? !busy
100
+ ? 'Classify Text'
101
+ : 'Processing...'
102
+ : 'Load model first'}
103
  </button>
104
  <button
105
  className="py-2 px-4 bg-gray-500 hover:bg-gray-600 rounded text-white font-medium transition-colors"
 
147
  )
148
  }
149
 
150
+ export default TextClassification
src/components/ZeroShotClassification.tsx CHANGED
@@ -4,10 +4,8 @@ import {
4
  Section,
5
  WorkerMessage,
6
  ZeroShotWorkerInput,
7
- ModelInfo
8
  } from '../types'
9
  import { useModel } from '../contexts/ModelContext'
10
- import { getModelInfo } from '../lib/huggingface'
11
 
12
  const PLACEHOLDER_REVIEWS: string[] = [
13
  // battery/charging problems
@@ -51,7 +49,7 @@ function ZeroShotClassification() {
51
  PLACEHOLDER_SECTIONS.map((title) => ({ title, items: [] }))
52
  )
53
 
54
- const { setProgress, status, setStatus, modelInfo, setModelInfo } = useModel()
55
 
56
  // Create a reference to the worker object.
57
  const worker = useRef<Worker | null>(null)
@@ -72,17 +70,8 @@ function ZeroShotClassification() {
72
  // Create a callback function for messages from the worker thread.
73
  const onMessageReceived = (e: MessageEvent<WorkerMessage>) => {
74
  const status = e.data.status
75
- if (status === 'initiate') {
76
- setStatus('loading')
77
- } else if (status === 'ready') {
78
  setStatus('ready')
79
- } else if (status === 'progress') {
80
- setStatus('progress')
81
- if (
82
- e.data.output.progress &&
83
- (e.data.output.file as string).startsWith('onnx')
84
- )
85
- setProgress(e.data.output.progress)
86
  } else if (status === 'output') {
87
  setStatus('output')
88
  const { sequence, labels, scores } = e.data.output!
@@ -100,9 +89,6 @@ function ZeroShotClassification() {
100
  }
101
  return newSections
102
  })
103
- } else if (status === 'complete') {
104
- setStatus('idle')
105
- setProgress(100)
106
  } else if (status === 'error') {
107
  setStatus('error')
108
  console.error(e.data.output)
@@ -118,7 +104,7 @@ function ZeroShotClassification() {
118
  }, [sections])
119
 
120
  const classify = useCallback(() => {
121
- setStatus('processing')
122
  const message: ZeroShotWorkerInput = {
123
  text,
124
  labels: sections
@@ -129,7 +115,7 @@ function ZeroShotClassification() {
129
  worker.current?.postMessage(message)
130
  }, [text, sections, modelInfo.name])
131
 
132
- const busy: boolean = status !== 'idle'
133
 
134
  const handleAddCategory = (): void => {
135
  setSections((sections) => {
 
4
  Section,
5
  WorkerMessage,
6
  ZeroShotWorkerInput,
 
7
  } from '../types'
8
  import { useModel } from '../contexts/ModelContext'
 
9
 
10
  const PLACEHOLDER_REVIEWS: string[] = [
11
  // battery/charging problems
 
49
  PLACEHOLDER_SECTIONS.map((title) => ({ title, items: [] }))
50
  )
51
 
52
+ const { status, setStatus, modelInfo } = useModel()
53
 
54
  // Create a reference to the worker object.
55
  const worker = useRef<Worker | null>(null)
 
70
  // Create a callback function for messages from the worker thread.
71
  const onMessageReceived = (e: MessageEvent<WorkerMessage>) => {
72
  const status = e.data.status
73
+ if (status === 'ready') {
 
 
74
  setStatus('ready')
 
 
 
 
 
 
 
75
  } else if (status === 'output') {
76
  setStatus('output')
77
  const { sequence, labels, scores } = e.data.output!
 
89
  }
90
  return newSections
91
  })
 
 
 
92
  } else if (status === 'error') {
93
  setStatus('error')
94
  console.error(e.data.output)
 
104
  }, [sections])
105
 
106
  const classify = useCallback(() => {
107
+ setStatus('loading')
108
  const message: ZeroShotWorkerInput = {
109
  text,
110
  labels: sections
 
115
  worker.current?.postMessage(message)
116
  }, [text, sections, modelInfo.name])
117
 
118
+ const busy: boolean = status !== 'ready'
119
 
120
  const handleAddCategory = (): void => {
121
  setSections((sections) => {
src/contexts/ModelContext.tsx CHANGED
@@ -1,11 +1,21 @@
1
- import React, { createContext, RefObject, useContext, useEffect, useRef, useState } from 'react'
2
- import { ModelInfo, ModelInfoResponse, QuantizationType } from '../types'
 
 
 
 
 
 
 
 
 
 
3
 
4
  interface ModelContextType {
 
 
5
  progress: number
6
- status: string
7
  setProgress: (progress: number) => void
8
- setStatus: (status: string) => void
9
  modelInfo: ModelInfo
10
  setModelInfo: (model: ModelInfo) => void
11
  pipeline: string
@@ -16,22 +26,21 @@ interface ModelContextType {
16
  setSelectedQuantization: (quantization: QuantizationType) => void
17
  activeWorker: Worker | null
18
  setActiveWorker: (worker: Worker | null) => void
19
- workerLoaded: boolean
20
- setWorkerLoaded: (workerLoaded: boolean) => void
21
  }
22
 
23
  const ModelContext = createContext<ModelContextType | undefined>(undefined)
24
 
25
  export function ModelProvider({ children }: { children: React.ReactNode }) {
26
  const [progress, setProgress] = useState<number>(0)
27
- const [status, setStatus] = useState<string>('idle')
28
  const [modelInfo, setModelInfo] = useState<ModelInfo>({} as ModelInfo)
29
- const [models, setModels] = useState<ModelInfoResponse[]>([] as ModelInfoResponse[])
 
 
30
  const [pipeline, setPipeline] = useState<string>('text-classification')
31
- const [selectedQuantization, setSelectedQuantization] = useState<QuantizationType>('int8')
 
32
  const [activeWorker, setActiveWorker] = useState<Worker | null>(null)
33
- const [workerLoaded, setWorkerLoaded] = useState<boolean>(false)
34
-
35
 
36
  // set progress to 0 when model is changed
37
  useEffect(() => {
@@ -54,9 +63,7 @@ export function ModelProvider({ children }: { children: React.ReactNode }) {
54
  selectedQuantization,
55
  setSelectedQuantization,
56
  activeWorker,
57
- setActiveWorker,
58
- workerLoaded,
59
- setWorkerLoaded
60
  }}
61
  >
62
  {children}
 
1
+ import React, {
2
+ createContext,
3
+ useContext,
4
+ useEffect,
5
+ useState
6
+ } from 'react'
7
+ import {
8
+ ModelInfo,
9
+ ModelInfoResponse,
10
+ QuantizationType,
11
+ WorkerStatus
12
+ } from '../types'
13
 
14
  interface ModelContextType {
15
+ status: WorkerStatus
16
+ setStatus: (status: WorkerStatus) => void
17
  progress: number
 
18
  setProgress: (progress: number) => void
 
19
  modelInfo: ModelInfo
20
  setModelInfo: (model: ModelInfo) => void
21
  pipeline: string
 
26
  setSelectedQuantization: (quantization: QuantizationType) => void
27
  activeWorker: Worker | null
28
  setActiveWorker: (worker: Worker | null) => void
 
 
29
  }
30
 
31
  const ModelContext = createContext<ModelContextType | undefined>(undefined)
32
 
33
  export function ModelProvider({ children }: { children: React.ReactNode }) {
34
  const [progress, setProgress] = useState<number>(0)
35
+ const [status, setStatus] = useState<WorkerStatus>('initiate')
36
  const [modelInfo, setModelInfo] = useState<ModelInfo>({} as ModelInfo)
37
+ const [models, setModels] = useState<ModelInfoResponse[]>(
38
+ [] as ModelInfoResponse[]
39
+ )
40
  const [pipeline, setPipeline] = useState<string>('text-classification')
41
+ const [selectedQuantization, setSelectedQuantization] =
42
+ useState<QuantizationType>('int8')
43
  const [activeWorker, setActiveWorker] = useState<Worker | null>(null)
 
 
44
 
45
  // set progress to 0 when model is changed
46
  useEffect(() => {
 
63
  selectedQuantization,
64
  setSelectedQuantization,
65
  activeWorker,
66
+ setActiveWorker
 
 
67
  }}
68
  >
69
  {children}
src/lib/huggingface.ts CHANGED
@@ -114,7 +114,14 @@ const getModelsByPipeline = async (
114
  }
115
  const models = await response.json()
116
  if (pipeline_tag === 'text-classification') {
117
- return models.filter((model: ModelInfoResponse) => !model.tags.includes('reranker') && !model.id.includes('reranker')).slice(0, 10)
 
 
 
 
 
 
 
118
  }
119
  return models.slice(0, 10)
120
  }
 
114
  }
115
  const models = await response.json()
116
  if (pipeline_tag === 'text-classification') {
117
+ return models
118
+ .filter(
119
+ (model: ModelInfoResponse) =>
120
+ !model.tags.includes('reranker') &&
121
+ !model.id.includes('reranker') &&
122
+ !model.tags.includes('sentence-transformers')
123
+ )
124
+ .slice(0, 10)
125
  }
126
  return models.slice(0, 10)
127
  }
src/types.ts CHANGED
@@ -9,8 +9,10 @@ export interface ClassificationOutput {
9
  scores: number[]
10
  }
11
 
 
 
12
  export interface WorkerMessage {
13
- status: 'initiate' | 'ready' | 'output' | 'complete' | 'progress' | 'error'
14
  progress?: number
15
  error?: string
16
  output?: any
@@ -28,7 +30,6 @@ export interface TextClassificationWorkerInput {
28
  model: string
29
  }
30
 
31
- export type AppStatus = 'idle' | 'loading' | 'processing'
32
 
33
  type q8 = 'q8' | 'int8' | 'bnb8' | 'uint8'
34
  type q4 = 'q4' | 'bnb4'
 
9
  scores: number[]
10
  }
11
 
12
+ export type WorkerStatus = 'initiate' | 'ready' | 'output' | 'loading' | 'error'
13
+
14
  export interface WorkerMessage {
15
+ status: WorkerStatus
16
  progress?: number
17
  error?: string
18
  output?: any
 
30
  model: string
31
  }
32
 
 
33
 
34
  type q8 = 'q8' | 'int8' | 'bnb8' | 'uint8'
35
  type q4 = 'q4' | 'bnb4'