wip: refactor model loading and classification components
Browse files- public/workers/text-classification.js +3 -3
- src/App.tsx +23 -68
- src/components/ModelInfo.tsx +55 -148
- src/components/ModelLoader.tsx +144 -0
- src/components/ModelSelector.tsx +93 -64
- src/components/PipelineSelector.tsx +2 -2
- src/components/TextClassification.tsx +22 -55
- src/components/ZeroShotClassification.tsx +4 -18
- src/contexts/ModelContext.tsx +21 -14
- src/lib/huggingface.ts +8 -1
- src/types.ts +3 -2
public/workers/text-classification.js
CHANGED
@@ -29,7 +29,7 @@ self.addEventListener('message', async (event) => {
|
|
29 |
const classifier = await MyTextClassificationPipeline.getInstance(
|
30 |
model,
|
31 |
(x) => {
|
32 |
-
self.postMessage({ status: '
|
33 |
}
|
34 |
)
|
35 |
|
@@ -40,7 +40,7 @@ self.addEventListener('message', async (event) => {
|
|
40 |
|
41 |
if (type === 'classify') {
|
42 |
if (!text) {
|
43 |
-
self.postMessage({ status: '
|
44 |
return
|
45 |
}
|
46 |
const split = text.split('\n')
|
@@ -57,6 +57,6 @@ self.addEventListener('message', async (event) => {
|
|
57 |
})
|
58 |
}
|
59 |
}
|
60 |
-
self.postMessage({ status: '
|
61 |
}
|
62 |
})
|
|
|
29 |
const classifier = await MyTextClassificationPipeline.getInstance(
|
30 |
model,
|
31 |
(x) => {
|
32 |
+
self.postMessage({ status: 'loading', output: x })
|
33 |
}
|
34 |
)
|
35 |
|
|
|
40 |
|
41 |
if (type === 'classify') {
|
42 |
if (!text) {
|
43 |
+
self.postMessage({ status: 'ready' }) // Nothing to process
|
44 |
return
|
45 |
}
|
46 |
const split = text.split('\n')
|
|
|
57 |
})
|
58 |
}
|
59 |
}
|
60 |
+
self.postMessage({ status: 'ready' })
|
61 |
}
|
62 |
})
|
src/App.tsx
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
import { useEffect
|
2 |
import PipelineSelector from './components/PipelineSelector'
|
3 |
import ZeroShotClassification from './components/ZeroShotClassification'
|
4 |
import TextClassification from './components/TextClassification'
|
@@ -10,14 +10,12 @@ import ModelSelector from './components/ModelSelector'
|
|
10 |
import ModelInfo from './components/ModelInfo'
|
11 |
|
12 |
function App() {
|
13 |
-
const { pipeline, setPipeline,
|
14 |
-
useModel()
|
15 |
|
16 |
useEffect(() => {
|
17 |
const fetchModels = async () => {
|
18 |
const fetchedModels = await getModelsByPipeline(pipeline)
|
19 |
setModels(fetchedModels)
|
20 |
-
console.log(fetchedModels)
|
21 |
}
|
22 |
fetchModels()
|
23 |
}, [setModels, pipeline])
|
@@ -26,76 +24,34 @@ function App() {
|
|
26 |
<div className="min-h-screen bg-gradient-to-br from-blue-50 to-indigo-100">
|
27 |
<Header />
|
28 |
|
29 |
-
<main className="max-w-
|
30 |
-
{/* Pipeline Selection */}
|
31 |
<div className="mb-8">
|
32 |
-
<div className="bg-white rounded-lg
|
33 |
-
<div className="flex items-
|
34 |
-
<
|
35 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
pipeline={pipeline}
|
44 |
-
setPipeline={setPipeline}
|
45 |
-
/>
|
46 |
</div>
|
47 |
|
48 |
-
<div className="
|
49 |
-
<
|
50 |
-
Select Model
|
51 |
-
</span>
|
52 |
-
<ModelSelector />
|
53 |
</div>
|
54 |
</div>
|
55 |
|
56 |
-
{/* Model Loading Progress */}
|
57 |
-
{status === 'progress' && (
|
58 |
-
<div className="mt-4 p-4 bg-blue-50 rounded-lg">
|
59 |
-
<div className="flex items-center space-x-3">
|
60 |
-
<div className="flex-shrink-0">
|
61 |
-
<svg
|
62 |
-
className="animate-spin h-5 w-5 text-blue-500"
|
63 |
-
fill="none"
|
64 |
-
viewBox="0 0 24 24"
|
65 |
-
>
|
66 |
-
<circle
|
67 |
-
className="opacity-25"
|
68 |
-
cx="12"
|
69 |
-
cy="12"
|
70 |
-
r="10"
|
71 |
-
stroke="currentColor"
|
72 |
-
strokeWidth="4"
|
73 |
-
></circle>
|
74 |
-
<path
|
75 |
-
className="opacity-75"
|
76 |
-
fill="currentColor"
|
77 |
-
d="M4 12a8 8 0 018-8V0C5.373 0 0 5.373 0 12h4zm2 5.291A7.962 7.962 0 014 12H0c0 3.042 1.135 5.824 3 7.938l3-2.647z"
|
78 |
-
></path>
|
79 |
-
</svg>
|
80 |
-
</div>
|
81 |
-
<div className="flex-1">
|
82 |
-
<p className="text-sm font-medium text-blue-900">
|
83 |
-
Loading Model...
|
84 |
-
</p>
|
85 |
-
<div className="mt-2 bg-blue-200 rounded-full h-2">
|
86 |
-
<div
|
87 |
-
className="bg-blue-500 h-2 rounded-full transition-all duration-300"
|
88 |
-
style={{ width: `${progress.toFixed(2)}%` }}
|
89 |
-
></div>
|
90 |
-
</div>
|
91 |
-
<p className="text-xs text-blue-700 mt-1">
|
92 |
-
{progress.toFixed(2)}%
|
93 |
-
</p>
|
94 |
-
</div>
|
95 |
-
</div>
|
96 |
-
</div>
|
97 |
-
)}
|
98 |
-
|
99 |
{/* Pipeline Description */}
|
100 |
<div className="mt-4 p-4 bg-gray-50 rounded-lg">
|
101 |
<div className="flex items-start space-x-3">
|
@@ -129,7 +85,6 @@ function App() {
|
|
129 |
</div>
|
130 |
</div>
|
131 |
|
132 |
-
{/* Pipeline Component */}
|
133 |
<div className="bg-white rounded-lg shadow-sm border overflow-hidden">
|
134 |
{pipeline === 'zero-shot-classification' && (
|
135 |
<ZeroShotClassification />
|
|
|
1 |
+
import { useEffect } from 'react'
|
2 |
import PipelineSelector from './components/PipelineSelector'
|
3 |
import ZeroShotClassification from './components/ZeroShotClassification'
|
4 |
import TextClassification from './components/TextClassification'
|
|
|
10 |
import ModelInfo from './components/ModelInfo'
|
11 |
|
12 |
function App() {
|
13 |
+
const { pipeline, setPipeline, setModels } = useModel()
|
|
|
14 |
|
15 |
useEffect(() => {
|
16 |
const fetchModels = async () => {
|
17 |
const fetchedModels = await getModelsByPipeline(pipeline)
|
18 |
setModels(fetchedModels)
|
|
|
19 |
}
|
20 |
fetchModels()
|
21 |
}, [setModels, pipeline])
|
|
|
24 |
<div className="min-h-screen bg-gradient-to-br from-blue-50 to-indigo-100">
|
25 |
<Header />
|
26 |
|
27 |
+
<main className="max-w-6xl mx-auto px-4 sm:px-6 lg:px-8 py-8">
|
|
|
28 |
<div className="mb-8">
|
29 |
+
<div className="bg-white rounded-lg border p-6">
|
30 |
+
<div className="flex items-start justify-between max-w-6xl mx-auto">
|
31 |
+
<div className="space-y-2 flex-1">
|
32 |
+
<div className="space-y-2">
|
33 |
+
<span className="text-lg font-semibold text-gray-900 block">
|
34 |
+
Choose a Pipeline
|
35 |
+
</span>
|
36 |
+
<PipelineSelector
|
37 |
+
pipeline={pipeline}
|
38 |
+
setPipeline={setPipeline}
|
39 |
+
/>
|
40 |
+
</div>
|
41 |
|
42 |
+
<div className="space-y-2">
|
43 |
+
<span className="text-lg font-semibold text-gray-900 block">
|
44 |
+
Select Model
|
45 |
+
</span>
|
46 |
+
<ModelSelector />
|
47 |
+
</div>
|
|
|
|
|
|
|
48 |
</div>
|
49 |
|
50 |
+
<div className="ml-6">
|
51 |
+
<ModelInfo />
|
|
|
|
|
|
|
52 |
</div>
|
53 |
</div>
|
54 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
55 |
{/* Pipeline Description */}
|
56 |
<div className="mt-4 p-4 bg-gray-50 rounded-lg">
|
57 |
<div className="flex items-start space-x-3">
|
|
|
85 |
</div>
|
86 |
</div>
|
87 |
|
|
|
88 |
<div className="bg-white rounded-lg shadow-sm border overflow-hidden">
|
89 |
{pipeline === 'zero-shot-classification' && (
|
90 |
<ZeroShotClassification />
|
src/components/ModelInfo.tsx
CHANGED
@@ -6,14 +6,11 @@ import {
|
|
6 |
DatabaseIcon,
|
7 |
CheckCircle,
|
8 |
XCircle,
|
9 |
-
ExternalLink
|
10 |
-
ChevronDown
|
11 |
} from 'lucide-react'
|
12 |
import { getModelSize } from '../lib/huggingface'
|
13 |
import { useModel } from '../contexts/ModelContext'
|
14 |
-
import
|
15 |
-
import { QuantizationType, WorkerMessage } from '../types'
|
16 |
-
import { getWorker } from '../lib/workerManager'
|
17 |
|
18 |
const ModelInfo = () => {
|
19 |
const formatNumber = (num: number) => {
|
@@ -29,96 +26,50 @@ const ModelInfo = () => {
|
|
29 |
|
30 |
const {
|
31 |
modelInfo,
|
32 |
-
selectedQuantization
|
33 |
-
setSelectedQuantization,
|
34 |
-
status,
|
35 |
-
setStatus,
|
36 |
-
setProgress,
|
37 |
-
activeWorker,
|
38 |
-
setActiveWorker,
|
39 |
-
pipeline,
|
40 |
-
workerLoaded,
|
41 |
-
setWorkerLoaded
|
42 |
} = useModel()
|
43 |
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
} else if (quantizations.includes('q8')) {
|
52 |
-
defaultQuant = 'q8'
|
53 |
-
} else if (quantizations.includes('q4')) {
|
54 |
-
defaultQuant = 'q4'
|
55 |
-
}
|
56 |
-
|
57 |
-
setSelectedQuantization(defaultQuant)
|
58 |
-
}
|
59 |
-
}, [
|
60 |
-
modelInfo.supportedQuantizations,
|
61 |
-
modelInfo.isCompatible,
|
62 |
-
setSelectedQuantization
|
63 |
-
])
|
64 |
-
|
65 |
-
useEffect(() => {
|
66 |
-
const newWorker = getWorker(pipeline)
|
67 |
-
if (!newWorker) {
|
68 |
-
return
|
69 |
-
}
|
70 |
-
|
71 |
-
setStatus('idle')
|
72 |
-
setWorkerLoaded(false)
|
73 |
-
setActiveWorker(newWorker)
|
74 |
-
|
75 |
-
const onMessageReceived = (e: MessageEvent<WorkerMessage>) => {
|
76 |
-
const { status, output } = e.data
|
77 |
-
if (status === 'initiate') {
|
78 |
-
setStatus('loading')
|
79 |
-
} else if (status === 'ready') {
|
80 |
-
setStatus('ready')
|
81 |
-
setWorkerLoaded(true)
|
82 |
-
} else if (status === 'progress' && output) {
|
83 |
-
setStatus('progress')
|
84 |
-
if (
|
85 |
-
output.progress &&
|
86 |
-
typeof output.file === 'string' &&
|
87 |
-
output.file.startsWith('onnx')
|
88 |
-
) {
|
89 |
-
setProgress(output.progress)
|
90 |
-
}
|
91 |
-
}
|
92 |
-
}
|
93 |
-
|
94 |
-
newWorker.addEventListener('message', onMessageReceived)
|
95 |
-
|
96 |
-
return () => {
|
97 |
-
newWorker.removeEventListener('message', onMessageReceived)
|
98 |
-
// terminateWorker(pipeline);
|
99 |
-
}
|
100 |
-
}, [pipeline, selectedQuantization, setActiveWorker, setStatus, setProgress, setWorkerLoaded])
|
101 |
-
|
102 |
-
const loadModel = useCallback(() => {
|
103 |
-
if (!modelInfo.name || !selectedQuantization) return
|
104 |
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
model: modelInfo.name,
|
109 |
-
quantization: selectedQuantization
|
110 |
-
}
|
111 |
-
activeWorker?.postMessage(message)
|
112 |
-
}, [modelInfo.name, selectedQuantization, setStatus, activeWorker])
|
113 |
|
114 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
115 |
|
116 |
if (!modelInfo.name) {
|
117 |
-
return
|
118 |
}
|
119 |
|
120 |
return (
|
121 |
-
<div className="bg-gradient-to-r from-blue-50 to-indigo-50 px-4 py-3 rounded-lg border border-blue-200 space-y-3">
|
122 |
{/* Model Name Row */}
|
123 |
<div className="flex items-center space-x-2">
|
124 |
<Bot className="w-4 h-4 text-blue-600" />
|
@@ -150,16 +101,16 @@ const ModelInfo = () => {
|
|
150 |
|
151 |
{/* Base Model Link */}
|
152 |
{modelInfo.baseId && (
|
153 |
-
<div className="flex items-center space-x-2 ml-6">
|
154 |
<a
|
155 |
href={`https://huggingface.co/${modelInfo.baseId}`}
|
156 |
target="_blank"
|
157 |
rel="noopener noreferrer"
|
158 |
-
className="
|
159 |
title={`Base model: ${modelInfo.baseId}`}
|
160 |
>
|
161 |
<ExternalLink className="w-3 h-3 inline-block mr-1" />
|
162 |
-
{modelInfo.baseId}
|
163 |
</a>
|
164 |
</div>
|
165 |
)}
|
@@ -180,75 +131,31 @@ const ModelInfo = () => {
|
|
180 |
</div>
|
181 |
)}
|
182 |
|
183 |
-
|
184 |
-
<
|
185 |
-
|
186 |
<span>{formatNumber(modelInfo.parameters)}</span>
|
187 |
-
|
188 |
-
|
|
|
|
|
189 |
|
190 |
-
|
191 |
-
<
|
192 |
-
|
193 |
<span>
|
194 |
{`~${getModelSize(
|
195 |
modelInfo.parameters,
|
196 |
selectedQuantization
|
197 |
).toFixed(1)}MB`}
|
198 |
</span>
|
199 |
-
|
200 |
-
|
|
|
|
|
201 |
</div>
|
202 |
|
203 |
-
|
204 |
-
{modelInfo.isCompatible &&
|
205 |
-
modelInfo.supportedQuantizations.length > 0 && (
|
206 |
-
<hr className="border-gray-200" />
|
207 |
-
)}
|
208 |
-
|
209 |
-
{/* Quantization Dropdown */}
|
210 |
-
{modelInfo.isCompatible &&
|
211 |
-
modelInfo.supportedQuantizations.length > 0 && (
|
212 |
-
<div className="flex items-center space-x-2">
|
213 |
-
<span className="text-xs text-gray-600 font-medium">
|
214 |
-
Quantization:
|
215 |
-
</span>
|
216 |
-
<div className="relative">
|
217 |
-
<select
|
218 |
-
value={selectedQuantization || ''}
|
219 |
-
onChange={(e) =>
|
220 |
-
setSelectedQuantization(e.target.value as QuantizationType)
|
221 |
-
}
|
222 |
-
className="appearance-none bg-white border border-gray-300 rounded-md px-3 py-1 pr-8 text-xs text-gray-700 focus:outline-none focus:ring-2 focus:ring-blue-500 focus:border-blue-500"
|
223 |
-
>
|
224 |
-
<option value="">Select quantization</option>
|
225 |
-
{modelInfo.supportedQuantizations.map((quant) => (
|
226 |
-
<option key={quant} value={quant}>
|
227 |
-
{quant}
|
228 |
-
</option>
|
229 |
-
))}
|
230 |
-
</select>
|
231 |
-
<ChevronDown className="absolute right-2 top-1/2 transform -translate-y-1/2 w-3 h-3 text-gray-400 pointer-events-none" />
|
232 |
-
</div>
|
233 |
-
</div>
|
234 |
-
)}
|
235 |
-
|
236 |
-
{/* Load Model Button */}
|
237 |
-
{modelInfo.isCompatible && selectedQuantization && (
|
238 |
-
<div className="flex justify-center">
|
239 |
-
<button
|
240 |
-
className="py-2 px-4 bg-green-500 hover:bg-green-600 rounded text-white font-medium disabled:opacity-50 disabled:cursor-not-allowed transition-colors text-sm"
|
241 |
-
disabled={busy || !selectedQuantization || workerLoaded}
|
242 |
-
onClick={loadModel}
|
243 |
-
>
|
244 |
-
{status === 'loading'
|
245 |
-
? 'Loading Model...'
|
246 |
-
: workerLoaded
|
247 |
-
? 'Model Ready'
|
248 |
-
: 'Load Model'}
|
249 |
-
</button>
|
250 |
-
</div>
|
251 |
-
)}
|
252 |
|
253 |
{/* Incompatibility Message */}
|
254 |
{modelInfo.isCompatible === false && modelInfo.incompatibilityReason && (
|
|
|
6 |
DatabaseIcon,
|
7 |
CheckCircle,
|
8 |
XCircle,
|
9 |
+
ExternalLink
|
|
|
10 |
} from 'lucide-react'
|
11 |
import { getModelSize } from '../lib/huggingface'
|
12 |
import { useModel } from '../contexts/ModelContext'
|
13 |
+
import ModelLoader from './ModelLoader'
|
|
|
|
|
14 |
|
15 |
const ModelInfo = () => {
|
16 |
const formatNumber = (num: number) => {
|
|
|
26 |
|
27 |
const {
|
28 |
modelInfo,
|
29 |
+
selectedQuantization
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
} = useModel()
|
31 |
|
32 |
+
const ModelInfoSkeleton = () => (
|
33 |
+
<div className="mt-5 bg-gradient-to-r from-blue-50 to-indigo-50 px-4 py-3 rounded-lg border border-blue-200 space-y-4 h-full min-h-[160px] animate-pulse w-[400px]">
|
34 |
+
<div className="flex items-center space-x-2">
|
35 |
+
<Bot className="w-4 h-4 text-blue-300" />
|
36 |
+
<div className="h-4 bg-gray-300 rounded w-48"></div>
|
37 |
+
<div className="w-4 h-4 bg-gray-300 rounded-full"></div>
|
38 |
+
</div>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
|
40 |
+
<div className="flex items-center space-x-2 ml-6">
|
41 |
+
<div className="h-3 bg-gray-200 rounded w-32"></div>
|
42 |
+
</div>
|
|
|
|
|
|
|
|
|
|
|
43 |
|
44 |
+
<div className="flex items-center justify-self-end space-x-4">
|
45 |
+
<div className="flex items-center space-x-1">
|
46 |
+
<Heart className="w-3 h-3 text-red-300" />
|
47 |
+
<div className="h-3 bg-gray-200 rounded w-8"></div>
|
48 |
+
</div>
|
49 |
+
<div className="flex items-center space-x-1">
|
50 |
+
<Download className="w-3 h-3 text-green-300" />
|
51 |
+
<div className="h-3 bg-gray-200 rounded w-8"></div>
|
52 |
+
</div>
|
53 |
+
<div className="flex items-center space-x-1">
|
54 |
+
<Cpu className="w-3 h-3 text-purple-300" />
|
55 |
+
<div className="h-3 bg-gray-200 rounded w-8"></div>
|
56 |
+
</div>
|
57 |
+
<div className="flex items-center space-x-1">
|
58 |
+
<DatabaseIcon className="w-3 h-3 text-purple-300" />
|
59 |
+
<div className="h-3 bg-gray-200 rounded w-12"></div>
|
60 |
+
</div>
|
61 |
+
</div>
|
62 |
+
<hr className="border-gray-200" />
|
63 |
+
<div className="h-8 bg-gray-200 rounded w-full"></div>
|
64 |
+
</div>
|
65 |
+
)
|
66 |
|
67 |
if (!modelInfo.name) {
|
68 |
+
return <ModelInfoSkeleton />
|
69 |
}
|
70 |
|
71 |
return (
|
72 |
+
<div className="mt-5 bg-gradient-to-r from-blue-50 to-indigo-50 px-4 py-3 rounded-lg border border-blue-200 space-y-3 h-full min-h-[150px]">
|
73 |
{/* Model Name Row */}
|
74 |
<div className="flex items-center space-x-2">
|
75 |
<Bot className="w-4 h-4 text-blue-600" />
|
|
|
101 |
|
102 |
{/* Base Model Link */}
|
103 |
{modelInfo.baseId && (
|
104 |
+
<div className="flex items-center space-x-2 ml-6 text-xs text-gray-600 truncate max-w-100">
|
105 |
<a
|
106 |
href={`https://huggingface.co/${modelInfo.baseId}`}
|
107 |
target="_blank"
|
108 |
rel="noopener noreferrer"
|
109 |
+
className=" hover:underline"
|
110 |
title={`Base model: ${modelInfo.baseId}`}
|
111 |
>
|
112 |
<ExternalLink className="w-3 h-3 inline-block mr-1" />
|
113 |
+
({modelInfo.baseId})
|
114 |
</a>
|
115 |
</div>
|
116 |
)}
|
|
|
131 |
</div>
|
132 |
)}
|
133 |
|
134 |
+
<div className="flex items-center space-x-1">
|
135 |
+
<Cpu className="w-3 h-3 text-purple-500" />
|
136 |
+
{modelInfo.parameters ? (
|
137 |
<span>{formatNumber(modelInfo.parameters)}</span>
|
138 |
+
) : (
|
139 |
+
<span>?</span>
|
140 |
+
)}
|
141 |
+
</div>
|
142 |
|
143 |
+
<div className="flex items-center space-x-1">
|
144 |
+
<DatabaseIcon className="w-3 h-3 text-purple-500" />
|
145 |
+
{modelInfo.parameters ? (
|
146 |
<span>
|
147 |
{`~${getModelSize(
|
148 |
modelInfo.parameters,
|
149 |
selectedQuantization
|
150 |
).toFixed(1)}MB`}
|
151 |
</span>
|
152 |
+
) : (
|
153 |
+
<span>?</span>
|
154 |
+
)}
|
155 |
+
</div>
|
156 |
</div>
|
157 |
|
158 |
+
<ModelLoader />
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
159 |
|
160 |
{/* Incompatibility Message */}
|
161 |
{modelInfo.isCompatible === false && modelInfo.incompatibilityReason && (
|
src/components/ModelLoader.tsx
ADDED
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import { useEffect, useCallback } from 'react'
|
2 |
+
import { ChevronDown, Loader } from 'lucide-react'
|
3 |
+
import { QuantizationType, WorkerMessage } from '../types'
|
4 |
+
import { useModel } from '../contexts/ModelContext'
|
5 |
+
import { getWorker } from '../lib/workerManager'
|
6 |
+
|
7 |
+
const ModelLoader = () => {
|
8 |
+
const {
|
9 |
+
modelInfo,
|
10 |
+
selectedQuantization,
|
11 |
+
setSelectedQuantization,
|
12 |
+
status,
|
13 |
+
progress,
|
14 |
+
setStatus,
|
15 |
+
setProgress,
|
16 |
+
activeWorker,
|
17 |
+
setActiveWorker,
|
18 |
+
pipeline
|
19 |
+
} = useModel()
|
20 |
+
|
21 |
+
useEffect(() => {
|
22 |
+
if (modelInfo.isCompatible && modelInfo.supportedQuantizations.length > 0) {
|
23 |
+
const quantizations = modelInfo.supportedQuantizations
|
24 |
+
let defaultQuant: QuantizationType = 'fp32'
|
25 |
+
|
26 |
+
if (quantizations.includes('int8')) {
|
27 |
+
defaultQuant = 'int8'
|
28 |
+
} else if (quantizations.includes('q8')) {
|
29 |
+
defaultQuant = 'q8'
|
30 |
+
} else if (quantizations.includes('q4')) {
|
31 |
+
defaultQuant = 'q4'
|
32 |
+
}
|
33 |
+
|
34 |
+
setSelectedQuantization(defaultQuant)
|
35 |
+
}
|
36 |
+
}, [
|
37 |
+
modelInfo.supportedQuantizations,
|
38 |
+
modelInfo.isCompatible,
|
39 |
+
setSelectedQuantization
|
40 |
+
])
|
41 |
+
|
42 |
+
useEffect(() => {
|
43 |
+
const newWorker = getWorker(pipeline)
|
44 |
+
if (!newWorker) {
|
45 |
+
return
|
46 |
+
}
|
47 |
+
|
48 |
+
setStatus('initiate')
|
49 |
+
setActiveWorker(newWorker)
|
50 |
+
|
51 |
+
const onMessageReceived = (e: MessageEvent<WorkerMessage>) => {
|
52 |
+
const { status, output } = e.data
|
53 |
+
if (status === 'ready') {
|
54 |
+
setStatus('ready')
|
55 |
+
} else if (status === 'loading' && output) {
|
56 |
+
setStatus('loading')
|
57 |
+
if (
|
58 |
+
output.progress &&
|
59 |
+
typeof output.file === 'string' &&
|
60 |
+
output.file.startsWith('onnx')
|
61 |
+
) {
|
62 |
+
setProgress(output.progress)
|
63 |
+
}
|
64 |
+
}
|
65 |
+
}
|
66 |
+
|
67 |
+
newWorker.addEventListener('message', onMessageReceived)
|
68 |
+
|
69 |
+
return () => {
|
70 |
+
newWorker.removeEventListener('message', onMessageReceived)
|
71 |
+
// terminateWorker(pipeline);
|
72 |
+
}
|
73 |
+
}, [pipeline, modelInfo.name, selectedQuantization, setActiveWorker, setStatus, setProgress])
|
74 |
+
|
75 |
+
const loadModel = useCallback(() => {
|
76 |
+
if (!modelInfo.name || !selectedQuantization) return
|
77 |
+
|
78 |
+
setStatus('loading')
|
79 |
+
const message = {
|
80 |
+
type: 'load',
|
81 |
+
model: modelInfo.name,
|
82 |
+
quantization: selectedQuantization
|
83 |
+
}
|
84 |
+
activeWorker?.postMessage(message)
|
85 |
+
}, [modelInfo.name, selectedQuantization, setStatus, activeWorker])
|
86 |
+
|
87 |
+
const ready: boolean = status === 'ready'
|
88 |
+
const busy: boolean = status === 'loading'
|
89 |
+
|
90 |
+
if (!modelInfo.isCompatible || modelInfo.supportedQuantizations.length === 0) {
|
91 |
+
return null
|
92 |
+
}
|
93 |
+
|
94 |
+
return (
|
95 |
+
<div className="space-y-3">
|
96 |
+
<hr className="border-gray-200" />
|
97 |
+
|
98 |
+
<div className="flex items-center justify-between space-x-4">
|
99 |
+
<div className="flex items-center space-x-2">
|
100 |
+
<span className="text-xs text-gray-600 font-medium">
|
101 |
+
Quantization:
|
102 |
+
</span>
|
103 |
+
<div className="relative">
|
104 |
+
<select
|
105 |
+
value={selectedQuantization || ''}
|
106 |
+
onChange={(e) =>
|
107 |
+
setSelectedQuantization(e.target.value as QuantizationType)
|
108 |
+
}
|
109 |
+
className="appearance-none bg-white border border-gray-300 rounded-md px-3 py-1 pr-8 text-xs text-gray-700 focus:outline-none focus:ring-2 focus:ring-blue-500 focus:border-blue-500"
|
110 |
+
>
|
111 |
+
<option value="">Select quantization</option>
|
112 |
+
{modelInfo.supportedQuantizations.map((quant) => (
|
113 |
+
<option key={quant} value={quant}>
|
114 |
+
{quant}
|
115 |
+
</option>
|
116 |
+
))}
|
117 |
+
</select>
|
118 |
+
<ChevronDown className="absolute right-2 top-1/2 transform -translate-y-1/2 w-3 h-3 text-gray-400 pointer-events-none" />
|
119 |
+
</div>
|
120 |
+
</div>
|
121 |
+
|
122 |
+
{selectedQuantization && (
|
123 |
+
<div className="flex justify-center">
|
124 |
+
<button
|
125 |
+
className="w-32 py-2 px-4 bg-green-500 hover:bg-green-600 rounded text-white font-medium disabled:opacity-50 disabled:cursor-not-allowed transition-colors text-sm inline-flex items-center text-center justify-center space-x-2"
|
126 |
+
disabled={(busy && !ready) || !selectedQuantization || ready}
|
127 |
+
onClick={loadModel}
|
128 |
+
>
|
129 |
+
{status === 'loading' && (
|
130 |
+
<>
|
131 |
+
<Loader className="animate-spin h-4 w-4" />
|
132 |
+
<span>{progress.toFixed(0)}%</span>
|
133 |
+
</>
|
134 |
+
)}
|
135 |
+
{!ready && !busy ? <span>Load Model</span> : !ready ? null : <span>Model Ready</span>}
|
136 |
+
</button>
|
137 |
+
</div>
|
138 |
+
)}
|
139 |
+
</div>
|
140 |
+
</div>
|
141 |
+
)
|
142 |
+
}
|
143 |
+
|
144 |
+
export default ModelLoader
|
src/components/ModelSelector.tsx
CHANGED
@@ -1,5 +1,11 @@
|
|
1 |
-
import React, { useEffect, useState } from 'react'
|
2 |
-
import {
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
import { useModel } from '../contexts/ModelContext'
|
4 |
import { getModelInfo } from '../lib/huggingface'
|
5 |
import { Heart, Download, ChevronDown, Check, ArrowUpDown } from 'lucide-react'
|
@@ -50,42 +56,46 @@ const ModelSelector: React.FC = () => {
|
|
50 |
}, [models, sortBy, sortOrder])
|
51 |
|
52 |
// Function to fetch detailed model info and set as selected
|
53 |
-
const fetchAndSetModelInfo =
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
let parameters = 0
|
58 |
-
if (modelInfoResponse.safetensors) {
|
59 |
-
const safetensors = modelInfoResponse.safetensors
|
60 |
-
parameters =
|
61 |
-
safetensors.parameters.BF16 ||
|
62 |
-
safetensors.parameters.F16 ||
|
63 |
-
safetensors.parameters.F32 ||
|
64 |
-
safetensors.parameters.total ||
|
65 |
-
0
|
66 |
-
}
|
67 |
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
supportedQuantizations: modelInfoResponse.supportedQuantizations,
|
79 |
-
baseId: modelInfoResponse.baseId
|
80 |
-
}
|
81 |
|
82 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
83 |
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
|
|
|
|
|
|
|
|
89 |
|
90 |
// Update modelInfo to first model when pipeline changes
|
91 |
useEffect(() => {
|
@@ -93,7 +103,7 @@ const ModelSelector: React.FC = () => {
|
|
93 |
const firstModel = models[0]
|
94 |
fetchAndSetModelInfo(firstModel.id)
|
95 |
}
|
96 |
-
}, [pipeline, models])
|
97 |
|
98 |
const handleModelSelect = (modelId: string) => {
|
99 |
fetchAndSetModelInfo(modelId)
|
@@ -108,35 +118,42 @@ const ModelSelector: React.FC = () => {
|
|
108 |
}
|
109 |
}
|
110 |
|
111 |
-
const selectedModel =
|
|
|
112 |
|
113 |
return (
|
114 |
<div className="relative">
|
115 |
-
<Listbox
|
|
|
|
|
|
|
116 |
<div className="relative">
|
117 |
<ListboxButton className="w-full px-3 py-2 border border-gray-300 rounded-md focus:outline-none focus:ring-2 focus:ring-blue-500 focus:border-transparent bg-white text-left flex items-center justify-between">
|
118 |
<div className="flex items-center justify-between w-full">
|
119 |
<div className="flex flex-col flex-1 min-w-0">
|
120 |
-
<span className="truncate font-medium">
|
|
|
|
|
121 |
</div>
|
122 |
-
|
123 |
<div className="flex items-center space-x-3">
|
124 |
-
{selectedModel &&
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
<
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
<
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
|
|
140 |
<ChevronDown className="w-4 h-4 ui-open:rotate-180 transition-transform flex-shrink-0" />
|
141 |
</div>
|
142 |
</div>
|
@@ -158,7 +175,9 @@ const ModelSelector: React.FC = () => {
|
|
158 |
<button
|
159 |
onClick={() => handleSortChange('name')}
|
160 |
className={`px-2 py-1 rounded flex items-center space-x-1 ${
|
161 |
-
sortBy === 'name'
|
|
|
|
|
162 |
}`}
|
163 |
>
|
164 |
<span>Name</span>
|
@@ -167,7 +186,9 @@ const ModelSelector: React.FC = () => {
|
|
167 |
<button
|
168 |
onClick={() => handleSortChange('likes')}
|
169 |
className={`px-2 py-1 rounded flex items-center space-x-1 ${
|
170 |
-
sortBy === 'likes'
|
|
|
|
|
171 |
}`}
|
172 |
>
|
173 |
<Heart className="w-3 h-3" />
|
@@ -177,21 +198,29 @@ const ModelSelector: React.FC = () => {
|
|
177 |
<button
|
178 |
onClick={() => handleSortChange('downloads')}
|
179 |
className={`px-2 py-1 rounded flex items-center space-x-1 ${
|
180 |
-
sortBy === 'downloads'
|
|
|
|
|
181 |
}`}
|
182 |
>
|
183 |
<Download className="w-3 h-3" />
|
184 |
<span>Downloads</span>
|
185 |
-
{sortBy === 'downloads' &&
|
|
|
|
|
186 |
</button>
|
187 |
<button
|
188 |
onClick={() => handleSortChange('createdAt')}
|
189 |
className={`px-2 py-1 rounded flex items-center space-x-1 ${
|
190 |
-
sortBy === 'createdAt'
|
|
|
|
|
191 |
}`}
|
192 |
>
|
193 |
<span>Date</span>
|
194 |
-
{sortBy === 'createdAt' &&
|
|
|
|
|
195 |
</button>
|
196 |
</div>
|
197 |
</div>
|
@@ -200,7 +229,7 @@ const ModelSelector: React.FC = () => {
|
|
200 |
<div className="overflow-auto max-h-48">
|
201 |
{sortedModels.map((model) => {
|
202 |
const hasStats = model.likes > 0 || model.downloads > 0
|
203 |
-
|
204 |
return (
|
205 |
<ListboxOption
|
206 |
key={model.id}
|
@@ -238,7 +267,7 @@ const ModelSelector: React.FC = () => {
|
|
238 |
<span>{formatNumber(model.downloads)}</span>
|
239 |
</div>
|
240 |
)}
|
241 |
-
|
242 |
{model.createdAt && (
|
243 |
<span className="text-xs text-gray-400">
|
244 |
{model.createdAt.split('T')[0]}
|
|
|
1 |
+
import React, { useCallback, useEffect, useState } from 'react'
|
2 |
+
import {
|
3 |
+
Listbox,
|
4 |
+
ListboxButton,
|
5 |
+
ListboxOption,
|
6 |
+
ListboxOptions,
|
7 |
+
Transition
|
8 |
+
} from '@headlessui/react'
|
9 |
import { useModel } from '../contexts/ModelContext'
|
10 |
import { getModelInfo } from '../lib/huggingface'
|
11 |
import { Heart, Download, ChevronDown, Check, ArrowUpDown } from 'lucide-react'
|
|
|
56 |
}, [models, sortBy, sortOrder])
|
57 |
|
58 |
// Function to fetch detailed model info and set as selected
|
59 |
+
const fetchAndSetModelInfo = useCallback(
|
60 |
+
async (modelId: string) => {
|
61 |
+
try {
|
62 |
+
const modelInfoResponse = await getModelInfo(modelId)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
63 |
|
64 |
+
let parameters = 0
|
65 |
+
if (modelInfoResponse.safetensors) {
|
66 |
+
const safetensors = modelInfoResponse.safetensors
|
67 |
+
parameters =
|
68 |
+
safetensors.parameters.BF16 ||
|
69 |
+
safetensors.parameters.F16 ||
|
70 |
+
safetensors.parameters.F32 ||
|
71 |
+
safetensors.parameters.total ||
|
72 |
+
0
|
73 |
+
}
|
|
|
|
|
|
|
74 |
|
75 |
+
const modelInfo = {
|
76 |
+
id: modelId,
|
77 |
+
name: modelInfoResponse.id || modelId,
|
78 |
+
architecture:
|
79 |
+
modelInfoResponse.config?.architectures?.[0] || 'Unknown',
|
80 |
+
parameters,
|
81 |
+
likes: modelInfoResponse.likes || 0,
|
82 |
+
downloads: modelInfoResponse.downloads || 0,
|
83 |
+
createdAt: modelInfoResponse.createdAt || '',
|
84 |
+
isCompatible: modelInfoResponse.isCompatible,
|
85 |
+
incompatibilityReason: modelInfoResponse.incompatibilityReason,
|
86 |
+
supportedQuantizations: modelInfoResponse.supportedQuantizations,
|
87 |
+
baseId: modelInfoResponse.baseId
|
88 |
+
}
|
89 |
|
90 |
+
console.log('Fetched model info:', modelInfoResponse)
|
91 |
+
|
92 |
+
setModelInfo(modelInfo)
|
93 |
+
} catch (error) {
|
94 |
+
console.error('Error fetching model info:', error)
|
95 |
+
}
|
96 |
+
},
|
97 |
+
[setModelInfo]
|
98 |
+
)
|
99 |
|
100 |
// Update modelInfo to first model when pipeline changes
|
101 |
useEffect(() => {
|
|
|
103 |
const firstModel = models[0]
|
104 |
fetchAndSetModelInfo(firstModel.id)
|
105 |
}
|
106 |
+
}, [pipeline, models, fetchAndSetModelInfo])
|
107 |
|
108 |
const handleModelSelect = (modelId: string) => {
|
109 |
fetchAndSetModelInfo(modelId)
|
|
|
118 |
}
|
119 |
}
|
120 |
|
121 |
+
const selectedModel =
|
122 |
+
models.find((model) => model.id === modelInfo.id) || models[0]
|
123 |
|
124 |
return (
|
125 |
<div className="relative">
|
126 |
+
<Listbox
|
127 |
+
value={selectedModel}
|
128 |
+
onChange={(model) => handleModelSelect(model.id)}
|
129 |
+
>
|
130 |
<div className="relative">
|
131 |
<ListboxButton className="w-full px-3 py-2 border border-gray-300 rounded-md focus:outline-none focus:ring-2 focus:ring-blue-500 focus:border-transparent bg-white text-left flex items-center justify-between">
|
132 |
<div className="flex items-center justify-between w-full">
|
133 |
<div className="flex flex-col flex-1 min-w-0">
|
134 |
+
<span className="truncate font-medium">
|
135 |
+
{modelInfo.id || 'Select a model'}
|
136 |
+
</span>
|
137 |
</div>
|
138 |
+
|
139 |
<div className="flex items-center space-x-3">
|
140 |
+
{selectedModel &&
|
141 |
+
(selectedModel.likes > 0 || selectedModel.downloads > 0) && (
|
142 |
+
<div className="flex items-center space-x-3 text-xs text-gray-500">
|
143 |
+
{selectedModel.likes > 0 && (
|
144 |
+
<div className="flex items-center space-x-1">
|
145 |
+
<Heart className="w-3 h-3 text-red-500" />
|
146 |
+
<span>{formatNumber(selectedModel.likes)}</span>
|
147 |
+
</div>
|
148 |
+
)}
|
149 |
+
{selectedModel.downloads > 0 && (
|
150 |
+
<div className="flex items-center space-x-1">
|
151 |
+
<Download className="w-3 h-3 text-green-500" />
|
152 |
+
<span>{formatNumber(selectedModel.downloads)}</span>
|
153 |
+
</div>
|
154 |
+
)}
|
155 |
+
</div>
|
156 |
+
)}
|
157 |
<ChevronDown className="w-4 h-4 ui-open:rotate-180 transition-transform flex-shrink-0" />
|
158 |
</div>
|
159 |
</div>
|
|
|
175 |
<button
|
176 |
onClick={() => handleSortChange('name')}
|
177 |
className={`px-2 py-1 rounded flex items-center space-x-1 ${
|
178 |
+
sortBy === 'name'
|
179 |
+
? 'bg-blue-100 text-blue-700'
|
180 |
+
: 'text-gray-600 hover:bg-gray-100'
|
181 |
}`}
|
182 |
>
|
183 |
<span>Name</span>
|
|
|
186 |
<button
|
187 |
onClick={() => handleSortChange('likes')}
|
188 |
className={`px-2 py-1 rounded flex items-center space-x-1 ${
|
189 |
+
sortBy === 'likes'
|
190 |
+
? 'bg-blue-100 text-blue-700'
|
191 |
+
: 'text-gray-600 hover:bg-gray-100'
|
192 |
}`}
|
193 |
>
|
194 |
<Heart className="w-3 h-3" />
|
|
|
198 |
<button
|
199 |
onClick={() => handleSortChange('downloads')}
|
200 |
className={`px-2 py-1 rounded flex items-center space-x-1 ${
|
201 |
+
sortBy === 'downloads'
|
202 |
+
? 'bg-blue-100 text-blue-700'
|
203 |
+
: 'text-gray-600 hover:bg-gray-100'
|
204 |
}`}
|
205 |
>
|
206 |
<Download className="w-3 h-3" />
|
207 |
<span>Downloads</span>
|
208 |
+
{sortBy === 'downloads' && (
|
209 |
+
<ArrowUpDown className="w-3 h-3" />
|
210 |
+
)}
|
211 |
</button>
|
212 |
<button
|
213 |
onClick={() => handleSortChange('createdAt')}
|
214 |
className={`px-2 py-1 rounded flex items-center space-x-1 ${
|
215 |
+
sortBy === 'createdAt'
|
216 |
+
? 'bg-blue-100 text-blue-700'
|
217 |
+
: 'text-gray-600 hover:bg-gray-100'
|
218 |
}`}
|
219 |
>
|
220 |
<span>Date</span>
|
221 |
+
{sortBy === 'createdAt' && (
|
222 |
+
<ArrowUpDown className="w-3 h-3" />
|
223 |
+
)}
|
224 |
</button>
|
225 |
</div>
|
226 |
</div>
|
|
|
229 |
<div className="overflow-auto max-h-48">
|
230 |
{sortedModels.map((model) => {
|
231 |
const hasStats = model.likes > 0 || model.downloads > 0
|
232 |
+
|
233 |
return (
|
234 |
<ListboxOption
|
235 |
key={model.id}
|
|
|
267 |
<span>{formatNumber(model.downloads)}</span>
|
268 |
</div>
|
269 |
)}
|
270 |
+
|
271 |
{model.createdAt && (
|
272 |
<span className="text-xs text-gray-400">
|
273 |
{model.createdAt.split('T')[0]}
|
src/components/PipelineSelector.tsx
CHANGED
@@ -42,7 +42,7 @@ const PipelineSelector: React.FC<PipelineSelectorProps> = ({
|
|
42 |
<div className="relative">
|
43 |
<Listbox value={selectedPipeline} onChange={setPipeline}>
|
44 |
<div className="relative">
|
45 |
-
<ListboxButton className="relative w-full cursor-default rounded-lg bg-white py-2 pl-3 pr-10 text-left
|
46 |
<span className="block truncate font-medium">
|
47 |
{formatPipelineName(selectedPipeline)}
|
48 |
</span>
|
@@ -62,7 +62,7 @@ const PipelineSelector: React.FC<PipelineSelectorProps> = ({
|
|
62 |
leaveFrom="transform scale-100 opacity-100"
|
63 |
leaveTo="transform scale-95 opacity-0"
|
64 |
>
|
65 |
-
<ListboxOptions className="absolute z-10 mt-1 max-h-60 w-full overflow-auto rounded-md bg-white py-1 text-base
|
66 |
{pipelines.map((p) => (
|
67 |
<ListboxOption
|
68 |
key={p}
|
|
|
42 |
<div className="relative">
|
43 |
<Listbox value={selectedPipeline} onChange={setPipeline}>
|
44 |
<div className="relative">
|
45 |
+
<ListboxButton className="relative w-full cursor-default rounded-lg bg-white py-2 pl-3 pr-10 text-left focus:outline-none focus-visible:border-indigo-500 focus-visible:ring-2 focus-visible:ring-white focus-visible:ring-opacity-75 focus-visible:ring-offset-2 focus-visible:ring-offset-orange-300 sm:text-sm border border-gray-300">
|
46 |
<span className="block truncate font-medium">
|
47 |
{formatPipelineName(selectedPipeline)}
|
48 |
</span>
|
|
|
62 |
leaveFrom="transform scale-100 opacity-100"
|
63 |
leaveTo="transform scale-95 opacity-0"
|
64 |
>
|
65 |
+
<ListboxOptions className="absolute z-10 mt-1 max-h-60 w-full overflow-auto rounded-md bg-white py-1 text-base ring-1 ring-black ring-opacity-5 focus:outline-none sm:text-sm">
|
66 |
{pipelines.map((p) => (
|
67 |
<ListboxOption
|
68 |
key={p}
|
src/components/TextClassification.tsx
CHANGED
@@ -1,13 +1,11 @@
|
|
1 |
-
import { useState, useRef, useEffect, useCallback } from 'react'
|
2 |
import {
|
3 |
ClassificationOutput,
|
4 |
TextClassificationWorkerInput,
|
5 |
-
WorkerMessage
|
6 |
-
} from '../types'
|
7 |
-
import { useModel } from '../contexts/ModelContext'
|
8 |
-
import {
|
9 |
-
import { getWorker } from '../lib/workerManager';
|
10 |
-
|
11 |
|
12 |
const PLACEHOLDER_TEXTS: string[] = [
|
13 |
'I absolutely love this product! It exceeded all my expectations.',
|
@@ -20,61 +18,31 @@ const PLACEHOLDER_TEXTS: string[] = [
|
|
20 |
'The product arrived damaged and the return process was a nightmare.',
|
21 |
'Pretty good overall. A few minor issues but mostly positive experience.',
|
22 |
'Outstanding! This company really knows how to treat their customers.'
|
23 |
-
].sort(() => Math.random() - 0.5)
|
24 |
|
25 |
function TextClassification() {
|
26 |
const [text, setText] = useState<string>(PLACEHOLDER_TEXTS.join('\n'))
|
27 |
const [results, setResults] = useState<ClassificationOutput[]>([])
|
28 |
-
const {
|
29 |
const workerRef = useRef<Worker | null>(null)
|
30 |
|
31 |
|
32 |
-
useEffect(() => {
|
33 |
-
if (!modelInfo.id) return;
|
34 |
-
const fetchModelInfo = async () => {
|
35 |
-
try {
|
36 |
-
const modelInfoResponse = await getModelInfo(modelInfo.id)
|
37 |
-
let parameters = 0
|
38 |
-
if (modelInfoResponse.safetensors) {
|
39 |
-
const safetensors = modelInfoResponse.safetensors
|
40 |
-
parameters =
|
41 |
-
(safetensors.parameters.F16 ||
|
42 |
-
safetensors.parameters.F32 ||
|
43 |
-
safetensors.parameters.total ||
|
44 |
-
0)
|
45 |
-
}
|
46 |
-
setModelInfo({
|
47 |
-
...modelInfo,
|
48 |
-
architecture: modelInfoResponse.config?.architectures[0] ?? '',
|
49 |
-
parameters,
|
50 |
-
likes: modelInfoResponse.likes,
|
51 |
-
downloads: modelInfoResponse.downloads
|
52 |
-
})
|
53 |
-
} catch (error) {
|
54 |
-
console.error('Error fetching model info:', error)
|
55 |
-
}
|
56 |
-
}
|
57 |
-
|
58 |
-
fetchModelInfo()
|
59 |
-
}, [modelInfo.id, setModelInfo])
|
60 |
-
|
61 |
// We use the `useEffect` hook to setup the worker as soon as the component is mounted.
|
62 |
useEffect(() => {
|
63 |
-
if(!workerRef.current) {
|
64 |
workerRef.current = getWorker('text-classification')
|
65 |
}
|
66 |
|
67 |
-
|
68 |
// Create a callback function for messages from the worker thread.
|
69 |
const onMessageReceived = (e: MessageEvent<WorkerMessage>) => {
|
70 |
const status = e.data.status
|
71 |
-
if (status === '
|
|
|
|
|
72 |
setStatus('output')
|
73 |
const result = e.data.output!
|
74 |
setResults((prevResults) => [...prevResults, result])
|
75 |
console.log(result)
|
76 |
-
} else if (status === 'complete') {
|
77 |
-
setStatus('idle')
|
78 |
} else if (status === 'error') {
|
79 |
setStatus('error')
|
80 |
console.error(e.data.output)
|
@@ -87,10 +55,10 @@ function TextClassification() {
|
|
87 |
// Define a cleanup function for when the component is unmounted.
|
88 |
return () =>
|
89 |
workerRef.current?.removeEventListener('message', onMessageReceived)
|
90 |
-
}, [])
|
91 |
|
92 |
const classify = useCallback(() => {
|
93 |
-
setStatus('
|
94 |
setResults([]) // Clear previous results
|
95 |
const message: TextClassificationWorkerInput = {
|
96 |
type: 'classify',
|
@@ -98,17 +66,16 @@ function TextClassification() {
|
|
98 |
model: modelInfo.id
|
99 |
}
|
100 |
workerRef.current?.postMessage(message)
|
101 |
-
}, [text, modelInfo.id])
|
102 |
|
103 |
const busy: boolean = status !== 'ready'
|
104 |
|
105 |
-
|
106 |
const handleClear = (): void => {
|
107 |
setResults([])
|
108 |
}
|
109 |
|
110 |
return (
|
111 |
-
<div className="flex flex-col h-[
|
112 |
<h1 className="text-2xl font-bold mb-4">Text Classification</h1>
|
113 |
|
114 |
<div className="flex flex-col lg:flex-row gap-4 h-full">
|
@@ -125,14 +92,14 @@ function TextClassification() {
|
|
125 |
<div className="flex gap-2 mt-4">
|
126 |
<button
|
127 |
className="flex-1 py-2 px-4 bg-blue-500 hover:bg-blue-600 rounded text-white font-medium disabled:opacity-50 disabled:cursor-not-allowed transition-colors"
|
128 |
-
disabled={busy
|
129 |
onClick={classify}
|
130 |
>
|
131 |
-
{
|
132 |
-
?
|
133 |
-
|
134 |
-
|
135 |
-
: '
|
136 |
</button>
|
137 |
<button
|
138 |
className="py-2 px-4 bg-gray-500 hover:bg-gray-600 rounded text-white font-medium transition-colors"
|
@@ -180,4 +147,4 @@ function TextClassification() {
|
|
180 |
)
|
181 |
}
|
182 |
|
183 |
-
export default TextClassification
|
|
|
1 |
+
import { useState, useRef, useEffect, useCallback } from 'react'
|
2 |
import {
|
3 |
ClassificationOutput,
|
4 |
TextClassificationWorkerInput,
|
5 |
+
WorkerMessage
|
6 |
+
} from '../types'
|
7 |
+
import { useModel } from '../contexts/ModelContext'
|
8 |
+
import { getWorker } from '../lib/workerManager'
|
|
|
|
|
9 |
|
10 |
const PLACEHOLDER_TEXTS: string[] = [
|
11 |
'I absolutely love this product! It exceeded all my expectations.',
|
|
|
18 |
'The product arrived damaged and the return process was a nightmare.',
|
19 |
'Pretty good overall. A few minor issues but mostly positive experience.',
|
20 |
'Outstanding! This company really knows how to treat their customers.'
|
21 |
+
].sort(() => Math.random() - 0.5)
|
22 |
|
23 |
function TextClassification() {
|
24 |
const [text, setText] = useState<string>(PLACEHOLDER_TEXTS.join('\n'))
|
25 |
const [results, setResults] = useState<ClassificationOutput[]>([])
|
26 |
+
const { status, setStatus, modelInfo } = useModel()
|
27 |
const workerRef = useRef<Worker | null>(null)
|
28 |
|
29 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
// We use the `useEffect` hook to setup the worker as soon as the component is mounted.
|
31 |
useEffect(() => {
|
32 |
+
if (!workerRef.current) {
|
33 |
workerRef.current = getWorker('text-classification')
|
34 |
}
|
35 |
|
|
|
36 |
// Create a callback function for messages from the worker thread.
|
37 |
const onMessageReceived = (e: MessageEvent<WorkerMessage>) => {
|
38 |
const status = e.data.status
|
39 |
+
if (status === 'ready') {
|
40 |
+
setStatus('ready')
|
41 |
+
} else if (status === 'output') {
|
42 |
setStatus('output')
|
43 |
const result = e.data.output!
|
44 |
setResults((prevResults) => [...prevResults, result])
|
45 |
console.log(result)
|
|
|
|
|
46 |
} else if (status === 'error') {
|
47 |
setStatus('error')
|
48 |
console.error(e.data.output)
|
|
|
55 |
// Define a cleanup function for when the component is unmounted.
|
56 |
return () =>
|
57 |
workerRef.current?.removeEventListener('message', onMessageReceived)
|
58 |
+
}, [setStatus])
|
59 |
|
60 |
const classify = useCallback(() => {
|
61 |
+
setStatus('loading')
|
62 |
setResults([]) // Clear previous results
|
63 |
const message: TextClassificationWorkerInput = {
|
64 |
type: 'classify',
|
|
|
66 |
model: modelInfo.id
|
67 |
}
|
68 |
workerRef.current?.postMessage(message)
|
69 |
+
}, [text, modelInfo.id, setStatus])
|
70 |
|
71 |
const busy: boolean = status !== 'ready'
|
72 |
|
|
|
73 |
const handleClear = (): void => {
|
74 |
setResults([])
|
75 |
}
|
76 |
|
77 |
return (
|
78 |
+
<div className="flex flex-col h-[60vh] max-h-[100vh] w-full p-4">
|
79 |
<h1 className="text-2xl font-bold mb-4">Text Classification</h1>
|
80 |
|
81 |
<div className="flex flex-col lg:flex-row gap-4 h-full">
|
|
|
92 |
<div className="flex gap-2 mt-4">
|
93 |
<button
|
94 |
className="flex-1 py-2 px-4 bg-blue-500 hover:bg-blue-600 rounded text-white font-medium disabled:opacity-50 disabled:cursor-not-allowed transition-colors"
|
95 |
+
disabled={busy}
|
96 |
onClick={classify}
|
97 |
>
|
98 |
+
{status === 'ready'
|
99 |
+
? !busy
|
100 |
+
? 'Classify Text'
|
101 |
+
: 'Processing...'
|
102 |
+
: 'Load model first'}
|
103 |
</button>
|
104 |
<button
|
105 |
className="py-2 px-4 bg-gray-500 hover:bg-gray-600 rounded text-white font-medium transition-colors"
|
|
|
147 |
)
|
148 |
}
|
149 |
|
150 |
+
export default TextClassification
|
src/components/ZeroShotClassification.tsx
CHANGED
@@ -4,10 +4,8 @@ import {
|
|
4 |
Section,
|
5 |
WorkerMessage,
|
6 |
ZeroShotWorkerInput,
|
7 |
-
ModelInfo
|
8 |
} from '../types'
|
9 |
import { useModel } from '../contexts/ModelContext'
|
10 |
-
import { getModelInfo } from '../lib/huggingface'
|
11 |
|
12 |
const PLACEHOLDER_REVIEWS: string[] = [
|
13 |
// battery/charging problems
|
@@ -51,7 +49,7 @@ function ZeroShotClassification() {
|
|
51 |
PLACEHOLDER_SECTIONS.map((title) => ({ title, items: [] }))
|
52 |
)
|
53 |
|
54 |
-
const {
|
55 |
|
56 |
// Create a reference to the worker object.
|
57 |
const worker = useRef<Worker | null>(null)
|
@@ -72,17 +70,8 @@ function ZeroShotClassification() {
|
|
72 |
// Create a callback function for messages from the worker thread.
|
73 |
const onMessageReceived = (e: MessageEvent<WorkerMessage>) => {
|
74 |
const status = e.data.status
|
75 |
-
if (status === '
|
76 |
-
setStatus('loading')
|
77 |
-
} else if (status === 'ready') {
|
78 |
setStatus('ready')
|
79 |
-
} else if (status === 'progress') {
|
80 |
-
setStatus('progress')
|
81 |
-
if (
|
82 |
-
e.data.output.progress &&
|
83 |
-
(e.data.output.file as string).startsWith('onnx')
|
84 |
-
)
|
85 |
-
setProgress(e.data.output.progress)
|
86 |
} else if (status === 'output') {
|
87 |
setStatus('output')
|
88 |
const { sequence, labels, scores } = e.data.output!
|
@@ -100,9 +89,6 @@ function ZeroShotClassification() {
|
|
100 |
}
|
101 |
return newSections
|
102 |
})
|
103 |
-
} else if (status === 'complete') {
|
104 |
-
setStatus('idle')
|
105 |
-
setProgress(100)
|
106 |
} else if (status === 'error') {
|
107 |
setStatus('error')
|
108 |
console.error(e.data.output)
|
@@ -118,7 +104,7 @@ function ZeroShotClassification() {
|
|
118 |
}, [sections])
|
119 |
|
120 |
const classify = useCallback(() => {
|
121 |
-
setStatus('
|
122 |
const message: ZeroShotWorkerInput = {
|
123 |
text,
|
124 |
labels: sections
|
@@ -129,7 +115,7 @@ function ZeroShotClassification() {
|
|
129 |
worker.current?.postMessage(message)
|
130 |
}, [text, sections, modelInfo.name])
|
131 |
|
132 |
-
const busy: boolean = status !== '
|
133 |
|
134 |
const handleAddCategory = (): void => {
|
135 |
setSections((sections) => {
|
|
|
4 |
Section,
|
5 |
WorkerMessage,
|
6 |
ZeroShotWorkerInput,
|
|
|
7 |
} from '../types'
|
8 |
import { useModel } from '../contexts/ModelContext'
|
|
|
9 |
|
10 |
const PLACEHOLDER_REVIEWS: string[] = [
|
11 |
// battery/charging problems
|
|
|
49 |
PLACEHOLDER_SECTIONS.map((title) => ({ title, items: [] }))
|
50 |
)
|
51 |
|
52 |
+
const { status, setStatus, modelInfo } = useModel()
|
53 |
|
54 |
// Create a reference to the worker object.
|
55 |
const worker = useRef<Worker | null>(null)
|
|
|
70 |
// Create a callback function for messages from the worker thread.
|
71 |
const onMessageReceived = (e: MessageEvent<WorkerMessage>) => {
|
72 |
const status = e.data.status
|
73 |
+
if (status === 'ready') {
|
|
|
|
|
74 |
setStatus('ready')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
75 |
} else if (status === 'output') {
|
76 |
setStatus('output')
|
77 |
const { sequence, labels, scores } = e.data.output!
|
|
|
89 |
}
|
90 |
return newSections
|
91 |
})
|
|
|
|
|
|
|
92 |
} else if (status === 'error') {
|
93 |
setStatus('error')
|
94 |
console.error(e.data.output)
|
|
|
104 |
}, [sections])
|
105 |
|
106 |
const classify = useCallback(() => {
|
107 |
+
setStatus('loading')
|
108 |
const message: ZeroShotWorkerInput = {
|
109 |
text,
|
110 |
labels: sections
|
|
|
115 |
worker.current?.postMessage(message)
|
116 |
}, [text, sections, modelInfo.name])
|
117 |
|
118 |
+
const busy: boolean = status !== 'ready'
|
119 |
|
120 |
const handleAddCategory = (): void => {
|
121 |
setSections((sections) => {
|
src/contexts/ModelContext.tsx
CHANGED
@@ -1,11 +1,21 @@
|
|
1 |
-
import React, {
|
2 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
|
4 |
interface ModelContextType {
|
|
|
|
|
5 |
progress: number
|
6 |
-
status: string
|
7 |
setProgress: (progress: number) => void
|
8 |
-
setStatus: (status: string) => void
|
9 |
modelInfo: ModelInfo
|
10 |
setModelInfo: (model: ModelInfo) => void
|
11 |
pipeline: string
|
@@ -16,22 +26,21 @@ interface ModelContextType {
|
|
16 |
setSelectedQuantization: (quantization: QuantizationType) => void
|
17 |
activeWorker: Worker | null
|
18 |
setActiveWorker: (worker: Worker | null) => void
|
19 |
-
workerLoaded: boolean
|
20 |
-
setWorkerLoaded: (workerLoaded: boolean) => void
|
21 |
}
|
22 |
|
23 |
const ModelContext = createContext<ModelContextType | undefined>(undefined)
|
24 |
|
25 |
export function ModelProvider({ children }: { children: React.ReactNode }) {
|
26 |
const [progress, setProgress] = useState<number>(0)
|
27 |
-
const [status, setStatus] = useState<
|
28 |
const [modelInfo, setModelInfo] = useState<ModelInfo>({} as ModelInfo)
|
29 |
-
const [models, setModels] = useState<ModelInfoResponse[]>(
|
|
|
|
|
30 |
const [pipeline, setPipeline] = useState<string>('text-classification')
|
31 |
-
const [selectedQuantization, setSelectedQuantization] =
|
|
|
32 |
const [activeWorker, setActiveWorker] = useState<Worker | null>(null)
|
33 |
-
const [workerLoaded, setWorkerLoaded] = useState<boolean>(false)
|
34 |
-
|
35 |
|
36 |
// set progress to 0 when model is changed
|
37 |
useEffect(() => {
|
@@ -54,9 +63,7 @@ export function ModelProvider({ children }: { children: React.ReactNode }) {
|
|
54 |
selectedQuantization,
|
55 |
setSelectedQuantization,
|
56 |
activeWorker,
|
57 |
-
setActiveWorker
|
58 |
-
workerLoaded,
|
59 |
-
setWorkerLoaded
|
60 |
}}
|
61 |
>
|
62 |
{children}
|
|
|
1 |
+
import React, {
|
2 |
+
createContext,
|
3 |
+
useContext,
|
4 |
+
useEffect,
|
5 |
+
useState
|
6 |
+
} from 'react'
|
7 |
+
import {
|
8 |
+
ModelInfo,
|
9 |
+
ModelInfoResponse,
|
10 |
+
QuantizationType,
|
11 |
+
WorkerStatus
|
12 |
+
} from '../types'
|
13 |
|
14 |
interface ModelContextType {
|
15 |
+
status: WorkerStatus
|
16 |
+
setStatus: (status: WorkerStatus) => void
|
17 |
progress: number
|
|
|
18 |
setProgress: (progress: number) => void
|
|
|
19 |
modelInfo: ModelInfo
|
20 |
setModelInfo: (model: ModelInfo) => void
|
21 |
pipeline: string
|
|
|
26 |
setSelectedQuantization: (quantization: QuantizationType) => void
|
27 |
activeWorker: Worker | null
|
28 |
setActiveWorker: (worker: Worker | null) => void
|
|
|
|
|
29 |
}
|
30 |
|
31 |
const ModelContext = createContext<ModelContextType | undefined>(undefined)
|
32 |
|
33 |
export function ModelProvider({ children }: { children: React.ReactNode }) {
|
34 |
const [progress, setProgress] = useState<number>(0)
|
35 |
+
const [status, setStatus] = useState<WorkerStatus>('initiate')
|
36 |
const [modelInfo, setModelInfo] = useState<ModelInfo>({} as ModelInfo)
|
37 |
+
const [models, setModels] = useState<ModelInfoResponse[]>(
|
38 |
+
[] as ModelInfoResponse[]
|
39 |
+
)
|
40 |
const [pipeline, setPipeline] = useState<string>('text-classification')
|
41 |
+
const [selectedQuantization, setSelectedQuantization] =
|
42 |
+
useState<QuantizationType>('int8')
|
43 |
const [activeWorker, setActiveWorker] = useState<Worker | null>(null)
|
|
|
|
|
44 |
|
45 |
// set progress to 0 when model is changed
|
46 |
useEffect(() => {
|
|
|
63 |
selectedQuantization,
|
64 |
setSelectedQuantization,
|
65 |
activeWorker,
|
66 |
+
setActiveWorker
|
|
|
|
|
67 |
}}
|
68 |
>
|
69 |
{children}
|
src/lib/huggingface.ts
CHANGED
@@ -114,7 +114,14 @@ const getModelsByPipeline = async (
|
|
114 |
}
|
115 |
const models = await response.json()
|
116 |
if (pipeline_tag === 'text-classification') {
|
117 |
-
return models
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
118 |
}
|
119 |
return models.slice(0, 10)
|
120 |
}
|
|
|
114 |
}
|
115 |
const models = await response.json()
|
116 |
if (pipeline_tag === 'text-classification') {
|
117 |
+
return models
|
118 |
+
.filter(
|
119 |
+
(model: ModelInfoResponse) =>
|
120 |
+
!model.tags.includes('reranker') &&
|
121 |
+
!model.id.includes('reranker') &&
|
122 |
+
!model.tags.includes('sentence-transformers')
|
123 |
+
)
|
124 |
+
.slice(0, 10)
|
125 |
}
|
126 |
return models.slice(0, 10)
|
127 |
}
|
src/types.ts
CHANGED
@@ -9,8 +9,10 @@ export interface ClassificationOutput {
|
|
9 |
scores: number[]
|
10 |
}
|
11 |
|
|
|
|
|
12 |
export interface WorkerMessage {
|
13 |
-
status:
|
14 |
progress?: number
|
15 |
error?: string
|
16 |
output?: any
|
@@ -28,7 +30,6 @@ export interface TextClassificationWorkerInput {
|
|
28 |
model: string
|
29 |
}
|
30 |
|
31 |
-
export type AppStatus = 'idle' | 'loading' | 'processing'
|
32 |
|
33 |
type q8 = 'q8' | 'int8' | 'bnb8' | 'uint8'
|
34 |
type q4 = 'q4' | 'bnb4'
|
|
|
9 |
scores: number[]
|
10 |
}
|
11 |
|
12 |
+
export type WorkerStatus = 'initiate' | 'ready' | 'output' | 'loading' | 'error'
|
13 |
+
|
14 |
export interface WorkerMessage {
|
15 |
+
status: WorkerStatus
|
16 |
progress?: number
|
17 |
error?: string
|
18 |
output?: any
|
|
|
30 |
model: string
|
31 |
}
|
32 |
|
|
|
33 |
|
34 |
type q8 = 'q8' | 'int8' | 'bnb8' | 'uint8'
|
35 |
type q4 = 'q4' | 'bnb4'
|