wip: enhance model selection and display across components
Browse files- src/App.tsx +23 -12
- src/components/ModelSelector.tsx +177 -22
- src/components/PipelineSelector.tsx +4 -0
- src/components/TextClassification.tsx +10 -10
- src/components/ZeroShotClassification.tsx +4 -2
- src/contexts/ModelContext.tsx +12 -2
- src/lib/huggingface.ts +34 -25
- src/types.ts +29 -0
src/App.tsx
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
import { useState } from 'react'
|
2 |
import PipelineSelector from './components/PipelineSelector'
|
3 |
import ZeroShotClassification from './components/ZeroShotClassification'
|
4 |
import TextClassification from './components/TextClassification'
|
@@ -6,11 +6,11 @@ import Header from './Header'
|
|
6 |
import Footer from './Footer'
|
7 |
import { useModel } from './contexts/ModelContext'
|
8 |
import { Bot, Heart, Download, Cpu, DatabaseIcon } from 'lucide-react'
|
9 |
-
import { getModelSize } from './lib/huggingface'
|
|
|
10 |
|
11 |
function App() {
|
12 |
-
const
|
13 |
-
const { progress, status, modelInfo } = useModel()
|
14 |
|
15 |
const formatNumber = (num: number) => {
|
16 |
if (num >= 1000000000) {
|
@@ -23,6 +23,14 @@ function App() {
|
|
23 |
return num.toString()
|
24 |
}
|
25 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
return (
|
27 |
<div className="min-h-screen bg-gradient-to-br from-blue-50 to-indigo-100">
|
28 |
<Header />
|
@@ -32,10 +40,6 @@ function App() {
|
|
32 |
<div className="mb-8">
|
33 |
<div className="bg-white rounded-lg shadow-sm border p-6">
|
34 |
<div className="flex items-center justify-between mb-4">
|
35 |
-
<h2 className="text-lg font-semibold text-gray-900">
|
36 |
-
Choose a Pipeline
|
37 |
-
</h2>
|
38 |
-
|
39 |
{/* Model Info Display */}
|
40 |
{modelInfo.name && (
|
41 |
<div className="bg-gradient-to-r from-blue-50 to-indigo-50 px-4 py-3 rounded-lg border border-blue-200 space-y-2">
|
@@ -77,9 +81,10 @@ function App() {
|
|
77 |
<div className="flex items-center space-x-1">
|
78 |
<DatabaseIcon className="w-3 h-3 text-purple-500" />
|
79 |
<span>
|
80 |
-
{`~${getModelSize(
|
81 |
-
|
82 |
-
|
|
|
83 |
</span>
|
84 |
</div>
|
85 |
)}
|
@@ -88,7 +93,13 @@ function App() {
|
|
88 |
)}
|
89 |
</div>
|
90 |
|
91 |
-
<
|
|
|
|
|
|
|
|
|
|
|
|
|
92 |
|
93 |
{/* Model Loading Progress */}
|
94 |
{status === 'progress' && (
|
|
|
1 |
+
import { useEffect, useState } from 'react'
|
2 |
import PipelineSelector from './components/PipelineSelector'
|
3 |
import ZeroShotClassification from './components/ZeroShotClassification'
|
4 |
import TextClassification from './components/TextClassification'
|
|
|
6 |
import Footer from './Footer'
|
7 |
import { useModel } from './contexts/ModelContext'
|
8 |
import { Bot, Heart, Download, Cpu, DatabaseIcon } from 'lucide-react'
|
9 |
+
import { getModelsByPipeline, getModelSize } from './lib/huggingface'
|
10 |
+
import ModelSelector from './components/ModelSelector'
|
11 |
|
12 |
function App() {
|
13 |
+
const { pipeline, setPipeline, progress, status, modelInfo, setModels } = useModel()
|
|
|
14 |
|
15 |
const formatNumber = (num: number) => {
|
16 |
if (num >= 1000000000) {
|
|
|
23 |
return num.toString()
|
24 |
}
|
25 |
|
26 |
+
useEffect(() => {
|
27 |
+
const fetchModels = async () => {
|
28 |
+
const fetchedModels = await getModelsByPipeline(pipeline);
|
29 |
+
setModels(fetchedModels);
|
30 |
+
};
|
31 |
+
fetchModels();
|
32 |
+
}, [setModels, pipeline]);
|
33 |
+
|
34 |
return (
|
35 |
<div className="min-h-screen bg-gradient-to-br from-blue-50 to-indigo-100">
|
36 |
<Header />
|
|
|
40 |
<div className="mb-8">
|
41 |
<div className="bg-white rounded-lg shadow-sm border p-6">
|
42 |
<div className="flex items-center justify-between mb-4">
|
|
|
|
|
|
|
|
|
43 |
{/* Model Info Display */}
|
44 |
{modelInfo.name && (
|
45 |
<div className="bg-gradient-to-r from-blue-50 to-indigo-50 px-4 py-3 rounded-lg border border-blue-200 space-y-2">
|
|
|
81 |
<div className="flex items-center space-x-1">
|
82 |
<DatabaseIcon className="w-3 h-3 text-purple-500" />
|
83 |
<span>
|
84 |
+
{`~${getModelSize(
|
85 |
+
modelInfo.parameters,
|
86 |
+
'INT8'
|
87 |
+
).toFixed(1)}MB`}
|
88 |
</span>
|
89 |
</div>
|
90 |
)}
|
|
|
93 |
)}
|
94 |
</div>
|
95 |
|
96 |
+
<div className="flex flex-row items-center space-x-4">
|
97 |
+
<span className="text-lg font-semibold text-gray-900">
|
98 |
+
Choose a Pipeline
|
99 |
+
</span>
|
100 |
+
<PipelineSelector pipeline={pipeline} setPipeline={setPipeline} />
|
101 |
+
</div>
|
102 |
+
<ModelSelector />
|
103 |
|
104 |
{/* Model Loading Progress */}
|
105 |
{status === 'progress' && (
|
src/components/ModelSelector.tsx
CHANGED
@@ -1,25 +1,180 @@
|
|
1 |
-
import React from 'react'
|
|
|
|
|
|
|
2 |
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
|
9 |
-
const ModelSelector: React.FC<ModelSelectorProps> = ({
|
10 |
-
model,
|
11 |
-
setModel,
|
12 |
-
models
|
13 |
-
}) => {
|
14 |
return (
|
15 |
-
<
|
16 |
-
{
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import React, { useEffect, useState } from 'react'
|
2 |
+
import { useModel } from '../contexts/ModelContext'
|
3 |
+
import { getModelInfo } from '../lib/huggingface'
|
4 |
+
import { Heart, Download, ChevronDown } from 'lucide-react'
|
5 |
|
6 |
+
const ModelSelector: React.FC = () => {
|
7 |
+
const { models, setModelInfo, modelInfo } = useModel()
|
8 |
+
const [isOpen, setIsOpen] = useState(false)
|
9 |
+
const [modelStats, setModelStats] = useState<
|
10 |
+
Record<string, { likes: number; downloads: number; createdAt: string }>
|
11 |
+
>({})
|
12 |
+
|
13 |
+
const formatNumber = (num: number) => {
|
14 |
+
if (num >= 1000000000) {
|
15 |
+
return (num / 1000000000).toFixed(1) + 'B'
|
16 |
+
} else if (num >= 1000000) {
|
17 |
+
return (num / 1000000).toFixed(1) + 'M'
|
18 |
+
} else if (num >= 1000) {
|
19 |
+
return (num / 1000).toFixed(1) + 'K'
|
20 |
+
}
|
21 |
+
return num.toString()
|
22 |
+
}
|
23 |
+
|
24 |
+
// Separate function to fetch only stats without updating selected model
|
25 |
+
const fetchModelStats = async (modelId: string) => {
|
26 |
+
try {
|
27 |
+
const modelInfoResponse = await getModelInfo(modelId)
|
28 |
+
|
29 |
+
setModelStats((prev) => ({
|
30 |
+
...prev,
|
31 |
+
[modelId]: {
|
32 |
+
likes: modelInfoResponse.likes || 0,
|
33 |
+
downloads: modelInfoResponse.downloads || 0,
|
34 |
+
createdAt: modelInfoResponse.createdAt || ''
|
35 |
+
}
|
36 |
+
}))
|
37 |
+
} catch (error) {
|
38 |
+
console.error('Error fetching model stats:', error)
|
39 |
+
}
|
40 |
+
}
|
41 |
+
|
42 |
+
// Function to fetch full model info and set as selected
|
43 |
+
const fetchModelAndSetInfo = async (modelId: string) => {
|
44 |
+
try {
|
45 |
+
const modelInfoResponse = await getModelInfo(modelId)
|
46 |
+
let parameters = 0
|
47 |
+
if (modelInfoResponse.safetensors) {
|
48 |
+
const safetensors = modelInfoResponse.safetensors
|
49 |
+
parameters =
|
50 |
+
safetensors.parameters.F16 ||
|
51 |
+
safetensors.parameters.F32 ||
|
52 |
+
safetensors.parameters.total ||
|
53 |
+
0
|
54 |
+
}
|
55 |
+
|
56 |
+
// Transform ModelInfoResponse to ModelInfo
|
57 |
+
const modelInfo = {
|
58 |
+
id: modelId,
|
59 |
+
name: modelInfoResponse.id || modelId,
|
60 |
+
architecture: modelInfoResponse.config?.architectures?.[0] || 'Unknown',
|
61 |
+
parameters,
|
62 |
+
likes: modelInfoResponse.likes || 0,
|
63 |
+
downloads: modelInfoResponse.downloads || 0,
|
64 |
+
createdAt: modelInfoResponse.createdAt || ''
|
65 |
+
}
|
66 |
+
|
67 |
+
// Also update stats
|
68 |
+
setModelStats((prev) => ({
|
69 |
+
...prev,
|
70 |
+
[modelId]: {
|
71 |
+
likes: modelInfoResponse.likes || 0,
|
72 |
+
downloads: modelInfoResponse.downloads || 0,
|
73 |
+
createdAt: modelInfoResponse.createdAt || ''
|
74 |
+
}
|
75 |
+
}))
|
76 |
+
|
77 |
+
console.log(modelInfoResponse)
|
78 |
+
|
79 |
+
setModelInfo(modelInfo)
|
80 |
+
} catch (error) {
|
81 |
+
console.error('Error fetching model info:', error)
|
82 |
+
}
|
83 |
+
}
|
84 |
+
|
85 |
+
// Fetch stats for all models when component mounts (without setting as selected)
|
86 |
+
useEffect(() => {
|
87 |
+
models.forEach((model) => {
|
88 |
+
if (!modelStats[model.id]) {
|
89 |
+
fetchModelStats(model.id)
|
90 |
+
}
|
91 |
+
})
|
92 |
+
}, [models])
|
93 |
+
|
94 |
+
// Only fetch full info when a model is actually selected
|
95 |
+
useEffect(() => {
|
96 |
+
if (!modelInfo.id) return
|
97 |
+
// Only fetch if we don't already have the full info
|
98 |
+
if (!modelStats[modelInfo.id]) {
|
99 |
+
fetchModelAndSetInfo(modelInfo.id)
|
100 |
+
}
|
101 |
+
}, [modelInfo.id])
|
102 |
+
|
103 |
+
const handleModelSelect = (modelId: string) => {
|
104 |
+
fetchModelAndSetInfo(modelId)
|
105 |
+
setIsOpen(false)
|
106 |
+
}
|
107 |
|
|
|
|
|
|
|
|
|
|
|
108 |
return (
|
109 |
+
<div className="relative">
|
110 |
+
{/* Custom Dropdown Button */}
|
111 |
+
<button
|
112 |
+
onClick={() => setIsOpen(!isOpen)}
|
113 |
+
className="w-full px-3 py-2 border border-gray-300 rounded-md focus:outline-none focus:ring-2 focus:ring-blue-500 focus:border-transparent bg-white text-left flex items-center justify-between"
|
114 |
+
>
|
115 |
+
<span className="truncate">{modelInfo.id || 'Select a model'}</span>
|
116 |
+
<ChevronDown
|
117 |
+
className={`w-4 h-4 transition-transform ${
|
118 |
+
isOpen ? 'rotate-180' : ''
|
119 |
+
}`}
|
120 |
+
/>
|
121 |
+
</button>
|
122 |
+
|
123 |
+
{/* Custom Dropdown Options */}
|
124 |
+
{isOpen && (
|
125 |
+
<div className="absolute z-10 w-full mt-1 bg-white border border-gray-300 rounded-md shadow-lg max-h-60 overflow-auto">
|
126 |
+
{models.map((model) => (
|
127 |
+
<div
|
128 |
+
key={model.id}
|
129 |
+
onClick={() => handleModelSelect(model.id)}
|
130 |
+
className="px-3 py-2 hover:bg-gray-50 cursor-pointer border-b border-gray-100 last:border-b-0"
|
131 |
+
>
|
132 |
+
<div className="flex items-center justify-between">
|
133 |
+
<span className="text-sm font-medium truncate flex-1 mr-2">
|
134 |
+
{model.id}
|
135 |
+
</span>
|
136 |
+
|
137 |
+
{/* Stats Display */}
|
138 |
+
{modelStats[model.id] &&
|
139 |
+
(modelStats[model.id].likes > 0 ||
|
140 |
+
modelStats[model.id].downloads > 0) && (
|
141 |
+
<div className="flex items-center space-x-3 text-xs text-gray-500 flex-shrink-0">
|
142 |
+
{modelStats[model.id].likes > 0 && (
|
143 |
+
<div className="flex items-center space-x-1">
|
144 |
+
<Heart className="w-3 h-3 text-red-500" />
|
145 |
+
<span>
|
146 |
+
{formatNumber(modelStats[model.id].likes)}
|
147 |
+
</span>
|
148 |
+
</div>
|
149 |
+
)}
|
150 |
+
|
151 |
+
{modelStats[model.id].downloads > 0 && (
|
152 |
+
<div className="flex items-center space-x-1">
|
153 |
+
<Download className="w-3 h-3 text-green-500" />
|
154 |
+
<span>
|
155 |
+
{formatNumber(modelStats[model.id].downloads)}
|
156 |
+
</span>
|
157 |
+
</div>
|
158 |
+
)}
|
159 |
+
{modelStats[model.id].createdAt !== '' && (
|
160 |
+
<span className="text-xs text-gray-400">
|
161 |
+
{modelStats[model.id].createdAt.split('T')[0]}
|
162 |
+
</span>
|
163 |
+
)}
|
164 |
+
</div>
|
165 |
+
)}
|
166 |
+
</div>
|
167 |
+
</div>
|
168 |
+
))}
|
169 |
+
</div>
|
170 |
+
)}
|
171 |
+
|
172 |
+
{/* Click outside to close */}
|
173 |
+
{isOpen && (
|
174 |
+
<div className="fixed inset-0 z-0" onClick={() => setIsOpen(false)} />
|
175 |
+
)}
|
176 |
+
</div>
|
177 |
+
)
|
178 |
+
}
|
179 |
+
|
180 |
+
export default ModelSelector
|
src/components/PipelineSelector.tsx
CHANGED
@@ -3,6 +3,10 @@ import React from 'react';
|
|
3 |
const pipelines = [
|
4 |
'zero-shot-classification',
|
5 |
'text-classification',
|
|
|
|
|
|
|
|
|
6 |
'image-classification',
|
7 |
'question-answering',
|
8 |
'translation'
|
|
|
3 |
const pipelines = [
|
4 |
'zero-shot-classification',
|
5 |
'text-classification',
|
6 |
+
'text-generation',
|
7 |
+
'summarization',
|
8 |
+
'feature-extraction',
|
9 |
+
'sentiment-analysis',
|
10 |
'image-classification',
|
11 |
'question-answering',
|
12 |
'translation'
|
src/components/TextClassification.tsx
CHANGED
@@ -3,7 +3,6 @@ import {
|
|
3 |
ClassificationOutput,
|
4 |
TextClassificationWorkerInput,
|
5 |
WorkerMessage,
|
6 |
-
ModelInfo
|
7 |
} from '../types';
|
8 |
import { useModel } from '../contexts/ModelContext';
|
9 |
import { getModelInfo } from '../lib/huggingface';
|
@@ -25,13 +24,14 @@ const PLACEHOLDER_TEXTS: string[] = [
|
|
25 |
function TextClassification() {
|
26 |
const [text, setText] = useState<string>(PLACEHOLDER_TEXTS.join('\n'))
|
27 |
const [results, setResults] = useState<ClassificationOutput[]>([])
|
28 |
-
const { setProgress, status, setStatus, modelInfo, setModelInfo} = useModel()
|
|
|
|
|
29 |
useEffect(() => {
|
30 |
-
|
31 |
const fetchModelInfo = async () => {
|
32 |
try {
|
33 |
-
const modelInfoResponse = await getModelInfo(
|
34 |
-
console.log(modelInfoResponse)
|
35 |
let parameters = 0
|
36 |
if (modelInfoResponse.safetensors) {
|
37 |
const safetensors = modelInfoResponse.safetensors
|
@@ -42,8 +42,8 @@ function TextClassification() {
|
|
42 |
0)
|
43 |
}
|
44 |
setModelInfo({
|
45 |
-
|
46 |
-
architecture: modelInfoResponse.config
|
47 |
parameters,
|
48 |
likes: modelInfoResponse.likes,
|
49 |
downloads: modelInfoResponse.downloads
|
@@ -54,7 +54,7 @@ function TextClassification() {
|
|
54 |
}
|
55 |
|
56 |
fetchModelInfo()
|
57 |
-
}, [setModelInfo])
|
58 |
|
59 |
// Create a reference to the worker object.
|
60 |
const worker = useRef<Worker | null>(null)
|
@@ -110,9 +110,9 @@ function TextClassification() {
|
|
110 |
const classify = useCallback(() => {
|
111 |
setStatus('processing')
|
112 |
setResults([]) // Clear previous results
|
113 |
-
const message: TextClassificationWorkerInput = { text, model: modelInfo.
|
114 |
worker.current?.postMessage(message)
|
115 |
-
}, [text, modelInfo.
|
116 |
|
117 |
const busy: boolean = status !== 'idle'
|
118 |
|
|
|
3 |
ClassificationOutput,
|
4 |
TextClassificationWorkerInput,
|
5 |
WorkerMessage,
|
|
|
6 |
} from '../types';
|
7 |
import { useModel } from '../contexts/ModelContext';
|
8 |
import { getModelInfo } from '../lib/huggingface';
|
|
|
24 |
function TextClassification() {
|
25 |
const [text, setText] = useState<string>(PLACEHOLDER_TEXTS.join('\n'))
|
26 |
const [results, setResults] = useState<ClassificationOutput[]>([])
|
27 |
+
const { setProgress, status, setStatus, modelInfo, setModelInfo, models, setModels} = useModel()
|
28 |
+
|
29 |
+
|
30 |
useEffect(() => {
|
31 |
+
if (!modelInfo.id) return;
|
32 |
const fetchModelInfo = async () => {
|
33 |
try {
|
34 |
+
const modelInfoResponse = await getModelInfo(modelInfo.id)
|
|
|
35 |
let parameters = 0
|
36 |
if (modelInfoResponse.safetensors) {
|
37 |
const safetensors = modelInfoResponse.safetensors
|
|
|
42 |
0)
|
43 |
}
|
44 |
setModelInfo({
|
45 |
+
...modelInfo,
|
46 |
+
architecture: modelInfoResponse.config?.architectures[0] ?? '',
|
47 |
parameters,
|
48 |
likes: modelInfoResponse.likes,
|
49 |
downloads: modelInfoResponse.downloads
|
|
|
54 |
}
|
55 |
|
56 |
fetchModelInfo()
|
57 |
+
}, [modelInfo.id, setModelInfo])
|
58 |
|
59 |
// Create a reference to the worker object.
|
60 |
const worker = useRef<Worker | null>(null)
|
|
|
110 |
const classify = useCallback(() => {
|
111 |
setStatus('processing')
|
112 |
setResults([]) // Clear previous results
|
113 |
+
const message: TextClassificationWorkerInput = { text, model: modelInfo.id }
|
114 |
worker.current?.postMessage(message)
|
115 |
+
}, [text, modelInfo.id])
|
116 |
|
117 |
const busy: boolean = status !== 'idle'
|
118 |
|
src/components/ZeroShotClassification.tsx
CHANGED
@@ -68,11 +68,13 @@ function ZeroShotClassification() {
|
|
68 |
0
|
69 |
}
|
70 |
setModelInfo({
|
|
|
71 |
name: modelName,
|
72 |
-
architecture: modelInfoResponse.config
|
73 |
parameters,
|
74 |
likes: modelInfoResponse.likes,
|
75 |
-
downloads: modelInfoResponse.downloads
|
|
|
76 |
})
|
77 |
} catch (error) {
|
78 |
console.error('Error fetching model info:', error)
|
|
|
68 |
0
|
69 |
}
|
70 |
setModelInfo({
|
71 |
+
id: modelInfoResponse.id,
|
72 |
name: modelName,
|
73 |
+
architecture: modelInfoResponse.config?.architectures[0] ?? '',
|
74 |
parameters,
|
75 |
likes: modelInfoResponse.likes,
|
76 |
+
downloads: modelInfoResponse.downloads,
|
77 |
+
createdAt: modelInfoResponse.createdAt,
|
78 |
})
|
79 |
} catch (error) {
|
80 |
console.error('Error fetching model info:', error)
|
src/contexts/ModelContext.tsx
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
import React, { createContext, useContext, useEffect, useState } from 'react'
|
2 |
-
import { ModelInfo } from '../types'
|
3 |
|
4 |
interface ModelContextType {
|
5 |
progress: number
|
@@ -8,6 +8,10 @@ interface ModelContextType {
|
|
8 |
setStatus: (status: string) => void
|
9 |
modelInfo: ModelInfo
|
10 |
setModelInfo: (model: ModelInfo) => void
|
|
|
|
|
|
|
|
|
11 |
}
|
12 |
|
13 |
const ModelContext = createContext<ModelContextType | undefined>(undefined)
|
@@ -16,6 +20,8 @@ export function ModelProvider({ children }: { children: React.ReactNode }) {
|
|
16 |
const [progress, setProgress] = useState<number>(0)
|
17 |
const [status, setStatus] = useState<string>('idle')
|
18 |
const [modelInfo, setModelInfo] = useState<ModelInfo>({} as ModelInfo)
|
|
|
|
|
19 |
|
20 |
// set progress to 0 when model is changed
|
21 |
useEffect(() => {
|
@@ -30,7 +36,11 @@ export function ModelProvider({ children }: { children: React.ReactNode }) {
|
|
30 |
status,
|
31 |
setStatus,
|
32 |
modelInfo,
|
33 |
-
setModelInfo
|
|
|
|
|
|
|
|
|
34 |
}}
|
35 |
>
|
36 |
{children}
|
|
|
1 |
import React, { createContext, useContext, useEffect, useState } from 'react'
|
2 |
+
import { ModelInfo, ModelInfoResponse } from '../types'
|
3 |
|
4 |
interface ModelContextType {
|
5 |
progress: number
|
|
|
8 |
setStatus: (status: string) => void
|
9 |
modelInfo: ModelInfo
|
10 |
setModelInfo: (model: ModelInfo) => void
|
11 |
+
pipeline: string
|
12 |
+
setPipeline: (pipeline: string) => void
|
13 |
+
models: ModelInfoResponse[]
|
14 |
+
setModels: (models: ModelInfoResponse[]) => void
|
15 |
}
|
16 |
|
17 |
const ModelContext = createContext<ModelContextType | undefined>(undefined)
|
|
|
20 |
const [progress, setProgress] = useState<number>(0)
|
21 |
const [status, setStatus] = useState<string>('idle')
|
22 |
const [modelInfo, setModelInfo] = useState<ModelInfo>({} as ModelInfo)
|
23 |
+
const [models, setModels] = useState<ModelInfoResponse[]>([] as ModelInfoResponse[])
|
24 |
+
const [pipeline, setPipeline] = useState<string>('zero-shot-classification')
|
25 |
|
26 |
// set progress to 0 when model is changed
|
27 |
useEffect(() => {
|
|
|
36 |
status,
|
37 |
setStatus,
|
38 |
modelInfo,
|
39 |
+
setModelInfo,
|
40 |
+
models,
|
41 |
+
setModels,
|
42 |
+
pipeline,
|
43 |
+
setPipeline,
|
44 |
}}
|
45 |
>
|
46 |
{children}
|
src/lib/huggingface.ts
CHANGED
@@ -1,27 +1,5 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
config: {
|
4 |
-
architectures: string[]
|
5 |
-
model_type: string
|
6 |
-
}
|
7 |
-
lastModified: string
|
8 |
-
pipeline_tag: string
|
9 |
-
tags: string[]
|
10 |
-
transformersInfo: {
|
11 |
-
pipeline_tag: string
|
12 |
-
auto_model: string
|
13 |
-
processor: string
|
14 |
-
}
|
15 |
-
safetensors?: {
|
16 |
-
parameters: {
|
17 |
-
F16?: number
|
18 |
-
F32?: number
|
19 |
-
total?: number
|
20 |
-
}
|
21 |
-
}
|
22 |
-
likes: number
|
23 |
-
downloads: number
|
24 |
-
}
|
25 |
|
26 |
const getModelInfo = async (modelName: string): Promise<ModelInfoResponse> => {
|
27 |
const token = process.env.REACT_APP_HUGGINGFACE_TOKEN
|
@@ -48,6 +26,37 @@ const getModelInfo = async (modelName: string): Promise<ModelInfoResponse> => {
|
|
48 |
return response.json()
|
49 |
}
|
50 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
// Define the possible quantization types for clarity and type safety
|
52 |
type QuantizationType = 'FP32' | 'FP16' | 'INT8' | 'Q4'
|
53 |
function getModelSize(
|
@@ -81,5 +90,5 @@ function getModelSize(
|
|
81 |
}
|
82 |
|
83 |
|
84 |
-
export { getModelInfo, getModelSize }
|
85 |
|
|
|
1 |
+
import { Mode } from "fs"
|
2 |
+
import { ModelInfoResponse } from "../types"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
|
4 |
const getModelInfo = async (modelName: string): Promise<ModelInfoResponse> => {
|
5 |
const token = process.env.REACT_APP_HUGGINGFACE_TOKEN
|
|
|
26 |
return response.json()
|
27 |
}
|
28 |
|
29 |
+
const getModelsByPipeline = async (
|
30 |
+
pipeline_tag: string
|
31 |
+
): Promise<ModelInfoResponse[]> => {
|
32 |
+
const token = process.env.REACT_APP_HUGGINGFACE_TOKEN
|
33 |
+
|
34 |
+
if (!token) {
|
35 |
+
throw new Error(
|
36 |
+
'Hugging Face token not found. Please set REACT_APP_HUGGINGFACE_TOKEN in your .env file'
|
37 |
+
)
|
38 |
+
}
|
39 |
+
|
40 |
+
const response = await fetch(
|
41 |
+
`https://huggingface.co/api/models?filter=${pipeline_tag}&filter=transformers.js&sort=downloads`,
|
42 |
+
{
|
43 |
+
method: 'GET',
|
44 |
+
headers: {
|
45 |
+
Authorization: `Bearer ${token}`
|
46 |
+
}
|
47 |
+
}
|
48 |
+
)
|
49 |
+
|
50 |
+
if (!response.ok) {
|
51 |
+
throw new Error(`Failed to fetch models for pipeline: ${response.statusText}`)
|
52 |
+
}
|
53 |
+
const models = await response.json()
|
54 |
+
if (pipeline_tag === 'text-classification') {
|
55 |
+
return models.filter((model: ModelInfoResponse) => !model.tags.includes('reranker') && !model.id.includes('reranker')).slice(0, 10)
|
56 |
+
}
|
57 |
+
return models.slice(0, 10)
|
58 |
+
}
|
59 |
+
|
60 |
// Define the possible quantization types for clarity and type safety
|
61 |
type QuantizationType = 'FP32' | 'FP16' | 'INT8' | 'Q4'
|
62 |
function getModelSize(
|
|
|
90 |
}
|
91 |
|
92 |
|
93 |
+
export { getModelInfo, getModelSize, getModelsByPipeline }
|
94 |
|
src/types.ts
CHANGED
@@ -28,9 +28,38 @@ export interface TextClassificationWorkerInput {
|
|
28 |
export type AppStatus = 'idle' | 'loading' | 'processing'
|
29 |
|
30 |
export interface ModelInfo {
|
|
|
31 |
name: string
|
32 |
architecture: string
|
33 |
parameters: number
|
34 |
likes: number
|
35 |
downloads: number
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
}
|
|
|
28 |
export type AppStatus = 'idle' | 'loading' | 'processing'
|
29 |
|
30 |
export interface ModelInfo {
|
31 |
+
id: string
|
32 |
name: string
|
33 |
architecture: string
|
34 |
parameters: number
|
35 |
likes: number
|
36 |
downloads: number
|
37 |
+
createdAt: string
|
38 |
+
}
|
39 |
+
|
40 |
+
|
41 |
+
export interface ModelInfoResponse {
|
42 |
+
id: string
|
43 |
+
createdAt: string
|
44 |
+
config?: {
|
45 |
+
architectures: string[]
|
46 |
+
model_type: string
|
47 |
+
}
|
48 |
+
lastModified: string
|
49 |
+
pipeline_tag: string
|
50 |
+
tags: string[]
|
51 |
+
transformersInfo: {
|
52 |
+
pipeline_tag: string
|
53 |
+
auto_model: string
|
54 |
+
processor: string
|
55 |
+
}
|
56 |
+
safetensors?: {
|
57 |
+
parameters: {
|
58 |
+
F16?: number
|
59 |
+
F32?: number
|
60 |
+
total?: number
|
61 |
+
}
|
62 |
+
}
|
63 |
+
likes: number
|
64 |
+
downloads: number
|
65 |
}
|