Enhance model handling and loading: add dtype support, improve fetching logic, and refine component interactions
Browse files- public/workers/text-classification.js +9 -6
- src/App.tsx +15 -10
- src/components/ModelInfo.tsx +5 -3
- src/components/ModelLoader.tsx +71 -40
- src/components/ModelSelector.tsx +29 -11
- src/components/TextClassification.tsx +8 -38
- src/contexts/ModelContext.tsx +17 -1
- src/lib/huggingface.ts +80 -43
- src/types.ts +1 -1
public/workers/text-classification.js
CHANGED
@@ -1,21 +1,23 @@
|
|
1 |
/* eslint-disable no-restricted-globals */
|
2 |
-
import { pipeline } from 'https://cdn.jsdelivr.net/npm/@huggingface/transformers@3.6.3'
|
3 |
|
4 |
class MyTextClassificationPipeline {
|
5 |
static task = 'text-classification'
|
6 |
static instance = null
|
7 |
|
8 |
-
static async getInstance(model, progress_callback = null) {
|
9 |
-
this.instance = pipeline(
|
10 |
-
|
11 |
-
|
|
|
|
|
12 |
return this.instance
|
13 |
}
|
14 |
}
|
15 |
|
16 |
// Listen for messages from the main thread
|
17 |
self.addEventListener('message', async (event) => {
|
18 |
-
const { type, model, text } = event.data
|
19 |
|
20 |
if (!model) {
|
21 |
self.postMessage({
|
@@ -28,6 +30,7 @@ self.addEventListener('message', async (event) => {
|
|
28 |
// Retrieve the pipeline. This will download the model if not already cached.
|
29 |
const classifier = await MyTextClassificationPipeline.getInstance(
|
30 |
model,
|
|
|
31 |
(x) => {
|
32 |
self.postMessage({ status: 'loading', output: x })
|
33 |
}
|
|
|
1 |
/* eslint-disable no-restricted-globals */
|
2 |
+
import { pipeline } from 'https://cdn.jsdelivr.net/npm/@huggingface/transformers@3.6.3'
|
3 |
|
4 |
class MyTextClassificationPipeline {
|
5 |
static task = 'text-classification'
|
6 |
static instance = null
|
7 |
|
8 |
+
static async getInstance(model, dtype = 'fp32', progress_callback = null) {
|
9 |
+
this.instance = pipeline(
|
10 |
+
this.task,
|
11 |
+
model,
|
12 |
+
{ dtype, progress_callback },
|
13 |
+
)
|
14 |
return this.instance
|
15 |
}
|
16 |
}
|
17 |
|
18 |
// Listen for messages from the main thread
|
19 |
self.addEventListener('message', async (event) => {
|
20 |
+
const { type, model, dtype, text } = event.data
|
21 |
|
22 |
if (!model) {
|
23 |
self.postMessage({
|
|
|
30 |
// Retrieve the pipeline. This will download the model if not already cached.
|
31 |
const classifier = await MyTextClassificationPipeline.getInstance(
|
32 |
model,
|
33 |
+
dtype,
|
34 |
(x) => {
|
35 |
self.postMessage({ status: 'loading', output: x })
|
36 |
}
|
src/App.tsx
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
import { useEffect
|
2 |
import PipelineSelector from './components/PipelineSelector'
|
3 |
import ZeroShotClassification from './components/ZeroShotClassification'
|
4 |
import TextClassification from './components/TextClassification'
|
@@ -10,19 +10,24 @@ import ModelInfo from './components/ModelInfo'
|
|
10 |
import ModelReadme from './components/ModelReadme'
|
11 |
|
12 |
function App() {
|
13 |
-
const { pipeline, setPipeline, setModels, setModelInfo, modelInfo } = useModel()
|
14 |
-
const [isFetching, setIsFetching] = useState(false)
|
15 |
|
16 |
useEffect(() => {
|
17 |
setModelInfo(null)
|
|
|
|
|
|
|
18 |
const fetchModels = async () => {
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
|
|
|
|
|
|
23 |
}
|
24 |
fetchModels()
|
25 |
-
}, [setModels, setModelInfo, pipeline])
|
26 |
|
27 |
return (
|
28 |
<div className="min-h-screen bg-gradient-to-br from-blue-50 to-indigo-100">
|
@@ -47,12 +52,12 @@ function App() {
|
|
47 |
<span className="text-lg font-semibold text-gray-900 block">
|
48 |
Select Model
|
49 |
</span>
|
50 |
-
<ModelSelector
|
51 |
</div>
|
52 |
</div>
|
53 |
|
54 |
<div className="ml-6">
|
55 |
-
<ModelInfo
|
56 |
</div>
|
57 |
</div>
|
58 |
|
|
|
1 |
+
import { useEffect } from 'react'
|
2 |
import PipelineSelector from './components/PipelineSelector'
|
3 |
import ZeroShotClassification from './components/ZeroShotClassification'
|
4 |
import TextClassification from './components/TextClassification'
|
|
|
10 |
import ModelReadme from './components/ModelReadme'
|
11 |
|
12 |
function App() {
|
13 |
+
const { pipeline, setPipeline, setModels, setModelInfo, modelInfo, setIsFetching} = useModel()
|
|
|
14 |
|
15 |
useEffect(() => {
|
16 |
setModelInfo(null)
|
17 |
+
setModels([])
|
18 |
+
setIsFetching(true)
|
19 |
+
|
20 |
const fetchModels = async () => {
|
21 |
+
try {
|
22 |
+
const fetchedModels = await getModelsByPipeline(pipeline)
|
23 |
+
setModels(fetchedModels)
|
24 |
+
} catch (error) {
|
25 |
+
console.error('Error fetching models:', error)
|
26 |
+
setIsFetching(false)
|
27 |
+
}
|
28 |
}
|
29 |
fetchModels()
|
30 |
+
}, [setModels, setModelInfo, setIsFetching, pipeline])
|
31 |
|
32 |
return (
|
33 |
<div className="min-h-screen bg-gradient-to-br from-blue-50 to-indigo-100">
|
|
|
52 |
<span className="text-lg font-semibold text-gray-900 block">
|
53 |
Select Model
|
54 |
</span>
|
55 |
+
<ModelSelector />
|
56 |
</div>
|
57 |
</div>
|
58 |
|
59 |
<div className="ml-6">
|
60 |
+
<ModelInfo />
|
61 |
</div>
|
62 |
</div>
|
63 |
|
src/components/ModelInfo.tsx
CHANGED
@@ -12,7 +12,7 @@ import { getModelSize } from '../lib/huggingface'
|
|
12 |
import { useModel } from '../contexts/ModelContext'
|
13 |
import ModelLoader from './ModelLoader'
|
14 |
|
15 |
-
const ModelInfo = (
|
16 |
const formatNumber = (num: number) => {
|
17 |
if (num >= 1000000000) {
|
18 |
return (num / 1000000000).toFixed(1) + 'B'
|
@@ -25,8 +25,10 @@ const ModelInfo = ({ isFetching }: { isFetching: boolean }) => {
|
|
25 |
}
|
26 |
|
27 |
const {
|
|
|
28 |
modelInfo,
|
29 |
-
selectedQuantization
|
|
|
30 |
} = useModel()
|
31 |
|
32 |
const ModelInfoSkeleton = () => (
|
@@ -64,7 +66,7 @@ const ModelInfo = ({ isFetching }: { isFetching: boolean }) => {
|
|
64 |
</div>
|
65 |
)
|
66 |
|
67 |
-
if (!modelInfo || isFetching) {
|
68 |
return <ModelInfoSkeleton />
|
69 |
}
|
70 |
|
|
|
12 |
import { useModel } from '../contexts/ModelContext'
|
13 |
import ModelLoader from './ModelLoader'
|
14 |
|
15 |
+
const ModelInfo = () => {
|
16 |
const formatNumber = (num: number) => {
|
17 |
if (num >= 1000000000) {
|
18 |
return (num / 1000000000).toFixed(1) + 'B'
|
|
|
25 |
}
|
26 |
|
27 |
const {
|
28 |
+
models,
|
29 |
modelInfo,
|
30 |
+
selectedQuantization,
|
31 |
+
isFetching
|
32 |
} = useModel()
|
33 |
|
34 |
const ModelInfoSkeleton = () => (
|
|
|
66 |
</div>
|
67 |
)
|
68 |
|
69 |
+
if (!modelInfo || isFetching || models.length === 0) {
|
70 |
return <ModelInfoSkeleton />
|
71 |
}
|
72 |
|
src/components/ModelLoader.tsx
CHANGED
@@ -15,13 +15,17 @@ const ModelLoader = () => {
|
|
15 |
setProgress,
|
16 |
activeWorker,
|
17 |
setActiveWorker,
|
18 |
-
pipeline
|
|
|
|
|
|
|
19 |
} = useModel()
|
20 |
|
|
|
21 |
useEffect(() => {
|
22 |
if (!modelInfo) return
|
23 |
|
24 |
-
if (modelInfo.isCompatible
|
25 |
const quantizations = modelInfo.supportedQuantizations
|
26 |
let defaultQuant: QuantizationType = 'fp32'
|
27 |
|
@@ -35,10 +39,9 @@ const ModelLoader = () => {
|
|
35 |
|
36 |
setSelectedQuantization(defaultQuant)
|
37 |
}
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
])
|
42 |
|
43 |
useEffect(() => {
|
44 |
if (!modelInfo) return
|
@@ -48,14 +51,18 @@ const ModelLoader = () => {
|
|
48 |
return
|
49 |
}
|
50 |
|
51 |
-
|
52 |
-
|
|
|
|
|
|
|
53 |
|
54 |
const onMessageReceived = (e: MessageEvent<WorkerMessage>) => {
|
55 |
const { status, output } = e.data
|
56 |
if (status === 'ready') {
|
57 |
setStatus('ready')
|
58 |
-
|
|
|
59 |
setStatus('loading')
|
60 |
if (
|
61 |
output.progress &&
|
@@ -64,6 +71,14 @@ const ModelLoader = () => {
|
|
64 |
) {
|
65 |
setProgress(output.progress)
|
66 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
67 |
}
|
68 |
}
|
69 |
|
@@ -73,24 +88,30 @@ const ModelLoader = () => {
|
|
73 |
newWorker.removeEventListener('message', onMessageReceived)
|
74 |
// terminateWorker(pipeline);
|
75 |
}
|
76 |
-
}, [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
77 |
|
78 |
const loadModel = useCallback(() => {
|
79 |
if (!modelInfo || !selectedQuantization) return
|
80 |
|
81 |
-
setStatus('loading')
|
82 |
const message = {
|
83 |
type: 'load',
|
84 |
model: modelInfo.name,
|
85 |
-
|
86 |
}
|
87 |
activeWorker?.postMessage(message)
|
88 |
-
}, [modelInfo, selectedQuantization,
|
89 |
-
|
90 |
-
const ready: boolean = status === 'ready'
|
91 |
-
const busy: boolean = status === 'loading'
|
92 |
|
93 |
-
if (!modelInfo?.isCompatible
|
94 |
return null
|
95 |
}
|
96 |
|
@@ -100,42 +121,52 @@ const ModelLoader = () => {
|
|
100 |
|
101 |
<div className="flex items-center justify-between space-x-4">
|
102 |
<div className="flex items-center space-x-2">
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
123 |
</div>
|
124 |
|
125 |
{selectedQuantization && (
|
126 |
<div className="flex justify-center">
|
127 |
<button
|
128 |
className="w-32 py-2 px-4 bg-green-500 hover:bg-green-600 rounded text-white font-medium disabled:opacity-50 disabled:cursor-not-allowed transition-colors text-sm inline-flex items-center text-center justify-center space-x-2"
|
129 |
-
disabled={
|
130 |
onClick={loadModel}
|
131 |
>
|
132 |
-
{status === 'loading' && (
|
133 |
<>
|
134 |
<Loader className="animate-spin h-4 w-4" />
|
135 |
<span>{progress.toFixed(0)}%</span>
|
136 |
</>
|
|
|
|
|
137 |
)}
|
138 |
-
{!ready && !busy ? <span>Load Model</span> : !ready ? null : <span>Model Ready</span>}
|
139 |
</button>
|
140 |
</div>
|
141 |
)}
|
|
|
15 |
setProgress,
|
16 |
activeWorker,
|
17 |
setActiveWorker,
|
18 |
+
pipeline,
|
19 |
+
setResults,
|
20 |
+
hasBeenLoaded,
|
21 |
+
setHasBeenLoaded
|
22 |
} = useModel()
|
23 |
|
24 |
+
|
25 |
useEffect(() => {
|
26 |
if (!modelInfo) return
|
27 |
|
28 |
+
if (modelInfo.isCompatible) {
|
29 |
const quantizations = modelInfo.supportedQuantizations
|
30 |
let defaultQuant: QuantizationType = 'fp32'
|
31 |
|
|
|
39 |
|
40 |
setSelectedQuantization(defaultQuant)
|
41 |
}
|
42 |
+
|
43 |
+
setHasBeenLoaded(false)
|
44 |
+
}, [modelInfo, setSelectedQuantization, setHasBeenLoaded])
|
|
|
45 |
|
46 |
useEffect(() => {
|
47 |
if (!modelInfo) return
|
|
|
51 |
return
|
52 |
}
|
53 |
|
54 |
+
if (!hasBeenLoaded) {
|
55 |
+
setStatus('initiate')
|
56 |
+
setActiveWorker(newWorker)
|
57 |
+
}
|
58 |
+
|
59 |
|
60 |
const onMessageReceived = (e: MessageEvent<WorkerMessage>) => {
|
61 |
const { status, output } = e.data
|
62 |
if (status === 'ready') {
|
63 |
setStatus('ready')
|
64 |
+
setHasBeenLoaded(true)
|
65 |
+
} else if (status === 'loading' && output && !hasBeenLoaded) {
|
66 |
setStatus('loading')
|
67 |
if (
|
68 |
output.progress &&
|
|
|
71 |
) {
|
72 |
setProgress(output.progress)
|
73 |
}
|
74 |
+
} else if (status === 'output') {
|
75 |
+
setStatus('output')
|
76 |
+
const result = e.data.output!
|
77 |
+
setResults((prev: any[]) => [...prev, result])
|
78 |
+
// console.log(result)
|
79 |
+
} else if (status === 'error') {
|
80 |
+
setStatus('error')
|
81 |
+
console.error(e.data.output)
|
82 |
}
|
83 |
}
|
84 |
|
|
|
88 |
newWorker.removeEventListener('message', onMessageReceived)
|
89 |
// terminateWorker(pipeline);
|
90 |
}
|
91 |
+
}, [
|
92 |
+
pipeline,
|
93 |
+
modelInfo,
|
94 |
+
selectedQuantization,
|
95 |
+
setActiveWorker,
|
96 |
+
setStatus,
|
97 |
+
setProgress,
|
98 |
+
setResults,
|
99 |
+
hasBeenLoaded,
|
100 |
+
setHasBeenLoaded
|
101 |
+
])
|
102 |
|
103 |
const loadModel = useCallback(() => {
|
104 |
if (!modelInfo || !selectedQuantization) return
|
105 |
|
|
|
106 |
const message = {
|
107 |
type: 'load',
|
108 |
model: modelInfo.name,
|
109 |
+
dtype: selectedQuantization ?? 'fp32'
|
110 |
}
|
111 |
activeWorker?.postMessage(message)
|
112 |
+
}, [modelInfo, selectedQuantization, activeWorker])
|
|
|
|
|
|
|
113 |
|
114 |
+
if (!modelInfo?.isCompatible) {
|
115 |
return null
|
116 |
}
|
117 |
|
|
|
121 |
|
122 |
<div className="flex items-center justify-between space-x-4">
|
123 |
<div className="flex items-center space-x-2">
|
124 |
+
{modelInfo.supportedQuantizations.length > 1 ? (
|
125 |
+
<>
|
126 |
+
<span className="text-xs text-gray-600 font-medium">
|
127 |
+
Quantization:
|
128 |
+
</span>
|
129 |
+
|
130 |
+
<div className="relative">
|
131 |
+
<select
|
132 |
+
value={selectedQuantization || ''}
|
133 |
+
onChange={(e) =>
|
134 |
+
setSelectedQuantization(e.target.value as QuantizationType)
|
135 |
+
}
|
136 |
+
className="appearance-none bg-white border border-gray-300 rounded-md px-3 py-1 pr-8 text-xs text-gray-700 focus:outline-none focus:ring-2 focus:ring-blue-500 focus:border-blue-500"
|
137 |
+
>
|
138 |
+
<option value="">Select quantization</option>
|
139 |
+
{modelInfo.supportedQuantizations.map((quant) => (
|
140 |
+
<option key={quant} value={quant}>
|
141 |
+
{quant}
|
142 |
+
</option>
|
143 |
+
))}
|
144 |
+
</select>
|
145 |
+
<ChevronDown className="absolute right-2 top-1/2 transform -translate-y-1/2 w-3 h-3 text-gray-400 pointer-events-none" />
|
146 |
+
</div>
|
147 |
+
</>
|
148 |
+
) : (
|
149 |
+
<span className="text-xs text-gray-600 font-medium white-space-break-spaces">
|
150 |
+
No quantization available. Using fp32
|
151 |
+
</span>
|
152 |
+
)}
|
153 |
</div>
|
154 |
|
155 |
{selectedQuantization && (
|
156 |
<div className="flex justify-center">
|
157 |
<button
|
158 |
className="w-32 py-2 px-4 bg-green-500 hover:bg-green-600 rounded text-white font-medium disabled:opacity-50 disabled:cursor-not-allowed transition-colors text-sm inline-flex items-center text-center justify-center space-x-2"
|
159 |
+
disabled={hasBeenLoaded}
|
160 |
onClick={loadModel}
|
161 |
>
|
162 |
+
{status === 'loading' && !hasBeenLoaded ? (
|
163 |
<>
|
164 |
<Loader className="animate-spin h-4 w-4" />
|
165 |
<span>{progress.toFixed(0)}%</span>
|
166 |
</>
|
167 |
+
) : (
|
168 |
+
<span>{!hasBeenLoaded ? 'Load Model' : 'Model Ready'}</span>
|
169 |
)}
|
|
|
170 |
</button>
|
171 |
</div>
|
172 |
)}
|
src/components/ModelSelector.tsx
CHANGED
@@ -22,8 +22,15 @@ import {
|
|
22 |
|
23 |
type SortOption = 'likes' | 'downloads' | 'createdAt' | 'name'
|
24 |
|
25 |
-
function ModelSelector(
|
26 |
-
const {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
const [sortBy, setSortBy] = useState<SortOption>('downloads')
|
28 |
const [sortOrder, setSortOrder] = useState<'asc' | 'desc'>('desc')
|
29 |
const [showCustomInput, setShowCustomInput] = useState(false)
|
@@ -102,25 +109,36 @@ function ModelSelector({ isFetching }: { isFetching: boolean }) {
|
|
102 |
baseId: modelInfoResponse.baseId,
|
103 |
readme: modelInfoResponse.readme
|
104 |
}
|
|
|
|
|
|
|
105 |
setModelInfo(modelInfo)
|
106 |
setIsCustomModel(isCustom)
|
|
|
107 |
} catch (error) {
|
108 |
console.error('Error fetching model info:', error)
|
|
|
109 |
throw error
|
110 |
}
|
111 |
},
|
112 |
-
[setModelInfo, pipeline]
|
113 |
)
|
114 |
|
115 |
-
//
|
116 |
useEffect(() => {
|
117 |
-
|
|
|
|
|
|
|
|
|
118 |
|
119 |
-
|
120 |
-
|
|
|
|
|
121 |
fetchAndSetModelInfo(firstModel.id, false)
|
122 |
}
|
123 |
-
}, [
|
124 |
|
125 |
const handleModelSelect = (modelId: string) => {
|
126 |
fetchAndSetModelInfo(modelId, false)
|
@@ -160,8 +178,8 @@ function ModelSelector({ isFetching }: { isFetching: boolean }) {
|
|
160 |
const handleRemoveCustomModel = () => {
|
161 |
setIsCustomModel(false)
|
162 |
// Load the first model from the list
|
163 |
-
if (
|
164 |
-
fetchAndSetModelInfo(
|
165 |
}
|
166 |
}
|
167 |
|
@@ -226,7 +244,7 @@ function ModelSelector({ isFetching }: { isFetching: boolean }) {
|
|
226 |
)
|
227 |
}
|
228 |
|
229 |
-
if (isFetching) {
|
230 |
return (
|
231 |
<div className="relative">
|
232 |
<div className="w-full px-3 py-2 border border-gray-300 rounded-md bg-white flex items-center justify-between animate-pulse h-10">
|
|
|
22 |
|
23 |
type SortOption = 'likes' | 'downloads' | 'createdAt' | 'name'
|
24 |
|
25 |
+
function ModelSelector() {
|
26 |
+
const {
|
27 |
+
models,
|
28 |
+
setModelInfo,
|
29 |
+
modelInfo,
|
30 |
+
pipeline,
|
31 |
+
isFetching,
|
32 |
+
setIsFetching
|
33 |
+
} = useModel()
|
34 |
const [sortBy, setSortBy] = useState<SortOption>('downloads')
|
35 |
const [sortOrder, setSortOrder] = useState<'asc' | 'desc'>('desc')
|
36 |
const [showCustomInput, setShowCustomInput] = useState(false)
|
|
|
109 |
baseId: modelInfoResponse.baseId,
|
110 |
readme: modelInfoResponse.readme
|
111 |
}
|
112 |
+
|
113 |
+
console.log('Fetched model info:', modelInfoResponse)
|
114 |
+
|
115 |
setModelInfo(modelInfo)
|
116 |
setIsCustomModel(isCustom)
|
117 |
+
setIsFetching(false)
|
118 |
} catch (error) {
|
119 |
console.error('Error fetching model info:', error)
|
120 |
+
setIsFetching(false)
|
121 |
throw error
|
122 |
}
|
123 |
},
|
124 |
+
[setModelInfo, pipeline, setIsFetching]
|
125 |
)
|
126 |
|
127 |
+
// Reset custom model state when pipeline changes
|
128 |
useEffect(() => {
|
129 |
+
setIsCustomModel(false)
|
130 |
+
setShowCustomInput(false)
|
131 |
+
setCustomModelName('')
|
132 |
+
setCustomModelError('')
|
133 |
+
}, [pipeline])
|
134 |
|
135 |
+
// Update modelInfo to first model when models are loaded and no custom model is selected
|
136 |
+
useEffect(() => {
|
137 |
+
if (models.length > 0 && !isCustomModel && !modelInfo) {
|
138 |
+
const firstModel = sortedModels[0]
|
139 |
fetchAndSetModelInfo(firstModel.id, false)
|
140 |
}
|
141 |
+
}, [models, sortedModels, fetchAndSetModelInfo, isCustomModel, modelInfo])
|
142 |
|
143 |
const handleModelSelect = (modelId: string) => {
|
144 |
fetchAndSetModelInfo(modelId, false)
|
|
|
178 |
const handleRemoveCustomModel = () => {
|
179 |
setIsCustomModel(false)
|
180 |
// Load the first model from the list
|
181 |
+
if (sortedModels.length > 0) {
|
182 |
+
fetchAndSetModelInfo(sortedModels[0].id, false)
|
183 |
}
|
184 |
}
|
185 |
|
|
|
244 |
)
|
245 |
}
|
246 |
|
247 |
+
if (isFetching || models.length === 0) {
|
248 |
return (
|
249 |
<div className="relative">
|
250 |
<div className="w-full px-3 py-2 border border-gray-300 rounded-md bg-white flex items-center justify-between animate-pulse h-10">
|
src/components/TextClassification.tsx
CHANGED
@@ -22,52 +22,23 @@ const PLACEHOLDER_TEXTS: string[] = [
|
|
22 |
|
23 |
function TextClassification() {
|
24 |
const [text, setText] = useState<string>(PLACEHOLDER_TEXTS.join('\n'))
|
25 |
-
const
|
26 |
-
const { status, setStatus, modelInfo } = useModel()
|
27 |
-
const workerRef = useRef<Worker | null>(null)
|
28 |
|
29 |
|
30 |
-
// We use the `useEffect` hook to setup the worker as soon as the component is mounted.
|
31 |
-
useEffect(() => {
|
32 |
-
if (!workerRef.current) {
|
33 |
-
workerRef.current = getWorker('text-classification')
|
34 |
-
}
|
35 |
-
|
36 |
-
// Create a callback function for messages from the worker thread.
|
37 |
-
const onMessageReceived = (e: MessageEvent<WorkerMessage>) => {
|
38 |
-
const status = e.data.status
|
39 |
-
if (status === 'ready') {
|
40 |
-
setStatus('ready')
|
41 |
-
} else if (status === 'output') {
|
42 |
-
setStatus('output')
|
43 |
-
const result = e.data.output!
|
44 |
-
setResults((prevResults) => [...prevResults, result])
|
45 |
-
console.log(result)
|
46 |
-
} else if (status === 'error') {
|
47 |
-
setStatus('error')
|
48 |
-
console.error(e.data.output)
|
49 |
-
}
|
50 |
-
}
|
51 |
-
|
52 |
-
// Attach the callback function as an event listener.
|
53 |
-
workerRef.current?.addEventListener('message', onMessageReceived)
|
54 |
-
|
55 |
-
// Define a cleanup function for when the component is unmounted.
|
56 |
-
return () =>
|
57 |
-
workerRef.current?.removeEventListener('message', onMessageReceived)
|
58 |
-
}, [setStatus])
|
59 |
|
60 |
const classify = useCallback(() => {
|
61 |
-
if (!modelInfo)
|
62 |
-
|
|
|
|
|
63 |
setResults([]) // Clear previous results
|
64 |
const message: TextClassificationWorkerInput = {
|
65 |
type: 'classify',
|
66 |
text,
|
67 |
model: modelInfo.id
|
68 |
}
|
69 |
-
|
70 |
-
}, [text, modelInfo, setStatus])
|
71 |
|
72 |
const busy: boolean = status !== 'ready'
|
73 |
|
@@ -96,8 +67,7 @@ function TextClassification() {
|
|
96 |
disabled={busy}
|
97 |
onClick={classify}
|
98 |
>
|
99 |
-
{
|
100 |
-
? !busy
|
101 |
? 'Classify Text'
|
102 |
: 'Processing...'
|
103 |
: 'Load model first'}
|
|
|
22 |
|
23 |
function TextClassification() {
|
24 |
const [text, setText] = useState<string>(PLACEHOLDER_TEXTS.join('\n'))
|
25 |
+
const { activeWorker, status, setStatus, modelInfo, results, setResults, hasBeenLoaded} = useModel()
|
|
|
|
|
26 |
|
27 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
|
29 |
const classify = useCallback(() => {
|
30 |
+
if (!modelInfo || !activeWorker) {
|
31 |
+
console.error('Model info or worker is not available')
|
32 |
+
return
|
33 |
+
}
|
34 |
setResults([]) // Clear previous results
|
35 |
const message: TextClassificationWorkerInput = {
|
36 |
type: 'classify',
|
37 |
text,
|
38 |
model: modelInfo.id
|
39 |
}
|
40 |
+
activeWorker.postMessage(message)
|
41 |
+
}, [text, modelInfo, setStatus, activeWorker])
|
42 |
|
43 |
const busy: boolean = status !== 'ready'
|
44 |
|
|
|
67 |
disabled={busy}
|
68 |
onClick={classify}
|
69 |
>
|
70 |
+
{hasBeenLoaded ? !busy
|
|
|
71 |
? 'Classify Text'
|
72 |
: 'Processing...'
|
73 |
: 'Load model first'}
|
src/contexts/ModelContext.tsx
CHANGED
@@ -26,6 +26,12 @@ interface ModelContextType {
|
|
26 |
setSelectedQuantization: (quantization: QuantizationType) => void
|
27 |
activeWorker: Worker | null
|
28 |
setActiveWorker: (worker: Worker | null) => void
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
}
|
30 |
|
31 |
const ModelContext = createContext<ModelContextType | undefined>(undefined)
|
@@ -41,6 +47,10 @@ export function ModelProvider({ children }: { children: React.ReactNode }) {
|
|
41 |
const [selectedQuantization, setSelectedQuantization] =
|
42 |
useState<QuantizationType>('int8')
|
43 |
const [activeWorker, setActiveWorker] = useState<Worker | null>(null)
|
|
|
|
|
|
|
|
|
44 |
|
45 |
// set progress to 0 when model is changed
|
46 |
useEffect(() => {
|
@@ -63,7 +73,13 @@ export function ModelProvider({ children }: { children: React.ReactNode }) {
|
|
63 |
selectedQuantization,
|
64 |
setSelectedQuantization,
|
65 |
activeWorker,
|
66 |
-
setActiveWorker
|
|
|
|
|
|
|
|
|
|
|
|
|
67 |
}}
|
68 |
>
|
69 |
{children}
|
|
|
26 |
setSelectedQuantization: (quantization: QuantizationType) => void
|
27 |
activeWorker: Worker | null
|
28 |
setActiveWorker: (worker: Worker | null) => void
|
29 |
+
isFetching: boolean
|
30 |
+
setIsFetching: (isFetching: boolean) => void
|
31 |
+
results: any[]
|
32 |
+
setResults: React.Dispatch<React.SetStateAction<any[]>>
|
33 |
+
hasBeenLoaded: boolean
|
34 |
+
setHasBeenLoaded: (hasBeenLoaded: boolean) => void
|
35 |
}
|
36 |
|
37 |
const ModelContext = createContext<ModelContextType | undefined>(undefined)
|
|
|
47 |
const [selectedQuantization, setSelectedQuantization] =
|
48 |
useState<QuantizationType>('int8')
|
49 |
const [activeWorker, setActiveWorker] = useState<Worker | null>(null)
|
50 |
+
const [isFetching, setIsFetching] = useState(false)
|
51 |
+
const [results, setResults] = useState<any[]>([])
|
52 |
+
const [hasBeenLoaded, setHasBeenLoaded] = useState(false)
|
53 |
+
|
54 |
|
55 |
// set progress to 0 when model is changed
|
56 |
useEffect(() => {
|
|
|
73 |
selectedQuantization,
|
74 |
setSelectedQuantization,
|
75 |
activeWorker,
|
76 |
+
setActiveWorker,
|
77 |
+
isFetching,
|
78 |
+
setIsFetching,
|
79 |
+
results,
|
80 |
+
setResults,
|
81 |
+
hasBeenLoaded,
|
82 |
+
setHasBeenLoaded
|
83 |
}}
|
84 |
>
|
85 |
{children}
|
src/lib/huggingface.ts
CHANGED
@@ -1,7 +1,10 @@
|
|
1 |
-
import { supportedPipelines } from
|
2 |
-
import { ModelInfoResponse, QuantizationType } from
|
3 |
|
4 |
-
const getModelInfo = async (
|
|
|
|
|
|
|
5 |
const token = process.env.REACT_APP_HUGGINGFACE_TOKEN
|
6 |
|
7 |
if (!token) {
|
@@ -23,36 +26,53 @@ const getModelInfo = async (modelName: string, pipeline: string): Promise<ModelI
|
|
23 |
if (!response.ok) {
|
24 |
throw new Error(`Failed to fetch model info: ${response.statusText}`)
|
25 |
}
|
26 |
-
|
27 |
const modelData: ModelInfoResponse = await response.json()
|
28 |
-
|
29 |
const requiredFiles = [
|
30 |
'config.json',
|
31 |
'tokenizer.json',
|
32 |
-
'tokenizer_config.json'
|
33 |
]
|
34 |
-
|
35 |
-
const siblingFiles = modelData.siblings?.map(s => s.rfilename) || []
|
36 |
-
const missingFiles = requiredFiles.filter(file => !siblingFiles.includes(file))
|
37 |
-
const hasOnnxFolder = siblingFiles.some((file) => file.endsWith('.onnx') && file.startsWith('onnx/'))
|
38 |
|
39 |
-
const
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
|
41 |
-
|
42 |
let incompatibilityReason = ''
|
43 |
if (!modelData.tags.includes(pipeline)) {
|
44 |
-
const expectedPipelines = modelData.tags
|
45 |
-
|
46 |
-
|
47 |
-
incompatibilityReason
|
48 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
49 |
incompatibilityReason += '- Folder onnx/ is missing\n'
|
50 |
}
|
51 |
-
const supportedQuantizations =
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
|
|
|
|
|
|
|
|
56 |
uniqueSupportedQuantizations.sort((a, b) => {
|
57 |
const getNumericValue = (str: string) => {
|
58 |
const match = str.match(/(\d+)/)
|
@@ -64,7 +84,9 @@ const getModelInfo = async (modelName: string, pipeline: string): Promise<ModelI
|
|
64 |
// Fetch README content
|
65 |
const fetchReadme = async (modelId: string): Promise<string> => {
|
66 |
try {
|
67 |
-
const readmeResponse = await fetch(
|
|
|
|
|
68 |
if (readmeResponse.ok) {
|
69 |
return await readmeResponse.text()
|
70 |
}
|
@@ -74,7 +96,7 @@ const getModelInfo = async (modelName: string, pipeline: string): Promise<ModelI
|
|
74 |
return ''
|
75 |
}
|
76 |
|
77 |
-
const baseModel = modelData.cardData?.base_model ?? modelData.modelId
|
78 |
if (baseModel && !modelData.safetensors) {
|
79 |
const baseModelResponse = await fetch(
|
80 |
`https://huggingface.co/api/models/${baseModel}`,
|
@@ -89,21 +111,22 @@ const getModelInfo = async (modelName: string, pipeline: string): Promise<ModelI
|
|
89 |
if (baseModelResponse.ok) {
|
90 |
const baseModelData: ModelInfoResponse = await baseModelResponse.json()
|
91 |
const readme = await fetchReadme(baseModel)
|
92 |
-
|
93 |
return {
|
94 |
...baseModelData,
|
95 |
id: modelData.id,
|
96 |
baseId: baseModel,
|
97 |
isCompatible,
|
98 |
incompatibilityReason,
|
99 |
-
supportedQuantizations:
|
|
|
100 |
readme
|
101 |
}
|
102 |
}
|
103 |
}
|
104 |
-
|
105 |
const readme = await fetchReadme(modelData.id)
|
106 |
-
|
107 |
return {
|
108 |
...modelData,
|
109 |
isCompatible,
|
@@ -135,7 +158,9 @@ const getModelsByPipeline = async (
|
|
135 |
}
|
136 |
)
|
137 |
if (!response1.ok) {
|
138 |
-
throw new Error(
|
|
|
|
|
139 |
}
|
140 |
const models1 = await response1.json()
|
141 |
|
@@ -150,14 +175,18 @@ const getModelsByPipeline = async (
|
|
150 |
}
|
151 |
)
|
152 |
if (!response2.ok) {
|
153 |
-
throw new Error(
|
|
|
|
|
154 |
}
|
155 |
const models2 = await response2.json()
|
156 |
|
157 |
// Combine and deduplicate models based on id
|
158 |
-
const combinedModels = [...models1, ...models2].filter(
|
159 |
-
|
160 |
-
|
|
|
|
|
161 |
)
|
162 |
|
163 |
if (pipelineTag === 'text-classification') {
|
@@ -171,11 +200,10 @@ const getModelsByPipeline = async (
|
|
171 |
)
|
172 |
.slice(0, 20)
|
173 |
}
|
174 |
-
|
175 |
return uniqueModels.slice(0, 20)
|
176 |
}
|
177 |
|
178 |
-
|
179 |
const getModelsByPipelineCustom = async (
|
180 |
searchString: string,
|
181 |
pipelineTag: string
|
@@ -197,12 +225,16 @@ const getModelsByPipelineCustom = async (
|
|
197 |
}
|
198 |
)
|
199 |
|
200 |
-
|
201 |
-
throw new Error(
|
|
|
|
|
202 |
}
|
203 |
const models = await response.json()
|
204 |
|
205 |
-
const uniqueModels = models.filter(
|
|
|
|
|
206 |
if (pipelineTag === 'text-classification') {
|
207 |
return uniqueModels
|
208 |
.filter(
|
@@ -214,7 +246,7 @@ const getModelsByPipelineCustom = async (
|
|
214 |
)
|
215 |
.slice(0, 20)
|
216 |
}
|
217 |
-
|
218 |
return uniqueModels.slice(0, 20)
|
219 |
}
|
220 |
|
@@ -239,9 +271,10 @@ function getModelSize(
|
|
239 |
bytesPerParameter = 1
|
240 |
break
|
241 |
case 'bnb4':
|
242 |
-
case 'q4':
|
|
|
243 |
bytesPerParameter = 0.5
|
244 |
-
|
245 |
}
|
246 |
|
247 |
const sizeInBytes = parameters * bytesPerParameter
|
@@ -250,5 +283,9 @@ function getModelSize(
|
|
250 |
return sizeInMB
|
251 |
}
|
252 |
|
253 |
-
|
254 |
-
|
|
|
|
|
|
|
|
|
|
1 |
+
import { supportedPipelines } from '../components/PipelineSelector'
|
2 |
+
import { ModelInfoResponse, QuantizationType } from '../types'
|
3 |
|
4 |
+
const getModelInfo = async (
|
5 |
+
modelName: string,
|
6 |
+
pipeline: string
|
7 |
+
): Promise<ModelInfoResponse> => {
|
8 |
const token = process.env.REACT_APP_HUGGINGFACE_TOKEN
|
9 |
|
10 |
if (!token) {
|
|
|
26 |
if (!response.ok) {
|
27 |
throw new Error(`Failed to fetch model info: ${response.statusText}`)
|
28 |
}
|
29 |
+
|
30 |
const modelData: ModelInfoResponse = await response.json()
|
31 |
+
|
32 |
const requiredFiles = [
|
33 |
'config.json',
|
34 |
'tokenizer.json',
|
35 |
+
'tokenizer_config.json'
|
36 |
]
|
|
|
|
|
|
|
|
|
37 |
|
38 |
+
const siblingFiles = modelData.siblings?.map((s) => s.rfilename) || []
|
39 |
+
const missingFiles = requiredFiles.filter(
|
40 |
+
(file) => !siblingFiles.includes(file)
|
41 |
+
)
|
42 |
+
const hasOnnxFolder = siblingFiles.some(
|
43 |
+
(file) => file.endsWith('.onnx') && file.startsWith('onnx/')
|
44 |
+
)
|
45 |
+
|
46 |
+
const isCompatible =
|
47 |
+
missingFiles.length === 0 &&
|
48 |
+
hasOnnxFolder &&
|
49 |
+
modelData.tags.includes(pipeline)
|
50 |
|
|
|
51 |
let incompatibilityReason = ''
|
52 |
if (!modelData.tags.includes(pipeline)) {
|
53 |
+
const expectedPipelines = modelData.tags
|
54 |
+
.filter((tag) => supportedPipelines.includes(tag))
|
55 |
+
.join(', ')
|
56 |
+
incompatibilityReason = expectedPipelines
|
57 |
+
? `- Model can be used with ${expectedPipelines} pipelines only\n`
|
58 |
+
: `- Pipeline ${pipeline} not supported by the model\n`
|
59 |
+
}
|
60 |
+
if (missingFiles.length > 0) {
|
61 |
+
incompatibilityReason += `- Missing required files: ${missingFiles.join(
|
62 |
+
', '
|
63 |
+
)}\n`
|
64 |
+
} else if (!hasOnnxFolder) {
|
65 |
incompatibilityReason += '- Folder onnx/ is missing\n'
|
66 |
}
|
67 |
+
const supportedQuantizations = hasOnnxFolder
|
68 |
+
? siblingFiles
|
69 |
+
.filter((file) => file.endsWith('.onnx') && file.includes('_'))
|
70 |
+
.map((file) => file.split('/')[1].split('_')[1].split('.')[0])
|
71 |
+
.filter((q) => q !== 'quantized')
|
72 |
+
: []
|
73 |
+
const uniqueSupportedQuantizations = Array.from(
|
74 |
+
new Set(supportedQuantizations)
|
75 |
+
)
|
76 |
uniqueSupportedQuantizations.sort((a, b) => {
|
77 |
const getNumericValue = (str: string) => {
|
78 |
const match = str.match(/(\d+)/)
|
|
|
84 |
// Fetch README content
|
85 |
const fetchReadme = async (modelId: string): Promise<string> => {
|
86 |
try {
|
87 |
+
const readmeResponse = await fetch(
|
88 |
+
`https://huggingface.co/${modelId}/raw/main/README.md`
|
89 |
+
)
|
90 |
if (readmeResponse.ok) {
|
91 |
return await readmeResponse.text()
|
92 |
}
|
|
|
96 |
return ''
|
97 |
}
|
98 |
|
99 |
+
const baseModel = modelData.cardData?.base_model ?? modelData.modelId
|
100 |
if (baseModel && !modelData.safetensors) {
|
101 |
const baseModelResponse = await fetch(
|
102 |
`https://huggingface.co/api/models/${baseModel}`,
|
|
|
111 |
if (baseModelResponse.ok) {
|
112 |
const baseModelData: ModelInfoResponse = await baseModelResponse.json()
|
113 |
const readme = await fetchReadme(baseModel)
|
114 |
+
|
115 |
return {
|
116 |
...baseModelData,
|
117 |
id: modelData.id,
|
118 |
baseId: baseModel,
|
119 |
isCompatible,
|
120 |
incompatibilityReason,
|
121 |
+
supportedQuantizations:
|
122 |
+
uniqueSupportedQuantizations as QuantizationType[],
|
123 |
readme
|
124 |
}
|
125 |
}
|
126 |
}
|
127 |
+
|
128 |
const readme = await fetchReadme(modelData.id)
|
129 |
+
|
130 |
return {
|
131 |
...modelData,
|
132 |
isCompatible,
|
|
|
158 |
}
|
159 |
)
|
160 |
if (!response1.ok) {
|
161 |
+
throw new Error(
|
162 |
+
`Failed to fetch models for pipeline: ${response1.statusText}`
|
163 |
+
)
|
164 |
}
|
165 |
const models1 = await response1.json()
|
166 |
|
|
|
175 |
}
|
176 |
)
|
177 |
if (!response2.ok) {
|
178 |
+
throw new Error(
|
179 |
+
`Failed to fetch models for pipeline: ${response2.statusText}`
|
180 |
+
)
|
181 |
}
|
182 |
const models2 = await response2.json()
|
183 |
|
184 |
// Combine and deduplicate models based on id
|
185 |
+
const combinedModels = [...models1, ...models2].filter(
|
186 |
+
(m: ModelInfoResponse) => m.createdAt > '2022/02/03'
|
187 |
+
)
|
188 |
+
const uniqueModels = combinedModels.filter(
|
189 |
+
(model, index, self) => index === self.findIndex((m) => m.id === model.id)
|
190 |
)
|
191 |
|
192 |
if (pipelineTag === 'text-classification') {
|
|
|
200 |
)
|
201 |
.slice(0, 20)
|
202 |
}
|
203 |
+
|
204 |
return uniqueModels.slice(0, 20)
|
205 |
}
|
206 |
|
|
|
207 |
const getModelsByPipelineCustom = async (
|
208 |
searchString: string,
|
209 |
pipelineTag: string
|
|
|
225 |
}
|
226 |
)
|
227 |
|
228 |
+
if (!response.ok) {
|
229 |
+
throw new Error(
|
230 |
+
`Failed to fetch models for pipeline: ${response.statusText}`
|
231 |
+
)
|
232 |
}
|
233 |
const models = await response.json()
|
234 |
|
235 |
+
const uniqueModels = models.filter(
|
236 |
+
(m: ModelInfoResponse) => m.createdAt > '2022/02/03'
|
237 |
+
)
|
238 |
if (pipelineTag === 'text-classification') {
|
239 |
return uniqueModels
|
240 |
.filter(
|
|
|
246 |
)
|
247 |
.slice(0, 20)
|
248 |
}
|
249 |
+
|
250 |
return uniqueModels.slice(0, 20)
|
251 |
}
|
252 |
|
|
|
271 |
bytesPerParameter = 1
|
272 |
break
|
273 |
case 'bnb4':
|
274 |
+
case 'q4':
|
275 |
+
case 'q4f16':
|
276 |
bytesPerParameter = 0.5
|
277 |
+
break
|
278 |
}
|
279 |
|
280 |
const sizeInBytes = parameters * bytesPerParameter
|
|
|
283 |
return sizeInMB
|
284 |
}
|
285 |
|
286 |
+
export {
|
287 |
+
getModelInfo,
|
288 |
+
getModelSize,
|
289 |
+
getModelsByPipeline,
|
290 |
+
getModelsByPipelineCustom
|
291 |
+
}
|
src/types.ts
CHANGED
@@ -32,7 +32,7 @@ export interface TextClassificationWorkerInput {
|
|
32 |
|
33 |
|
34 |
type q8 = 'q8' | 'int8' | 'bnb8' | 'uint8'
|
35 |
-
type q4 = 'q4' | 'bnb4'
|
36 |
type fp16 = 'fp16'
|
37 |
type fp32 = 'fp32'
|
38 |
|
|
|
32 |
|
33 |
|
34 |
type q8 = 'q8' | 'int8' | 'bnb8' | 'uint8'
|
35 |
+
type q4 = 'q4' | 'bnb4' | 'q4f16'
|
36 |
type fp16 = 'fp16'
|
37 |
type fp32 = 'fp32'
|
38 |
|