Vokturz commited on
Commit
1150456
·
1 Parent(s): 96812c9

feat: update model names and add model size calculation utility

Browse files
src/App.tsx CHANGED
@@ -1,15 +1,16 @@
1
- import { useState } from 'react';
2
- import PipelineSelector from './components/PipelineSelector';
3
- import ZeroShotClassification from './components/ZeroShotClassification';
4
- import TextClassification from './components/TextClassification';
5
- import Header from './Header';
6
- import Footer from './Footer';
7
- import { useModel } from './contexts/ModelContext';
8
- import { Bot, Heart, Download, Cpu } from 'lucide-react';
 
9
 
10
  function App() {
11
- const [pipeline, setPipeline] = useState('zero-shot-classification');
12
- const { progress, status, modelInfo } = useModel();
13
 
14
  const formatNumber = (num: number) => {
15
  if (num >= 1000000000) {
@@ -19,8 +20,8 @@ function App() {
19
  } else if (num >= 1000) {
20
  return (num / 1000).toFixed(1) + 'K'
21
  }
22
- return num.toString();
23
- };
24
 
25
  return (
26
  <div className="min-h-screen bg-gradient-to-br from-blue-50 to-indigo-100">
@@ -34,43 +35,59 @@ function App() {
34
  <h2 className="text-lg font-semibold text-gray-900">
35
  Choose a Pipeline
36
  </h2>
37
-
38
  {/* Model Info Display */}
39
  {modelInfo.name && (
40
- <div className="flex items-center space-x-4 bg-gradient-to-r from-blue-50 to-indigo-50 px-4 py-2 rounded-lg border border-blue-200">
 
41
  <div className="flex items-center space-x-2">
42
  <Bot className="w-4 h-4 text-blue-600" />
43
- <span className="text-sm font-medium text-gray-700 truncate max-w-80" title={modelInfo.name}>
 
 
 
44
  {modelInfo.name.split('/').pop()}
45
  </span>
46
  </div>
47
-
48
- <div className="flex items-center space-x-4 text-xs text-gray-600">
 
49
  {modelInfo.likes > 0 && (
50
  <div className="flex items-center space-x-1">
51
  <Heart className="w-3 h-3 text-red-500" />
52
  <span>{formatNumber(modelInfo.likes)}</span>
53
  </div>
54
  )}
55
-
56
  {modelInfo.downloads > 0 && (
57
  <div className="flex items-center space-x-1">
58
  <Download className="w-3 h-3 text-green-500" />
59
  <span>{formatNumber(modelInfo.downloads)}</span>
60
  </div>
61
  )}
62
-
63
  {modelInfo.parameters > 0 && (
64
  <div className="flex items-center space-x-1">
65
  <Cpu className="w-3 h-3 text-purple-500" />
66
  <span>{formatNumber(modelInfo.parameters)}</span>
67
  </div>
68
  )}
 
 
 
 
 
 
 
 
 
 
 
69
  </div>
70
  </div>
71
  )}
72
  </div>
73
-
74
  <PipelineSelector pipeline={pipeline} setPipeline={setPipeline} />
75
 
76
  {/* Model Loading Progress */}
@@ -108,7 +125,9 @@ function App() {
108
  style={{ width: `${progress.toFixed(2)}%` }}
109
  ></div>
110
  </div>
111
- <p className="text-xs text-blue-700 mt-1">{progress.toFixed(2)}%</p>
 
 
112
  </div>
113
  </div>
114
  </div>
@@ -158,7 +177,7 @@ function App() {
158
 
159
  <Footer />
160
  </div>
161
- );
162
  }
163
 
164
- export default App;
 
1
+ import { useState } from 'react'
2
+ import PipelineSelector from './components/PipelineSelector'
3
+ import ZeroShotClassification from './components/ZeroShotClassification'
4
+ import TextClassification from './components/TextClassification'
5
+ import Header from './Header'
6
+ import Footer from './Footer'
7
+ import { useModel } from './contexts/ModelContext'
8
+ import { Bot, Heart, Download, Cpu, DatabaseIcon } from 'lucide-react'
9
+ import { getModelSize } from './lib/huggingface'
10
 
11
  function App() {
12
+ const [pipeline, setPipeline] = useState('zero-shot-classification')
13
+ const { progress, status, modelInfo } = useModel()
14
 
15
  const formatNumber = (num: number) => {
16
  if (num >= 1000000000) {
 
20
  } else if (num >= 1000) {
21
  return (num / 1000).toFixed(1) + 'K'
22
  }
23
+ return num.toString()
24
+ }
25
 
26
  return (
27
  <div className="min-h-screen bg-gradient-to-br from-blue-50 to-indigo-100">
 
35
  <h2 className="text-lg font-semibold text-gray-900">
36
  Choose a Pipeline
37
  </h2>
38
+
39
  {/* Model Info Display */}
40
  {modelInfo.name && (
41
+ <div className="bg-gradient-to-r from-blue-50 to-indigo-50 px-4 py-3 rounded-lg border border-blue-200 space-y-2">
42
+ {/* Model Name Row */}
43
  <div className="flex items-center space-x-2">
44
  <Bot className="w-4 h-4 text-blue-600" />
45
+ <span
46
+ className="text-sm font-medium text-gray-700 truncate max-w-100"
47
+ title={modelInfo.name}
48
+ >
49
  {modelInfo.name.split('/').pop()}
50
  </span>
51
  </div>
52
+
53
+ {/* Stats Row */}
54
+ <div className="flex items-center justify-self-end space-x-4 text-xs text-gray-600">
55
  {modelInfo.likes > 0 && (
56
  <div className="flex items-center space-x-1">
57
  <Heart className="w-3 h-3 text-red-500" />
58
  <span>{formatNumber(modelInfo.likes)}</span>
59
  </div>
60
  )}
61
+
62
  {modelInfo.downloads > 0 && (
63
  <div className="flex items-center space-x-1">
64
  <Download className="w-3 h-3 text-green-500" />
65
  <span>{formatNumber(modelInfo.downloads)}</span>
66
  </div>
67
  )}
68
+
69
  {modelInfo.parameters > 0 && (
70
  <div className="flex items-center space-x-1">
71
  <Cpu className="w-3 h-3 text-purple-500" />
72
  <span>{formatNumber(modelInfo.parameters)}</span>
73
  </div>
74
  )}
75
+
76
+ {modelInfo.parameters > 0 && (
77
+ <div className="flex items-center space-x-1">
78
+ <DatabaseIcon className="w-3 h-3 text-purple-500" />
79
+ <span>
80
+ {`~${getModelSize(modelInfo.parameters, 'INT8').toFixed(
81
+ 1
82
+ )}MB`}
83
+ </span>
84
+ </div>
85
+ )}
86
  </div>
87
  </div>
88
  )}
89
  </div>
90
+
91
  <PipelineSelector pipeline={pipeline} setPipeline={setPipeline} />
92
 
93
  {/* Model Loading Progress */}
 
125
  style={{ width: `${progress.toFixed(2)}%` }}
126
  ></div>
127
  </div>
128
+ <p className="text-xs text-blue-700 mt-1">
129
+ {progress.toFixed(2)}%
130
+ </p>
131
  </div>
132
  </div>
133
  </div>
 
177
 
178
  <Footer />
179
  </div>
180
+ )
181
  }
182
 
183
+ export default App
src/components/TextClassification.tsx CHANGED
@@ -27,7 +27,7 @@ function TextClassification() {
27
  const [results, setResults] = useState<ClassificationOutput[]>([])
28
  const { setProgress, status, setStatus, modelInfo, setModelInfo} = useModel()
29
  useEffect(() => {
30
- const modelName = 'Xenova/distilbert-base-uncased-finetuned-sst-2-english'
31
  const fetchModelInfo = async () => {
32
  try {
33
  const modelInfoResponse = await getModelInfo(modelName)
 
27
  const [results, setResults] = useState<ClassificationOutput[]>([])
28
  const { setProgress, status, setStatus, modelInfo, setModelInfo} = useModel()
29
  useEffect(() => {
30
+ const modelName = 'distilbert/distilbert-base-uncased-finetuned-sst-2-english'
31
  const fetchModelInfo = async () => {
32
  try {
33
  const modelInfoResponse = await getModelInfo(modelName)
src/components/ZeroShotClassification.tsx CHANGED
@@ -53,7 +53,7 @@ function ZeroShotClassification() {
53
 
54
  const { setProgress, status, setStatus, modelInfo, setModelInfo } = useModel()
55
  useEffect(() => {
56
- const modelName = 'MoritzLaurer/deberta-v3-xsmall-zeroshot-v1.1-all-33'
57
  const fetchModelInfo = async () => {
58
  try {
59
  const modelInfoResponse = await getModelInfo(modelName)
 
53
 
54
  const { setProgress, status, setStatus, modelInfo, setModelInfo } = useModel()
55
  useEffect(() => {
56
+ const modelName = 'lxyuan/distilbert-base-multilingual-cased-sentiments-student'
57
  const fetchModelInfo = async () => {
58
  try {
59
  const modelInfoResponse = await getModelInfo(modelName)
src/lib/huggingface.ts CHANGED
@@ -48,4 +48,38 @@ const getModelInfo = async (modelName: string): Promise<ModelInfoResponse> => {
48
  return response.json()
49
  }
50
 
51
- export { getModelInfo }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  return response.json()
49
  }
50
 
51
+ // Define the possible quantization types for clarity and type safety
52
+ type QuantizationType = 'FP32' | 'FP16' | 'INT8' | 'Q4'
53
+ function getModelSize(
54
+ parameters: number,
55
+ quantization: QuantizationType
56
+ ): number {
57
+ let bytesPerParameter: number
58
+
59
+ switch (quantization) {
60
+ case 'FP32':
61
+ // 32-bit floating point uses 4 bytes
62
+ bytesPerParameter = 4
63
+ break
64
+ case 'FP16':
65
+ bytesPerParameter = 2
66
+ break
67
+ case 'INT8':
68
+ bytesPerParameter = 1
69
+ break
70
+ case 'Q4':
71
+ bytesPerParameter = 0.5
72
+ const theoreticalSize = (parameters * bytesPerParameter) / (1024 * 1024)
73
+ return theoreticalSize
74
+ }
75
+
76
+ // There are 1,024 * 1,024 bytes in a megabyte
77
+ const sizeInBytes = parameters * bytesPerParameter
78
+ const sizeInMB = sizeInBytes / (1024 * 1024)
79
+
80
+ return sizeInMB
81
+ }
82
+
83
+
84
+ export { getModelInfo, getModelSize }
85
+