Vokturz commited on
Commit
4d810fa
·
1 Parent(s): 322c234

Enhance model handling and loading: add dtype support, improve fetching logic, and refine component interactions

Browse files
public/workers/text-classification.js CHANGED
@@ -1,21 +1,23 @@
1
  /* eslint-disable no-restricted-globals */
2
- import { pipeline } from 'https://cdn.jsdelivr.net/npm/@huggingface/transformers@3.6.3';
3
 
4
  class MyTextClassificationPipeline {
5
  static task = 'text-classification'
6
  static instance = null
7
 
8
- static async getInstance(model, progress_callback = null) {
9
- this.instance = pipeline(this.task, model, {
10
- progress_callback
11
- })
 
 
12
  return this.instance
13
  }
14
  }
15
 
16
  // Listen for messages from the main thread
17
  self.addEventListener('message', async (event) => {
18
- const { type, model, text } = event.data // Destructure 'type'
19
 
20
  if (!model) {
21
  self.postMessage({
@@ -28,6 +30,7 @@ self.addEventListener('message', async (event) => {
28
  // Retrieve the pipeline. This will download the model if not already cached.
29
  const classifier = await MyTextClassificationPipeline.getInstance(
30
  model,
 
31
  (x) => {
32
  self.postMessage({ status: 'loading', output: x })
33
  }
 
1
  /* eslint-disable no-restricted-globals */
2
+ import { pipeline } from 'https://cdn.jsdelivr.net/npm/@huggingface/transformers@3.6.3'
3
 
4
  class MyTextClassificationPipeline {
5
  static task = 'text-classification'
6
  static instance = null
7
 
8
+ static async getInstance(model, dtype = 'fp32', progress_callback = null) {
9
+ this.instance = pipeline(
10
+ this.task,
11
+ model,
12
+ { dtype, progress_callback },
13
+ )
14
  return this.instance
15
  }
16
  }
17
 
18
  // Listen for messages from the main thread
19
  self.addEventListener('message', async (event) => {
20
+ const { type, model, dtype, text } = event.data
21
 
22
  if (!model) {
23
  self.postMessage({
 
30
  // Retrieve the pipeline. This will download the model if not already cached.
31
  const classifier = await MyTextClassificationPipeline.getInstance(
32
  model,
33
+ dtype,
34
  (x) => {
35
  self.postMessage({ status: 'loading', output: x })
36
  }
src/App.tsx CHANGED
@@ -1,4 +1,4 @@
1
- import { useEffect, useState } from 'react'
2
  import PipelineSelector from './components/PipelineSelector'
3
  import ZeroShotClassification from './components/ZeroShotClassification'
4
  import TextClassification from './components/TextClassification'
@@ -10,19 +10,24 @@ import ModelInfo from './components/ModelInfo'
10
  import ModelReadme from './components/ModelReadme'
11
 
12
  function App() {
13
- const { pipeline, setPipeline, setModels, setModelInfo, modelInfo } = useModel()
14
- const [isFetching, setIsFetching] = useState(false)
15
 
16
  useEffect(() => {
17
  setModelInfo(null)
 
 
 
18
  const fetchModels = async () => {
19
- setIsFetching(true)
20
- const fetchedModels = await getModelsByPipeline(pipeline)
21
- setModels(fetchedModels)
22
- setIsFetching(false)
 
 
 
23
  }
24
  fetchModels()
25
- }, [setModels, setModelInfo, pipeline])
26
 
27
  return (
28
  <div className="min-h-screen bg-gradient-to-br from-blue-50 to-indigo-100">
@@ -47,12 +52,12 @@ function App() {
47
  <span className="text-lg font-semibold text-gray-900 block">
48
  Select Model
49
  </span>
50
- <ModelSelector isFetching={isFetching} />
51
  </div>
52
  </div>
53
 
54
  <div className="ml-6">
55
- <ModelInfo isFetching={isFetching} />
56
  </div>
57
  </div>
58
 
 
1
+ import { useEffect } from 'react'
2
  import PipelineSelector from './components/PipelineSelector'
3
  import ZeroShotClassification from './components/ZeroShotClassification'
4
  import TextClassification from './components/TextClassification'
 
10
  import ModelReadme from './components/ModelReadme'
11
 
12
  function App() {
13
+ const { pipeline, setPipeline, setModels, setModelInfo, modelInfo, setIsFetching} = useModel()
 
14
 
15
  useEffect(() => {
16
  setModelInfo(null)
17
+ setModels([])
18
+ setIsFetching(true)
19
+
20
  const fetchModels = async () => {
21
+ try {
22
+ const fetchedModels = await getModelsByPipeline(pipeline)
23
+ setModels(fetchedModels)
24
+ } catch (error) {
25
+ console.error('Error fetching models:', error)
26
+ setIsFetching(false)
27
+ }
28
  }
29
  fetchModels()
30
+ }, [setModels, setModelInfo, setIsFetching, pipeline])
31
 
32
  return (
33
  <div className="min-h-screen bg-gradient-to-br from-blue-50 to-indigo-100">
 
52
  <span className="text-lg font-semibold text-gray-900 block">
53
  Select Model
54
  </span>
55
+ <ModelSelector />
56
  </div>
57
  </div>
58
 
59
  <div className="ml-6">
60
+ <ModelInfo />
61
  </div>
62
  </div>
63
 
src/components/ModelInfo.tsx CHANGED
@@ -12,7 +12,7 @@ import { getModelSize } from '../lib/huggingface'
12
  import { useModel } from '../contexts/ModelContext'
13
  import ModelLoader from './ModelLoader'
14
 
15
- const ModelInfo = ({ isFetching }: { isFetching: boolean }) => {
16
  const formatNumber = (num: number) => {
17
  if (num >= 1000000000) {
18
  return (num / 1000000000).toFixed(1) + 'B'
@@ -25,8 +25,10 @@ const ModelInfo = ({ isFetching }: { isFetching: boolean }) => {
25
  }
26
 
27
  const {
 
28
  modelInfo,
29
- selectedQuantization
 
30
  } = useModel()
31
 
32
  const ModelInfoSkeleton = () => (
@@ -64,7 +66,7 @@ const ModelInfo = ({ isFetching }: { isFetching: boolean }) => {
64
  </div>
65
  )
66
 
67
- if (!modelInfo || isFetching) {
68
  return <ModelInfoSkeleton />
69
  }
70
 
 
12
  import { useModel } from '../contexts/ModelContext'
13
  import ModelLoader from './ModelLoader'
14
 
15
+ const ModelInfo = () => {
16
  const formatNumber = (num: number) => {
17
  if (num >= 1000000000) {
18
  return (num / 1000000000).toFixed(1) + 'B'
 
25
  }
26
 
27
  const {
28
+ models,
29
  modelInfo,
30
+ selectedQuantization,
31
+ isFetching
32
  } = useModel()
33
 
34
  const ModelInfoSkeleton = () => (
 
66
  </div>
67
  )
68
 
69
+ if (!modelInfo || isFetching || models.length === 0) {
70
  return <ModelInfoSkeleton />
71
  }
72
 
src/components/ModelLoader.tsx CHANGED
@@ -15,13 +15,17 @@ const ModelLoader = () => {
15
  setProgress,
16
  activeWorker,
17
  setActiveWorker,
18
- pipeline
 
 
 
19
  } = useModel()
20
 
 
21
  useEffect(() => {
22
  if (!modelInfo) return
23
 
24
- if (modelInfo.isCompatible && modelInfo.supportedQuantizations.length > 0) {
25
  const quantizations = modelInfo.supportedQuantizations
26
  let defaultQuant: QuantizationType = 'fp32'
27
 
@@ -35,10 +39,9 @@ const ModelLoader = () => {
35
 
36
  setSelectedQuantization(defaultQuant)
37
  }
38
- }, [
39
- modelInfo,
40
- setSelectedQuantization
41
- ])
42
 
43
  useEffect(() => {
44
  if (!modelInfo) return
@@ -48,14 +51,18 @@ const ModelLoader = () => {
48
  return
49
  }
50
 
51
- setStatus('initiate')
52
- setActiveWorker(newWorker)
 
 
 
53
 
54
  const onMessageReceived = (e: MessageEvent<WorkerMessage>) => {
55
  const { status, output } = e.data
56
  if (status === 'ready') {
57
  setStatus('ready')
58
- } else if (status === 'loading' && output) {
 
59
  setStatus('loading')
60
  if (
61
  output.progress &&
@@ -64,6 +71,14 @@ const ModelLoader = () => {
64
  ) {
65
  setProgress(output.progress)
66
  }
 
 
 
 
 
 
 
 
67
  }
68
  }
69
 
@@ -73,24 +88,30 @@ const ModelLoader = () => {
73
  newWorker.removeEventListener('message', onMessageReceived)
74
  // terminateWorker(pipeline);
75
  }
76
- }, [pipeline, modelInfo, selectedQuantization, setActiveWorker, setStatus, setProgress])
 
 
 
 
 
 
 
 
 
 
77
 
78
  const loadModel = useCallback(() => {
79
  if (!modelInfo || !selectedQuantization) return
80
 
81
- setStatus('loading')
82
  const message = {
83
  type: 'load',
84
  model: modelInfo.name,
85
- quantization: selectedQuantization
86
  }
87
  activeWorker?.postMessage(message)
88
- }, [modelInfo, selectedQuantization, setStatus, activeWorker])
89
-
90
- const ready: boolean = status === 'ready'
91
- const busy: boolean = status === 'loading'
92
 
93
- if (!modelInfo?.isCompatible || modelInfo.supportedQuantizations.length === 0) {
94
  return null
95
  }
96
 
@@ -100,42 +121,52 @@ const ModelLoader = () => {
100
 
101
  <div className="flex items-center justify-between space-x-4">
102
  <div className="flex items-center space-x-2">
103
- <span className="text-xs text-gray-600 font-medium">
104
- Quantization:
105
- </span>
106
- <div className="relative">
107
- <select
108
- value={selectedQuantization || ''}
109
- onChange={(e) =>
110
- setSelectedQuantization(e.target.value as QuantizationType)
111
- }
112
- className="appearance-none bg-white border border-gray-300 rounded-md px-3 py-1 pr-8 text-xs text-gray-700 focus:outline-none focus:ring-2 focus:ring-blue-500 focus:border-blue-500"
113
- >
114
- <option value="">Select quantization</option>
115
- {modelInfo.supportedQuantizations.map((quant) => (
116
- <option key={quant} value={quant}>
117
- {quant}
118
- </option>
119
- ))}
120
- </select>
121
- <ChevronDown className="absolute right-2 top-1/2 transform -translate-y-1/2 w-3 h-3 text-gray-400 pointer-events-none" />
122
- </div>
 
 
 
 
 
 
 
 
 
123
  </div>
124
 
125
  {selectedQuantization && (
126
  <div className="flex justify-center">
127
  <button
128
  className="w-32 py-2 px-4 bg-green-500 hover:bg-green-600 rounded text-white font-medium disabled:opacity-50 disabled:cursor-not-allowed transition-colors text-sm inline-flex items-center text-center justify-center space-x-2"
129
- disabled={(busy && !ready) || !selectedQuantization || ready}
130
  onClick={loadModel}
131
  >
132
- {status === 'loading' && (
133
  <>
134
  <Loader className="animate-spin h-4 w-4" />
135
  <span>{progress.toFixed(0)}%</span>
136
  </>
 
 
137
  )}
138
- {!ready && !busy ? <span>Load Model</span> : !ready ? null : <span>Model Ready</span>}
139
  </button>
140
  </div>
141
  )}
 
15
  setProgress,
16
  activeWorker,
17
  setActiveWorker,
18
+ pipeline,
19
+ setResults,
20
+ hasBeenLoaded,
21
+ setHasBeenLoaded
22
  } = useModel()
23
 
24
+
25
  useEffect(() => {
26
  if (!modelInfo) return
27
 
28
+ if (modelInfo.isCompatible) {
29
  const quantizations = modelInfo.supportedQuantizations
30
  let defaultQuant: QuantizationType = 'fp32'
31
 
 
39
 
40
  setSelectedQuantization(defaultQuant)
41
  }
42
+
43
+ setHasBeenLoaded(false)
44
+ }, [modelInfo, setSelectedQuantization, setHasBeenLoaded])
 
45
 
46
  useEffect(() => {
47
  if (!modelInfo) return
 
51
  return
52
  }
53
 
54
+ if (!hasBeenLoaded) {
55
+ setStatus('initiate')
56
+ setActiveWorker(newWorker)
57
+ }
58
+
59
 
60
  const onMessageReceived = (e: MessageEvent<WorkerMessage>) => {
61
  const { status, output } = e.data
62
  if (status === 'ready') {
63
  setStatus('ready')
64
+ setHasBeenLoaded(true)
65
+ } else if (status === 'loading' && output && !hasBeenLoaded) {
66
  setStatus('loading')
67
  if (
68
  output.progress &&
 
71
  ) {
72
  setProgress(output.progress)
73
  }
74
+ } else if (status === 'output') {
75
+ setStatus('output')
76
+ const result = e.data.output!
77
+ setResults((prev: any[]) => [...prev, result])
78
+ // console.log(result)
79
+ } else if (status === 'error') {
80
+ setStatus('error')
81
+ console.error(e.data.output)
82
  }
83
  }
84
 
 
88
  newWorker.removeEventListener('message', onMessageReceived)
89
  // terminateWorker(pipeline);
90
  }
91
+ }, [
92
+ pipeline,
93
+ modelInfo,
94
+ selectedQuantization,
95
+ setActiveWorker,
96
+ setStatus,
97
+ setProgress,
98
+ setResults,
99
+ hasBeenLoaded,
100
+ setHasBeenLoaded
101
+ ])
102
 
103
  const loadModel = useCallback(() => {
104
  if (!modelInfo || !selectedQuantization) return
105
 
 
106
  const message = {
107
  type: 'load',
108
  model: modelInfo.name,
109
+ dtype: selectedQuantization ?? 'fp32'
110
  }
111
  activeWorker?.postMessage(message)
112
+ }, [modelInfo, selectedQuantization, activeWorker])
 
 
 
113
 
114
+ if (!modelInfo?.isCompatible) {
115
  return null
116
  }
117
 
 
121
 
122
  <div className="flex items-center justify-between space-x-4">
123
  <div className="flex items-center space-x-2">
124
+ {modelInfo.supportedQuantizations.length > 1 ? (
125
+ <>
126
+ <span className="text-xs text-gray-600 font-medium">
127
+ Quantization:
128
+ </span>
129
+
130
+ <div className="relative">
131
+ <select
132
+ value={selectedQuantization || ''}
133
+ onChange={(e) =>
134
+ setSelectedQuantization(e.target.value as QuantizationType)
135
+ }
136
+ className="appearance-none bg-white border border-gray-300 rounded-md px-3 py-1 pr-8 text-xs text-gray-700 focus:outline-none focus:ring-2 focus:ring-blue-500 focus:border-blue-500"
137
+ >
138
+ <option value="">Select quantization</option>
139
+ {modelInfo.supportedQuantizations.map((quant) => (
140
+ <option key={quant} value={quant}>
141
+ {quant}
142
+ </option>
143
+ ))}
144
+ </select>
145
+ <ChevronDown className="absolute right-2 top-1/2 transform -translate-y-1/2 w-3 h-3 text-gray-400 pointer-events-none" />
146
+ </div>
147
+ </>
148
+ ) : (
149
+ <span className="text-xs text-gray-600 font-medium white-space-break-spaces">
150
+ No quantization available. Using fp32
151
+ </span>
152
+ )}
153
  </div>
154
 
155
  {selectedQuantization && (
156
  <div className="flex justify-center">
157
  <button
158
  className="w-32 py-2 px-4 bg-green-500 hover:bg-green-600 rounded text-white font-medium disabled:opacity-50 disabled:cursor-not-allowed transition-colors text-sm inline-flex items-center text-center justify-center space-x-2"
159
+ disabled={hasBeenLoaded}
160
  onClick={loadModel}
161
  >
162
+ {status === 'loading' && !hasBeenLoaded ? (
163
  <>
164
  <Loader className="animate-spin h-4 w-4" />
165
  <span>{progress.toFixed(0)}%</span>
166
  </>
167
+ ) : (
168
+ <span>{!hasBeenLoaded ? 'Load Model' : 'Model Ready'}</span>
169
  )}
 
170
  </button>
171
  </div>
172
  )}
src/components/ModelSelector.tsx CHANGED
@@ -22,8 +22,15 @@ import {
22
 
23
  type SortOption = 'likes' | 'downloads' | 'createdAt' | 'name'
24
 
25
- function ModelSelector({ isFetching }: { isFetching: boolean }) {
26
- const { models, setModelInfo, modelInfo, pipeline } = useModel()
 
 
 
 
 
 
 
27
  const [sortBy, setSortBy] = useState<SortOption>('downloads')
28
  const [sortOrder, setSortOrder] = useState<'asc' | 'desc'>('desc')
29
  const [showCustomInput, setShowCustomInput] = useState(false)
@@ -102,25 +109,36 @@ function ModelSelector({ isFetching }: { isFetching: boolean }) {
102
  baseId: modelInfoResponse.baseId,
103
  readme: modelInfoResponse.readme
104
  }
 
 
 
105
  setModelInfo(modelInfo)
106
  setIsCustomModel(isCustom)
 
107
  } catch (error) {
108
  console.error('Error fetching model info:', error)
 
109
  throw error
110
  }
111
  },
112
- [setModelInfo, pipeline]
113
  )
114
 
115
- // Update modelInfo to first model when pipeline changes
116
  useEffect(() => {
117
- if (isFetching) return
 
 
 
 
118
 
119
- if (models.length > 0 && !isCustomModel) {
120
- const firstModel = models[0]
 
 
121
  fetchAndSetModelInfo(firstModel.id, false)
122
  }
123
- }, [pipeline, models, fetchAndSetModelInfo, isCustomModel, isFetching])
124
 
125
  const handleModelSelect = (modelId: string) => {
126
  fetchAndSetModelInfo(modelId, false)
@@ -160,8 +178,8 @@ function ModelSelector({ isFetching }: { isFetching: boolean }) {
160
  const handleRemoveCustomModel = () => {
161
  setIsCustomModel(false)
162
  // Load the first model from the list
163
- if (models.length > 0) {
164
- fetchAndSetModelInfo(models[0].id, false)
165
  }
166
  }
167
 
@@ -226,7 +244,7 @@ function ModelSelector({ isFetching }: { isFetching: boolean }) {
226
  )
227
  }
228
 
229
- if (isFetching) {
230
  return (
231
  <div className="relative">
232
  <div className="w-full px-3 py-2 border border-gray-300 rounded-md bg-white flex items-center justify-between animate-pulse h-10">
 
22
 
23
  type SortOption = 'likes' | 'downloads' | 'createdAt' | 'name'
24
 
25
+ function ModelSelector() {
26
+ const {
27
+ models,
28
+ setModelInfo,
29
+ modelInfo,
30
+ pipeline,
31
+ isFetching,
32
+ setIsFetching
33
+ } = useModel()
34
  const [sortBy, setSortBy] = useState<SortOption>('downloads')
35
  const [sortOrder, setSortOrder] = useState<'asc' | 'desc'>('desc')
36
  const [showCustomInput, setShowCustomInput] = useState(false)
 
109
  baseId: modelInfoResponse.baseId,
110
  readme: modelInfoResponse.readme
111
  }
112
+
113
+ console.log('Fetched model info:', modelInfoResponse)
114
+
115
  setModelInfo(modelInfo)
116
  setIsCustomModel(isCustom)
117
+ setIsFetching(false)
118
  } catch (error) {
119
  console.error('Error fetching model info:', error)
120
+ setIsFetching(false)
121
  throw error
122
  }
123
  },
124
+ [setModelInfo, pipeline, setIsFetching]
125
  )
126
 
127
+ // Reset custom model state when pipeline changes
128
  useEffect(() => {
129
+ setIsCustomModel(false)
130
+ setShowCustomInput(false)
131
+ setCustomModelName('')
132
+ setCustomModelError('')
133
+ }, [pipeline])
134
 
135
+ // Update modelInfo to first model when models are loaded and no custom model is selected
136
+ useEffect(() => {
137
+ if (models.length > 0 && !isCustomModel && !modelInfo) {
138
+ const firstModel = sortedModels[0]
139
  fetchAndSetModelInfo(firstModel.id, false)
140
  }
141
+ }, [models, sortedModels, fetchAndSetModelInfo, isCustomModel, modelInfo])
142
 
143
  const handleModelSelect = (modelId: string) => {
144
  fetchAndSetModelInfo(modelId, false)
 
178
  const handleRemoveCustomModel = () => {
179
  setIsCustomModel(false)
180
  // Load the first model from the list
181
+ if (sortedModels.length > 0) {
182
+ fetchAndSetModelInfo(sortedModels[0].id, false)
183
  }
184
  }
185
 
 
244
  )
245
  }
246
 
247
+ if (isFetching || models.length === 0) {
248
  return (
249
  <div className="relative">
250
  <div className="w-full px-3 py-2 border border-gray-300 rounded-md bg-white flex items-center justify-between animate-pulse h-10">
src/components/TextClassification.tsx CHANGED
@@ -22,52 +22,23 @@ const PLACEHOLDER_TEXTS: string[] = [
22
 
23
  function TextClassification() {
24
  const [text, setText] = useState<string>(PLACEHOLDER_TEXTS.join('\n'))
25
- const [results, setResults] = useState<ClassificationOutput[]>([])
26
- const { status, setStatus, modelInfo } = useModel()
27
- const workerRef = useRef<Worker | null>(null)
28
 
29
 
30
- // We use the `useEffect` hook to setup the worker as soon as the component is mounted.
31
- useEffect(() => {
32
- if (!workerRef.current) {
33
- workerRef.current = getWorker('text-classification')
34
- }
35
-
36
- // Create a callback function for messages from the worker thread.
37
- const onMessageReceived = (e: MessageEvent<WorkerMessage>) => {
38
- const status = e.data.status
39
- if (status === 'ready') {
40
- setStatus('ready')
41
- } else if (status === 'output') {
42
- setStatus('output')
43
- const result = e.data.output!
44
- setResults((prevResults) => [...prevResults, result])
45
- console.log(result)
46
- } else if (status === 'error') {
47
- setStatus('error')
48
- console.error(e.data.output)
49
- }
50
- }
51
-
52
- // Attach the callback function as an event listener.
53
- workerRef.current?.addEventListener('message', onMessageReceived)
54
-
55
- // Define a cleanup function for when the component is unmounted.
56
- return () =>
57
- workerRef.current?.removeEventListener('message', onMessageReceived)
58
- }, [setStatus])
59
 
60
  const classify = useCallback(() => {
61
- if (!modelInfo) return
62
- setStatus('loading')
 
 
63
  setResults([]) // Clear previous results
64
  const message: TextClassificationWorkerInput = {
65
  type: 'classify',
66
  text,
67
  model: modelInfo.id
68
  }
69
- workerRef.current?.postMessage(message)
70
- }, [text, modelInfo, setStatus])
71
 
72
  const busy: boolean = status !== 'ready'
73
 
@@ -96,8 +67,7 @@ function TextClassification() {
96
  disabled={busy}
97
  onClick={classify}
98
  >
99
- {status === 'ready'
100
- ? !busy
101
  ? 'Classify Text'
102
  : 'Processing...'
103
  : 'Load model first'}
 
22
 
23
  function TextClassification() {
24
  const [text, setText] = useState<string>(PLACEHOLDER_TEXTS.join('\n'))
25
+ const { activeWorker, status, setStatus, modelInfo, results, setResults, hasBeenLoaded} = useModel()
 
 
26
 
27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
  const classify = useCallback(() => {
30
+ if (!modelInfo || !activeWorker) {
31
+ console.error('Model info or worker is not available')
32
+ return
33
+ }
34
  setResults([]) // Clear previous results
35
  const message: TextClassificationWorkerInput = {
36
  type: 'classify',
37
  text,
38
  model: modelInfo.id
39
  }
40
+ activeWorker.postMessage(message)
41
+ }, [text, modelInfo, setStatus, activeWorker])
42
 
43
  const busy: boolean = status !== 'ready'
44
 
 
67
  disabled={busy}
68
  onClick={classify}
69
  >
70
+ {hasBeenLoaded ? !busy
 
71
  ? 'Classify Text'
72
  : 'Processing...'
73
  : 'Load model first'}
src/contexts/ModelContext.tsx CHANGED
@@ -26,6 +26,12 @@ interface ModelContextType {
26
  setSelectedQuantization: (quantization: QuantizationType) => void
27
  activeWorker: Worker | null
28
  setActiveWorker: (worker: Worker | null) => void
 
 
 
 
 
 
29
  }
30
 
31
  const ModelContext = createContext<ModelContextType | undefined>(undefined)
@@ -41,6 +47,10 @@ export function ModelProvider({ children }: { children: React.ReactNode }) {
41
  const [selectedQuantization, setSelectedQuantization] =
42
  useState<QuantizationType>('int8')
43
  const [activeWorker, setActiveWorker] = useState<Worker | null>(null)
 
 
 
 
44
 
45
  // set progress to 0 when model is changed
46
  useEffect(() => {
@@ -63,7 +73,13 @@ export function ModelProvider({ children }: { children: React.ReactNode }) {
63
  selectedQuantization,
64
  setSelectedQuantization,
65
  activeWorker,
66
- setActiveWorker
 
 
 
 
 
 
67
  }}
68
  >
69
  {children}
 
26
  setSelectedQuantization: (quantization: QuantizationType) => void
27
  activeWorker: Worker | null
28
  setActiveWorker: (worker: Worker | null) => void
29
+ isFetching: boolean
30
+ setIsFetching: (isFetching: boolean) => void
31
+ results: any[]
32
+ setResults: React.Dispatch<React.SetStateAction<any[]>>
33
+ hasBeenLoaded: boolean
34
+ setHasBeenLoaded: (hasBeenLoaded: boolean) => void
35
  }
36
 
37
  const ModelContext = createContext<ModelContextType | undefined>(undefined)
 
47
  const [selectedQuantization, setSelectedQuantization] =
48
  useState<QuantizationType>('int8')
49
  const [activeWorker, setActiveWorker] = useState<Worker | null>(null)
50
+ const [isFetching, setIsFetching] = useState(false)
51
+ const [results, setResults] = useState<any[]>([])
52
+ const [hasBeenLoaded, setHasBeenLoaded] = useState(false)
53
+
54
 
55
  // set progress to 0 when model is changed
56
  useEffect(() => {
 
73
  selectedQuantization,
74
  setSelectedQuantization,
75
  activeWorker,
76
+ setActiveWorker,
77
+ isFetching,
78
+ setIsFetching,
79
+ results,
80
+ setResults,
81
+ hasBeenLoaded,
82
+ setHasBeenLoaded
83
  }}
84
  >
85
  {children}
src/lib/huggingface.ts CHANGED
@@ -1,7 +1,10 @@
1
- import { supportedPipelines } from "../components/PipelineSelector"
2
- import { ModelInfoResponse, QuantizationType } from "../types"
3
 
4
- const getModelInfo = async (modelName: string, pipeline: string): Promise<ModelInfoResponse> => {
 
 
 
5
  const token = process.env.REACT_APP_HUGGINGFACE_TOKEN
6
 
7
  if (!token) {
@@ -23,36 +26,53 @@ const getModelInfo = async (modelName: string, pipeline: string): Promise<ModelI
23
  if (!response.ok) {
24
  throw new Error(`Failed to fetch model info: ${response.statusText}`)
25
  }
26
-
27
  const modelData: ModelInfoResponse = await response.json()
28
-
29
  const requiredFiles = [
30
  'config.json',
31
  'tokenizer.json',
32
- 'tokenizer_config.json',
33
  ]
34
-
35
- const siblingFiles = modelData.siblings?.map(s => s.rfilename) || []
36
- const missingFiles = requiredFiles.filter(file => !siblingFiles.includes(file))
37
- const hasOnnxFolder = siblingFiles.some((file) => file.endsWith('.onnx') && file.startsWith('onnx/'))
38
 
39
- const isCompatible = missingFiles.length === 0 && hasOnnxFolder && modelData.tags.includes(pipeline)
 
 
 
 
 
 
 
 
 
 
 
40
 
41
-
42
  let incompatibilityReason = ''
43
  if (!modelData.tags.includes(pipeline)) {
44
- const expectedPipelines = modelData.tags.filter(tag => supportedPipelines.includes(tag)).join(', ')
45
- incompatibilityReason = expectedPipelines ? `- Model can be used with ${expectedPipelines} pipelines only\n` : `- Pipeline ${pipeline} not supported by the model\n`
46
- } if (missingFiles.length > 0) {
47
- incompatibilityReason += `- Missing required files: ${missingFiles.join(', ')}\n`
48
- } else if (!hasOnnxFolder) {
 
 
 
 
 
 
 
49
  incompatibilityReason += '- Folder onnx/ is missing\n'
50
  }
51
- const supportedQuantizations = siblingFiles
52
- .filter((file) => file.endsWith('.onnx') && file.includes('_'))
53
- .map((file) => file.split('/')[1].split('_')[1].split('.')[0])
54
- .filter((q) => q !== 'quantized')
55
- const uniqueSupportedQuantizations = Array.from(new Set(supportedQuantizations))
 
 
 
 
56
  uniqueSupportedQuantizations.sort((a, b) => {
57
  const getNumericValue = (str: string) => {
58
  const match = str.match(/(\d+)/)
@@ -64,7 +84,9 @@ const getModelInfo = async (modelName: string, pipeline: string): Promise<ModelI
64
  // Fetch README content
65
  const fetchReadme = async (modelId: string): Promise<string> => {
66
  try {
67
- const readmeResponse = await fetch(`https://huggingface.co/${modelId}/raw/main/README.md`)
 
 
68
  if (readmeResponse.ok) {
69
  return await readmeResponse.text()
70
  }
@@ -74,7 +96,7 @@ const getModelInfo = async (modelName: string, pipeline: string): Promise<ModelI
74
  return ''
75
  }
76
 
77
- const baseModel = modelData.cardData?.base_model ?? modelData.modelId
78
  if (baseModel && !modelData.safetensors) {
79
  const baseModelResponse = await fetch(
80
  `https://huggingface.co/api/models/${baseModel}`,
@@ -89,21 +111,22 @@ const getModelInfo = async (modelName: string, pipeline: string): Promise<ModelI
89
  if (baseModelResponse.ok) {
90
  const baseModelData: ModelInfoResponse = await baseModelResponse.json()
91
  const readme = await fetchReadme(baseModel)
92
-
93
  return {
94
  ...baseModelData,
95
  id: modelData.id,
96
  baseId: baseModel,
97
  isCompatible,
98
  incompatibilityReason,
99
- supportedQuantizations: uniqueSupportedQuantizations as QuantizationType[],
 
100
  readme
101
  }
102
  }
103
  }
104
-
105
  const readme = await fetchReadme(modelData.id)
106
-
107
  return {
108
  ...modelData,
109
  isCompatible,
@@ -135,7 +158,9 @@ const getModelsByPipeline = async (
135
  }
136
  )
137
  if (!response1.ok) {
138
- throw new Error(`Failed to fetch models for pipeline: ${response1.statusText}`)
 
 
139
  }
140
  const models1 = await response1.json()
141
 
@@ -150,14 +175,18 @@ const getModelsByPipeline = async (
150
  }
151
  )
152
  if (!response2.ok) {
153
- throw new Error(`Failed to fetch models for pipeline: ${response2.statusText}`)
 
 
154
  }
155
  const models2 = await response2.json()
156
 
157
  // Combine and deduplicate models based on id
158
- const combinedModels = [...models1, ...models2].filter((m: ModelInfoResponse) => m.createdAt > '2022/02/03')
159
- const uniqueModels = combinedModels.filter((model, index, self) =>
160
- index === self.findIndex(m => m.id === model.id)
 
 
161
  )
162
 
163
  if (pipelineTag === 'text-classification') {
@@ -171,11 +200,10 @@ const getModelsByPipeline = async (
171
  )
172
  .slice(0, 20)
173
  }
174
-
175
  return uniqueModels.slice(0, 20)
176
  }
177
 
178
-
179
  const getModelsByPipelineCustom = async (
180
  searchString: string,
181
  pipelineTag: string
@@ -197,12 +225,16 @@ const getModelsByPipelineCustom = async (
197
  }
198
  )
199
 
200
- if (!response.ok) {
201
- throw new Error(`Failed to fetch models for pipeline: ${response.statusText}`)
 
 
202
  }
203
  const models = await response.json()
204
 
205
- const uniqueModels = models.filter((m: ModelInfoResponse) => m.createdAt > '2022/02/03')
 
 
206
  if (pipelineTag === 'text-classification') {
207
  return uniqueModels
208
  .filter(
@@ -214,7 +246,7 @@ const getModelsByPipelineCustom = async (
214
  )
215
  .slice(0, 20)
216
  }
217
-
218
  return uniqueModels.slice(0, 20)
219
  }
220
 
@@ -239,9 +271,10 @@ function getModelSize(
239
  bytesPerParameter = 1
240
  break
241
  case 'bnb4':
242
- case 'q4':
 
243
  bytesPerParameter = 0.5
244
- break
245
  }
246
 
247
  const sizeInBytes = parameters * bytesPerParameter
@@ -250,5 +283,9 @@ function getModelSize(
250
  return sizeInMB
251
  }
252
 
253
-
254
- export { getModelInfo, getModelSize, getModelsByPipeline, getModelsByPipelineCustom }
 
 
 
 
 
1
+ import { supportedPipelines } from '../components/PipelineSelector'
2
+ import { ModelInfoResponse, QuantizationType } from '../types'
3
 
4
+ const getModelInfo = async (
5
+ modelName: string,
6
+ pipeline: string
7
+ ): Promise<ModelInfoResponse> => {
8
  const token = process.env.REACT_APP_HUGGINGFACE_TOKEN
9
 
10
  if (!token) {
 
26
  if (!response.ok) {
27
  throw new Error(`Failed to fetch model info: ${response.statusText}`)
28
  }
29
+
30
  const modelData: ModelInfoResponse = await response.json()
31
+
32
  const requiredFiles = [
33
  'config.json',
34
  'tokenizer.json',
35
+ 'tokenizer_config.json'
36
  ]
 
 
 
 
37
 
38
+ const siblingFiles = modelData.siblings?.map((s) => s.rfilename) || []
39
+ const missingFiles = requiredFiles.filter(
40
+ (file) => !siblingFiles.includes(file)
41
+ )
42
+ const hasOnnxFolder = siblingFiles.some(
43
+ (file) => file.endsWith('.onnx') && file.startsWith('onnx/')
44
+ )
45
+
46
+ const isCompatible =
47
+ missingFiles.length === 0 &&
48
+ hasOnnxFolder &&
49
+ modelData.tags.includes(pipeline)
50
 
 
51
  let incompatibilityReason = ''
52
  if (!modelData.tags.includes(pipeline)) {
53
+ const expectedPipelines = modelData.tags
54
+ .filter((tag) => supportedPipelines.includes(tag))
55
+ .join(', ')
56
+ incompatibilityReason = expectedPipelines
57
+ ? `- Model can be used with ${expectedPipelines} pipelines only\n`
58
+ : `- Pipeline ${pipeline} not supported by the model\n`
59
+ }
60
+ if (missingFiles.length > 0) {
61
+ incompatibilityReason += `- Missing required files: ${missingFiles.join(
62
+ ', '
63
+ )}\n`
64
+ } else if (!hasOnnxFolder) {
65
  incompatibilityReason += '- Folder onnx/ is missing\n'
66
  }
67
+ const supportedQuantizations = hasOnnxFolder
68
+ ? siblingFiles
69
+ .filter((file) => file.endsWith('.onnx') && file.includes('_'))
70
+ .map((file) => file.split('/')[1].split('_')[1].split('.')[0])
71
+ .filter((q) => q !== 'quantized')
72
+ : []
73
+ const uniqueSupportedQuantizations = Array.from(
74
+ new Set(supportedQuantizations)
75
+ )
76
  uniqueSupportedQuantizations.sort((a, b) => {
77
  const getNumericValue = (str: string) => {
78
  const match = str.match(/(\d+)/)
 
84
  // Fetch README content
85
  const fetchReadme = async (modelId: string): Promise<string> => {
86
  try {
87
+ const readmeResponse = await fetch(
88
+ `https://huggingface.co/${modelId}/raw/main/README.md`
89
+ )
90
  if (readmeResponse.ok) {
91
  return await readmeResponse.text()
92
  }
 
96
  return ''
97
  }
98
 
99
+ const baseModel = modelData.cardData?.base_model ?? modelData.modelId
100
  if (baseModel && !modelData.safetensors) {
101
  const baseModelResponse = await fetch(
102
  `https://huggingface.co/api/models/${baseModel}`,
 
111
  if (baseModelResponse.ok) {
112
  const baseModelData: ModelInfoResponse = await baseModelResponse.json()
113
  const readme = await fetchReadme(baseModel)
114
+
115
  return {
116
  ...baseModelData,
117
  id: modelData.id,
118
  baseId: baseModel,
119
  isCompatible,
120
  incompatibilityReason,
121
+ supportedQuantizations:
122
+ uniqueSupportedQuantizations as QuantizationType[],
123
  readme
124
  }
125
  }
126
  }
127
+
128
  const readme = await fetchReadme(modelData.id)
129
+
130
  return {
131
  ...modelData,
132
  isCompatible,
 
158
  }
159
  )
160
  if (!response1.ok) {
161
+ throw new Error(
162
+ `Failed to fetch models for pipeline: ${response1.statusText}`
163
+ )
164
  }
165
  const models1 = await response1.json()
166
 
 
175
  }
176
  )
177
  if (!response2.ok) {
178
+ throw new Error(
179
+ `Failed to fetch models for pipeline: ${response2.statusText}`
180
+ )
181
  }
182
  const models2 = await response2.json()
183
 
184
  // Combine and deduplicate models based on id
185
+ const combinedModels = [...models1, ...models2].filter(
186
+ (m: ModelInfoResponse) => m.createdAt > '2022/02/03'
187
+ )
188
+ const uniqueModels = combinedModels.filter(
189
+ (model, index, self) => index === self.findIndex((m) => m.id === model.id)
190
  )
191
 
192
  if (pipelineTag === 'text-classification') {
 
200
  )
201
  .slice(0, 20)
202
  }
203
+
204
  return uniqueModels.slice(0, 20)
205
  }
206
 
 
207
  const getModelsByPipelineCustom = async (
208
  searchString: string,
209
  pipelineTag: string
 
225
  }
226
  )
227
 
228
+ if (!response.ok) {
229
+ throw new Error(
230
+ `Failed to fetch models for pipeline: ${response.statusText}`
231
+ )
232
  }
233
  const models = await response.json()
234
 
235
+ const uniqueModels = models.filter(
236
+ (m: ModelInfoResponse) => m.createdAt > '2022/02/03'
237
+ )
238
  if (pipelineTag === 'text-classification') {
239
  return uniqueModels
240
  .filter(
 
246
  )
247
  .slice(0, 20)
248
  }
249
+
250
  return uniqueModels.slice(0, 20)
251
  }
252
 
 
271
  bytesPerParameter = 1
272
  break
273
  case 'bnb4':
274
+ case 'q4':
275
+ case 'q4f16':
276
  bytesPerParameter = 0.5
277
+ break
278
  }
279
 
280
  const sizeInBytes = parameters * bytesPerParameter
 
283
  return sizeInMB
284
  }
285
 
286
+ export {
287
+ getModelInfo,
288
+ getModelSize,
289
+ getModelsByPipeline,
290
+ getModelsByPipelineCustom
291
+ }
src/types.ts CHANGED
@@ -32,7 +32,7 @@ export interface TextClassificationWorkerInput {
32
 
33
 
34
  type q8 = 'q8' | 'int8' | 'bnb8' | 'uint8'
35
- type q4 = 'q4' | 'bnb4'
36
  type fp16 = 'fp16'
37
  type fp32 = 'fp32'
38
 
 
32
 
33
 
34
  type q8 = 'q8' | 'int8' | 'bnb8' | 'uint8'
35
+ type q4 = 'q4' | 'bnb4' | 'q4f16'
36
  type fp16 = 'fp16'
37
  type fp32 = 'fp32'
38