feat: enhance model information display and improve context management
Browse files- Added model information display in the pipeline selection area, including model name, likes, downloads, and parameters.
- Updated the context to manage model information instead of just the model name.
- Integrated model information fetching from Hugging Face API in both TextClassification and ZeroShotClassification components.
- Modified worker scripts to accept dynamic model names for text classification and zero-shot classification tasks.
- Improved formatting for large numbers in the UI.
- Added environment variable support for Hugging Face API token.
- Updated types to include ModelInfo interface for better type safety.
- .env.example +1 -0
- .gitignore +2 -0
- .vscode/settings.json +3 -0
- package.json +3 -0
- pnpm-lock.yaml +0 -0
- src/App.tsx +54 -9
- src/components/TextClassification.tsx +64 -31
- src/components/ZeroShotClassification.tsx +95 -57
- src/contexts/ModelContext.tsx +26 -18
- src/lib/huggingface.ts +51 -0
- src/types.ts +21 -11
- src/workers/text-classification.js +13 -5
- src/workers/zero-shot.js +12 -7
.env.example
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
REACT_APP_HUGGINGFACE_TOKEN="hf_..."
|
.gitignore
CHANGED
@@ -21,3 +21,5 @@
|
|
21 |
npm-debug.log*
|
22 |
yarn-debug.log*
|
23 |
yarn-error.log*
|
|
|
|
|
|
21 |
npm-debug.log*
|
22 |
yarn-debug.log*
|
23 |
yarn-error.log*
|
24 |
+
|
25 |
+
.env
|
.vscode/settings.json
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"CodeGPT.apiKey": "CodeGPT Plus Beta"
|
3 |
+
}
|
package.json
CHANGED
@@ -13,6 +13,9 @@
|
|
13 |
"@types/react": "^19.1.8",
|
14 |
"@types/react-dom": "^19.1.6",
|
15 |
"build": "^0.1.4",
|
|
|
|
|
|
|
16 |
"react": "^19.1.0",
|
17 |
"react-dom": "^19.1.0",
|
18 |
"react-scripts": "5.0.1",
|
|
|
13 |
"@types/react": "^19.1.8",
|
14 |
"@types/react-dom": "^19.1.6",
|
15 |
"build": "^0.1.4",
|
16 |
+
"dotenv": "^17.0.1",
|
17 |
+
"lucide-react": "^0.525.0",
|
18 |
+
"path-browserify": "^1.0.1",
|
19 |
"react": "^19.1.0",
|
20 |
"react-dom": "^19.1.0",
|
21 |
"react-scripts": "5.0.1",
|
pnpm-lock.yaml
CHANGED
The diff for this file is too large to render.
See raw diff
|
|
src/App.tsx
CHANGED
@@ -5,10 +5,22 @@ import TextClassification from './components/TextClassification';
|
|
5 |
import Header from './Header';
|
6 |
import Footer from './Footer';
|
7 |
import { useModel } from './contexts/ModelContext';
|
|
|
8 |
|
9 |
function App() {
|
10 |
const [pipeline, setPipeline] = useState('zero-shot-classification');
|
11 |
-
const { progress, status,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
|
13 |
return (
|
14 |
<div className="min-h-screen bg-gradient-to-br from-blue-50 to-indigo-100">
|
@@ -18,9 +30,47 @@ function App() {
|
|
18 |
{/* Pipeline Selection */}
|
19 |
<div className="mb-8">
|
20 |
<div className="bg-white rounded-lg shadow-sm border p-6">
|
21 |
-
<
|
22 |
-
|
23 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
<PipelineSelector pipeline={pipeline} setPipeline={setPipeline} />
|
25 |
|
26 |
{/* Model Loading Progress */}
|
@@ -85,11 +135,6 @@ function App() {
|
|
85 |
{pipeline === 'zero-shot-classification'
|
86 |
? 'Zero-Shot Classification'
|
87 |
: 'Text-Classification'}
|
88 |
-
{model && (
|
89 |
-
<span className="ml-2 text-xs text-gray-500 font-normal">
|
90 |
-
({model})
|
91 |
-
</span>
|
92 |
-
)}
|
93 |
</h3>
|
94 |
<p className="text-sm text-gray-600 mt-1">
|
95 |
{pipeline === 'zero-shot-classification'
|
|
|
5 |
import Header from './Header';
|
6 |
import Footer from './Footer';
|
7 |
import { useModel } from './contexts/ModelContext';
|
8 |
+
import { Bot, Heart, Download, Cpu } from 'lucide-react';
|
9 |
|
10 |
function App() {
|
11 |
const [pipeline, setPipeline] = useState('zero-shot-classification');
|
12 |
+
const { progress, status, modelInfo } = useModel();
|
13 |
+
|
14 |
+
const formatNumber = (num: number) => {
|
15 |
+
if (num >= 1000000000) {
|
16 |
+
return (num / 1000000000).toFixed(1) + 'B'
|
17 |
+
} else if (num >= 1000000) {
|
18 |
+
return (num / 1000000).toFixed(1) + 'M'
|
19 |
+
} else if (num >= 1000) {
|
20 |
+
return (num / 1000).toFixed(1) + 'K'
|
21 |
+
}
|
22 |
+
return num.toString();
|
23 |
+
};
|
24 |
|
25 |
return (
|
26 |
<div className="min-h-screen bg-gradient-to-br from-blue-50 to-indigo-100">
|
|
|
30 |
{/* Pipeline Selection */}
|
31 |
<div className="mb-8">
|
32 |
<div className="bg-white rounded-lg shadow-sm border p-6">
|
33 |
+
<div className="flex items-center justify-between mb-4">
|
34 |
+
<h2 className="text-lg font-semibold text-gray-900">
|
35 |
+
Choose a Pipeline
|
36 |
+
</h2>
|
37 |
+
|
38 |
+
{/* Model Info Display */}
|
39 |
+
{modelInfo.name && (
|
40 |
+
<div className="flex items-center space-x-4 bg-gradient-to-r from-blue-50 to-indigo-50 px-4 py-2 rounded-lg border border-blue-200">
|
41 |
+
<div className="flex items-center space-x-2">
|
42 |
+
<Bot className="w-4 h-4 text-blue-600" />
|
43 |
+
<span className="text-sm font-medium text-gray-700 truncate max-w-80" title={modelInfo.name}>
|
44 |
+
{modelInfo.name.split('/').pop()}
|
45 |
+
</span>
|
46 |
+
</div>
|
47 |
+
|
48 |
+
<div className="flex items-center space-x-4 text-xs text-gray-600">
|
49 |
+
{modelInfo.likes > 0 && (
|
50 |
+
<div className="flex items-center space-x-1">
|
51 |
+
<Heart className="w-3 h-3 text-red-500" />
|
52 |
+
<span>{formatNumber(modelInfo.likes)}</span>
|
53 |
+
</div>
|
54 |
+
)}
|
55 |
+
|
56 |
+
{modelInfo.downloads > 0 && (
|
57 |
+
<div className="flex items-center space-x-1">
|
58 |
+
<Download className="w-3 h-3 text-green-500" />
|
59 |
+
<span>{formatNumber(modelInfo.downloads)}</span>
|
60 |
+
</div>
|
61 |
+
)}
|
62 |
+
|
63 |
+
{modelInfo.parameters > 0 && (
|
64 |
+
<div className="flex items-center space-x-1">
|
65 |
+
<Cpu className="w-3 h-3 text-purple-500" />
|
66 |
+
<span>{formatNumber(modelInfo.parameters)}</span>
|
67 |
+
</div>
|
68 |
+
)}
|
69 |
+
</div>
|
70 |
+
</div>
|
71 |
+
)}
|
72 |
+
</div>
|
73 |
+
|
74 |
<PipelineSelector pipeline={pipeline} setPipeline={setPipeline} />
|
75 |
|
76 |
{/* Model Loading Progress */}
|
|
|
135 |
{pipeline === 'zero-shot-classification'
|
136 |
? 'Zero-Shot Classification'
|
137 |
: 'Text-Classification'}
|
|
|
|
|
|
|
|
|
|
|
138 |
</h3>
|
139 |
<p className="text-sm text-gray-600 mt-1">
|
140 |
{pipeline === 'zero-shot-classification'
|
src/components/TextClassification.tsx
CHANGED
@@ -2,9 +2,11 @@ import { useState, useRef, useEffect, useCallback } from 'react';
|
|
2 |
import {
|
3 |
ClassificationOutput,
|
4 |
TextClassificationWorkerInput,
|
5 |
-
WorkerMessage
|
|
|
6 |
} from '../types';
|
7 |
import { useModel } from '../contexts/ModelContext';
|
|
|
8 |
|
9 |
|
10 |
const PLACEHOLDER_TEXTS: string[] = [
|
@@ -21,13 +23,41 @@ const PLACEHOLDER_TEXTS: string[] = [
|
|
21 |
].sort(() => Math.random() - 0.5);
|
22 |
|
23 |
function TextClassification() {
|
24 |
-
const [text, setText] = useState<string>(PLACEHOLDER_TEXTS.join('\n'))
|
25 |
-
const [results, setResults] = useState<ClassificationOutput[]>([])
|
26 |
-
const { setProgress, status, setStatus,
|
27 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
|
29 |
// Create a reference to the worker object.
|
30 |
-
const worker = useRef<Worker | null>(null)
|
31 |
|
32 |
// We use the `useEffect` hook to setup the worker as soon as the component is mounted.
|
33 |
useEffect(() => {
|
@@ -38,54 +68,57 @@ function TextClassification() {
|
|
38 |
{
|
39 |
type: 'module'
|
40 |
}
|
41 |
-
)
|
42 |
}
|
43 |
|
44 |
// Create a callback function for messages from the worker thread.
|
45 |
const onMessageReceived = (e: MessageEvent<WorkerMessage>) => {
|
46 |
-
const status = e.data.status
|
47 |
if (status === 'initiate') {
|
48 |
-
setStatus('loading')
|
49 |
} else if (status === 'ready') {
|
50 |
-
setStatus('ready')
|
51 |
} else if (status === 'progress') {
|
52 |
-
setStatus('progress')
|
53 |
if (
|
54 |
e.data.output.progress &&
|
55 |
(e.data.output.file as string).startsWith('onnx')
|
56 |
)
|
57 |
-
setProgress(e.data.output.progress)
|
58 |
} else if (status === 'output') {
|
59 |
-
setStatus('output')
|
60 |
-
const result = e.data.output
|
61 |
-
setResults((prevResults) => [...prevResults, result])
|
62 |
-
console.log(result)
|
63 |
} else if (status === 'complete') {
|
64 |
-
setStatus('idle')
|
65 |
-
setProgress(100)
|
|
|
|
|
|
|
66 |
}
|
67 |
-
}
|
68 |
|
69 |
// Attach the callback function as an event listener.
|
70 |
-
worker.current.addEventListener('message', onMessageReceived)
|
71 |
|
72 |
// Define a cleanup function for when the component is unmounted.
|
73 |
return () =>
|
74 |
-
worker.current?.removeEventListener('message', onMessageReceived)
|
75 |
-
}, [])
|
76 |
|
77 |
const classify = useCallback(() => {
|
78 |
-
setStatus('processing')
|
79 |
-
setResults([])
|
80 |
-
const message: TextClassificationWorkerInput = { text }
|
81 |
-
worker.current?.postMessage(message)
|
82 |
-
}, [text])
|
83 |
|
84 |
-
const busy: boolean = status !== 'idle'
|
85 |
|
86 |
const handleClear = (): void => {
|
87 |
-
setResults([])
|
88 |
-
}
|
89 |
|
90 |
return (
|
91 |
<div className="flex flex-col h-[40vh] max-h-[80vh] w-full p-4">
|
@@ -157,7 +190,7 @@ function TextClassification() {
|
|
157 |
</div>
|
158 |
</div>
|
159 |
</div>
|
160 |
-
)
|
161 |
}
|
162 |
|
163 |
export default TextClassification;
|
|
|
2 |
import {
|
3 |
ClassificationOutput,
|
4 |
TextClassificationWorkerInput,
|
5 |
+
WorkerMessage,
|
6 |
+
ModelInfo
|
7 |
} from '../types';
|
8 |
import { useModel } from '../contexts/ModelContext';
|
9 |
+
import { getModelInfo } from '../lib/huggingface';
|
10 |
|
11 |
|
12 |
const PLACEHOLDER_TEXTS: string[] = [
|
|
|
23 |
].sort(() => Math.random() - 0.5);
|
24 |
|
25 |
function TextClassification() {
|
26 |
+
const [text, setText] = useState<string>(PLACEHOLDER_TEXTS.join('\n'))
|
27 |
+
const [results, setResults] = useState<ClassificationOutput[]>([])
|
28 |
+
const { setProgress, status, setStatus, modelInfo, setModelInfo} = useModel()
|
29 |
+
useEffect(() => {
|
30 |
+
const modelName = 'Xenova/distilbert-base-uncased-finetuned-sst-2-english'
|
31 |
+
const fetchModelInfo = async () => {
|
32 |
+
try {
|
33 |
+
const modelInfoResponse = await getModelInfo(modelName)
|
34 |
+
console.log(modelInfoResponse)
|
35 |
+
let parameters = 0
|
36 |
+
if (modelInfoResponse.safetensors) {
|
37 |
+
const safetensors = modelInfoResponse.safetensors
|
38 |
+
parameters =
|
39 |
+
(safetensors.parameters.F16 ||
|
40 |
+
safetensors.parameters.F32 ||
|
41 |
+
safetensors.parameters.total ||
|
42 |
+
0)
|
43 |
+
}
|
44 |
+
setModelInfo({
|
45 |
+
name: modelName,
|
46 |
+
architecture: modelInfoResponse.config.architectures[0],
|
47 |
+
parameters,
|
48 |
+
likes: modelInfoResponse.likes,
|
49 |
+
downloads: modelInfoResponse.downloads
|
50 |
+
})
|
51 |
+
} catch (error) {
|
52 |
+
console.error('Error fetching model info:', error)
|
53 |
+
}
|
54 |
+
}
|
55 |
+
|
56 |
+
fetchModelInfo()
|
57 |
+
}, [setModelInfo])
|
58 |
|
59 |
// Create a reference to the worker object.
|
60 |
+
const worker = useRef<Worker | null>(null)
|
61 |
|
62 |
// We use the `useEffect` hook to setup the worker as soon as the component is mounted.
|
63 |
useEffect(() => {
|
|
|
68 |
{
|
69 |
type: 'module'
|
70 |
}
|
71 |
+
)
|
72 |
}
|
73 |
|
74 |
// Create a callback function for messages from the worker thread.
|
75 |
const onMessageReceived = (e: MessageEvent<WorkerMessage>) => {
|
76 |
+
const status = e.data.status
|
77 |
if (status === 'initiate') {
|
78 |
+
setStatus('loading')
|
79 |
} else if (status === 'ready') {
|
80 |
+
setStatus('ready')
|
81 |
} else if (status === 'progress') {
|
82 |
+
setStatus('progress')
|
83 |
if (
|
84 |
e.data.output.progress &&
|
85 |
(e.data.output.file as string).startsWith('onnx')
|
86 |
)
|
87 |
+
setProgress(e.data.output.progress)
|
88 |
} else if (status === 'output') {
|
89 |
+
setStatus('output')
|
90 |
+
const result = e.data.output!
|
91 |
+
setResults((prevResults) => [...prevResults, result])
|
92 |
+
console.log(result)
|
93 |
} else if (status === 'complete') {
|
94 |
+
setStatus('idle')
|
95 |
+
setProgress(100)
|
96 |
+
} else if (status === 'error') {
|
97 |
+
setStatus('error')
|
98 |
+
console.error(e.data.output)
|
99 |
}
|
100 |
+
}
|
101 |
|
102 |
// Attach the callback function as an event listener.
|
103 |
+
worker.current.addEventListener('message', onMessageReceived)
|
104 |
|
105 |
// Define a cleanup function for when the component is unmounted.
|
106 |
return () =>
|
107 |
+
worker.current?.removeEventListener('message', onMessageReceived)
|
108 |
+
}, [])
|
109 |
|
110 |
const classify = useCallback(() => {
|
111 |
+
setStatus('processing')
|
112 |
+
setResults([]) // Clear previous results
|
113 |
+
const message: TextClassificationWorkerInput = { text, model: modelInfo.name }
|
114 |
+
worker.current?.postMessage(message)
|
115 |
+
}, [text, modelInfo.name])
|
116 |
|
117 |
+
const busy: boolean = status !== 'idle'
|
118 |
|
119 |
const handleClear = (): void => {
|
120 |
+
setResults([])
|
121 |
+
}
|
122 |
|
123 |
return (
|
124 |
<div className="flex flex-col h-[40vh] max-h-[80vh] w-full p-4">
|
|
|
190 |
</div>
|
191 |
</div>
|
192 |
</div>
|
193 |
+
)
|
194 |
}
|
195 |
|
196 |
export default TextClassification;
|
src/components/ZeroShotClassification.tsx
CHANGED
@@ -1,7 +1,13 @@
|
|
1 |
// src/App.tsx
|
2 |
-
import { useState, useRef, useEffect, useCallback } from 'react'
|
3 |
-
import {
|
4 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
|
6 |
const PLACEHOLDER_REVIEWS: string[] = [
|
7 |
// battery/charging problems
|
@@ -28,7 +34,7 @@ const PLACEHOLDER_REVIEWS: string[] = [
|
|
28 |
"I'm not sure what to make of this phone. It's not bad, but it's not great either. I'm on the fence about it.",
|
29 |
"I hate the color of this phone. It's so ugly!",
|
30 |
"This phone sucks! I'm returning it."
|
31 |
-
].sort(() => Math.random() - 0.5)
|
32 |
|
33 |
const PLACEHOLDER_SECTIONS: string[] = [
|
34 |
'Battery and charging problems',
|
@@ -36,20 +42,48 @@ const PLACEHOLDER_SECTIONS: string[] = [
|
|
36 |
'Poor build quality',
|
37 |
'Software issues',
|
38 |
'Other'
|
39 |
-
]
|
40 |
|
41 |
function ZeroShotClassification() {
|
42 |
-
const [text, setText] = useState<string>(PLACEHOLDER_REVIEWS.join('\n'))
|
43 |
|
44 |
const [sections, setSections] = useState<Section[]>(
|
45 |
PLACEHOLDER_SECTIONS.map((title) => ({ title, items: [] }))
|
46 |
-
)
|
47 |
|
48 |
-
const { setProgress, status, setStatus,
|
49 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
|
51 |
// Create a reference to the worker object.
|
52 |
-
const worker = useRef<Worker | null>(null)
|
53 |
|
54 |
// We use the `useEffect` hook to setup the worker as soon as the `App` component is mounted.
|
55 |
useEffect(() => {
|
@@ -60,86 +94,90 @@ function ZeroShotClassification() {
|
|
60 |
{
|
61 |
type: 'module'
|
62 |
}
|
63 |
-
)
|
64 |
}
|
65 |
|
66 |
// Create a callback function for messages from the worker thread.
|
67 |
const onMessageReceived = (e: MessageEvent<WorkerMessage>) => {
|
68 |
-
const status = e.data.status
|
69 |
if (status === 'initiate') {
|
70 |
-
setStatus('loading')
|
71 |
} else if (status === 'ready') {
|
72 |
-
setStatus('ready')
|
73 |
} else if (status === 'progress') {
|
74 |
-
setStatus('progress')
|
75 |
if (
|
76 |
e.data.output.progress &&
|
77 |
(e.data.output.file as string).startsWith('onnx')
|
78 |
)
|
79 |
-
setProgress(e.data.output.progress)
|
80 |
} else if (status === 'output') {
|
81 |
-
setStatus('output')
|
82 |
-
const { sequence, labels, scores } = e.data.output
|
83 |
|
84 |
// Threshold for classification
|
85 |
-
const label = scores[0] > 0.5 ? labels[0] : 'Other'
|
86 |
|
87 |
const sectionID =
|
88 |
-
sections.map((x) => x.title).indexOf(label) ?? sections.length - 1
|
89 |
setSections((sections) => {
|
90 |
-
const newSections = [...sections]
|
91 |
newSections[sectionID] = {
|
92 |
...newSections[sectionID],
|
93 |
items: [...newSections[sectionID].items, sequence]
|
94 |
-
}
|
95 |
-
return newSections
|
96 |
-
})
|
97 |
} else if (status === 'complete') {
|
98 |
-
setStatus('idle')
|
99 |
-
setProgress(100)
|
|
|
|
|
|
|
100 |
}
|
101 |
-
}
|
102 |
|
103 |
// Attach the callback function as an event listener.
|
104 |
-
worker.current.addEventListener('message', onMessageReceived)
|
105 |
|
106 |
// Define a cleanup function for when the component is unmounted.
|
107 |
return () =>
|
108 |
-
worker.current?.removeEventListener('message', onMessageReceived)
|
109 |
-
}, [sections])
|
110 |
|
111 |
const classify = useCallback(() => {
|
112 |
-
setStatus('processing')
|
113 |
const message: ZeroShotWorkerInput = {
|
114 |
text,
|
115 |
labels: sections
|
116 |
.slice(0, sections.length - 1)
|
117 |
-
.map((section) => section.title)
|
118 |
-
|
119 |
-
|
120 |
-
|
|
|
121 |
|
122 |
-
const busy: boolean = status !== 'idle'
|
123 |
|
124 |
const handleAddCategory = (): void => {
|
125 |
setSections((sections) => {
|
126 |
-
const newSections = [...sections]
|
127 |
// add at position 2 from the end
|
128 |
newSections.splice(newSections.length - 1, 0, {
|
129 |
title: 'New Category',
|
130 |
items: []
|
131 |
-
})
|
132 |
-
return newSections
|
133 |
-
})
|
134 |
-
}
|
135 |
|
136 |
const handleRemoveCategory = (): void => {
|
137 |
setSections((sections) => {
|
138 |
-
const newSections = [...sections]
|
139 |
-
newSections.splice(newSections.length - 2, 1)
|
140 |
-
return newSections
|
141 |
-
})
|
142 |
-
}
|
143 |
|
144 |
const handleClear = (): void => {
|
145 |
setSections((sections) =>
|
@@ -147,16 +185,16 @@ function ZeroShotClassification() {
|
|
147 |
...section,
|
148 |
items: []
|
149 |
}))
|
150 |
-
)
|
151 |
-
}
|
152 |
|
153 |
const handleSectionTitleChange = (index: number, newTitle: string): void => {
|
154 |
setSections((sections) => {
|
155 |
-
const newSections = [...sections]
|
156 |
-
newSections[index].title = newTitle
|
157 |
-
return newSections
|
158 |
-
})
|
159 |
-
}
|
160 |
|
161 |
return (
|
162 |
<div className="flex flex-col h-screen w-full p-1">
|
@@ -174,8 +212,8 @@ function ZeroShotClassification() {
|
|
174 |
{!busy
|
175 |
? 'Categorize'
|
176 |
: status === 'loading'
|
177 |
-
|
178 |
-
|
179 |
</button>
|
180 |
<div className="flex gap-1">
|
181 |
<button
|
@@ -223,7 +261,7 @@ function ZeroShotClassification() {
|
|
223 |
))}
|
224 |
</div>
|
225 |
</div>
|
226 |
-
)
|
227 |
}
|
228 |
|
229 |
-
export default ZeroShotClassification
|
|
|
1 |
// src/App.tsx
|
2 |
+
import { useState, useRef, useEffect, useCallback } from 'react'
|
3 |
+
import {
|
4 |
+
Section,
|
5 |
+
WorkerMessage,
|
6 |
+
ZeroShotWorkerInput,
|
7 |
+
ModelInfo
|
8 |
+
} from '../types'
|
9 |
+
import { useModel } from '../contexts/ModelContext'
|
10 |
+
import { getModelInfo } from '../lib/huggingface'
|
11 |
|
12 |
const PLACEHOLDER_REVIEWS: string[] = [
|
13 |
// battery/charging problems
|
|
|
34 |
"I'm not sure what to make of this phone. It's not bad, but it's not great either. I'm on the fence about it.",
|
35 |
"I hate the color of this phone. It's so ugly!",
|
36 |
"This phone sucks! I'm returning it."
|
37 |
+
].sort(() => Math.random() - 0.5)
|
38 |
|
39 |
const PLACEHOLDER_SECTIONS: string[] = [
|
40 |
'Battery and charging problems',
|
|
|
42 |
'Poor build quality',
|
43 |
'Software issues',
|
44 |
'Other'
|
45 |
+
]
|
46 |
|
47 |
function ZeroShotClassification() {
|
48 |
+
const [text, setText] = useState<string>(PLACEHOLDER_REVIEWS.join('\n'))
|
49 |
|
50 |
const [sections, setSections] = useState<Section[]>(
|
51 |
PLACEHOLDER_SECTIONS.map((title) => ({ title, items: [] }))
|
52 |
+
)
|
53 |
|
54 |
+
const { setProgress, status, setStatus, modelInfo, setModelInfo } = useModel()
|
55 |
+
useEffect(() => {
|
56 |
+
const modelName = 'MoritzLaurer/deberta-v3-xsmall-zeroshot-v1.1-all-33'
|
57 |
+
const fetchModelInfo = async () => {
|
58 |
+
try {
|
59 |
+
const modelInfoResponse = await getModelInfo(modelName)
|
60 |
+
console.log(modelInfoResponse)
|
61 |
+
let parameters = 0
|
62 |
+
if (modelInfoResponse.safetensors) {
|
63 |
+
const safetensors = modelInfoResponse.safetensors
|
64 |
+
parameters =
|
65 |
+
safetensors.parameters.F16 ||
|
66 |
+
safetensors.parameters.F32 ||
|
67 |
+
safetensors.parameters.total ||
|
68 |
+
0
|
69 |
+
}
|
70 |
+
setModelInfo({
|
71 |
+
name: modelName,
|
72 |
+
architecture: modelInfoResponse.config.architectures[0],
|
73 |
+
parameters,
|
74 |
+
likes: modelInfoResponse.likes,
|
75 |
+
downloads: modelInfoResponse.downloads
|
76 |
+
})
|
77 |
+
} catch (error) {
|
78 |
+
console.error('Error fetching model info:', error)
|
79 |
+
}
|
80 |
+
}
|
81 |
+
|
82 |
+
fetchModelInfo()
|
83 |
+
}, [setModelInfo])
|
84 |
|
85 |
// Create a reference to the worker object.
|
86 |
+
const worker = useRef<Worker | null>(null)
|
87 |
|
88 |
// We use the `useEffect` hook to setup the worker as soon as the `App` component is mounted.
|
89 |
useEffect(() => {
|
|
|
94 |
{
|
95 |
type: 'module'
|
96 |
}
|
97 |
+
)
|
98 |
}
|
99 |
|
100 |
// Create a callback function for messages from the worker thread.
|
101 |
const onMessageReceived = (e: MessageEvent<WorkerMessage>) => {
|
102 |
+
const status = e.data.status
|
103 |
if (status === 'initiate') {
|
104 |
+
setStatus('loading')
|
105 |
} else if (status === 'ready') {
|
106 |
+
setStatus('ready')
|
107 |
} else if (status === 'progress') {
|
108 |
+
setStatus('progress')
|
109 |
if (
|
110 |
e.data.output.progress &&
|
111 |
(e.data.output.file as string).startsWith('onnx')
|
112 |
)
|
113 |
+
setProgress(e.data.output.progress)
|
114 |
} else if (status === 'output') {
|
115 |
+
setStatus('output')
|
116 |
+
const { sequence, labels, scores } = e.data.output!
|
117 |
|
118 |
// Threshold for classification
|
119 |
+
const label = scores[0] > 0.5 ? labels[0] : 'Other'
|
120 |
|
121 |
const sectionID =
|
122 |
+
sections.map((x) => x.title).indexOf(label) ?? sections.length - 1
|
123 |
setSections((sections) => {
|
124 |
+
const newSections = [...sections]
|
125 |
newSections[sectionID] = {
|
126 |
...newSections[sectionID],
|
127 |
items: [...newSections[sectionID].items, sequence]
|
128 |
+
}
|
129 |
+
return newSections
|
130 |
+
})
|
131 |
} else if (status === 'complete') {
|
132 |
+
setStatus('idle')
|
133 |
+
setProgress(100)
|
134 |
+
} else if (status === 'error') {
|
135 |
+
setStatus('error')
|
136 |
+
console.error(e.data.output)
|
137 |
}
|
138 |
+
}
|
139 |
|
140 |
// Attach the callback function as an event listener.
|
141 |
+
worker.current.addEventListener('message', onMessageReceived)
|
142 |
|
143 |
// Define a cleanup function for when the component is unmounted.
|
144 |
return () =>
|
145 |
+
worker.current?.removeEventListener('message', onMessageReceived)
|
146 |
+
}, [sections])
|
147 |
|
148 |
const classify = useCallback(() => {
|
149 |
+
setStatus('processing')
|
150 |
const message: ZeroShotWorkerInput = {
|
151 |
text,
|
152 |
labels: sections
|
153 |
.slice(0, sections.length - 1)
|
154 |
+
.map((section) => section.title),
|
155 |
+
model: modelInfo.name
|
156 |
+
}
|
157 |
+
worker.current?.postMessage(message)
|
158 |
+
}, [text, sections, modelInfo.name])
|
159 |
|
160 |
+
const busy: boolean = status !== 'idle'
|
161 |
|
162 |
const handleAddCategory = (): void => {
|
163 |
setSections((sections) => {
|
164 |
+
const newSections = [...sections]
|
165 |
// add at position 2 from the end
|
166 |
newSections.splice(newSections.length - 1, 0, {
|
167 |
title: 'New Category',
|
168 |
items: []
|
169 |
+
})
|
170 |
+
return newSections
|
171 |
+
})
|
172 |
+
}
|
173 |
|
174 |
const handleRemoveCategory = (): void => {
|
175 |
setSections((sections) => {
|
176 |
+
const newSections = [...sections]
|
177 |
+
newSections.splice(newSections.length - 2, 1) // Remove second last element
|
178 |
+
return newSections
|
179 |
+
})
|
180 |
+
}
|
181 |
|
182 |
const handleClear = (): void => {
|
183 |
setSections((sections) =>
|
|
|
185 |
...section,
|
186 |
items: []
|
187 |
}))
|
188 |
+
)
|
189 |
+
}
|
190 |
|
191 |
const handleSectionTitleChange = (index: number, newTitle: string): void => {
|
192 |
setSections((sections) => {
|
193 |
+
const newSections = [...sections]
|
194 |
+
newSections[index].title = newTitle
|
195 |
+
return newSections
|
196 |
+
})
|
197 |
+
}
|
198 |
|
199 |
return (
|
200 |
<div className="flex flex-col h-screen w-full p-1">
|
|
|
212 |
{!busy
|
213 |
? 'Categorize'
|
214 |
: status === 'loading'
|
215 |
+
? 'Model loading...'
|
216 |
+
: 'Processing'}
|
217 |
</button>
|
218 |
<div className="flex gap-1">
|
219 |
<button
|
|
|
261 |
))}
|
262 |
</div>
|
263 |
</div>
|
264 |
+
)
|
265 |
}
|
266 |
|
267 |
+
export default ZeroShotClassification
|
src/contexts/ModelContext.tsx
CHANGED
@@ -1,39 +1,47 @@
|
|
1 |
-
import React, { createContext, useContext, useEffect, useState } from 'react'
|
|
|
2 |
|
3 |
interface ModelContextType {
|
4 |
-
progress: number
|
5 |
-
status: string
|
6 |
-
setProgress: (progress: number) => void
|
7 |
-
setStatus: (status: string) => void
|
8 |
-
|
9 |
-
|
10 |
}
|
11 |
|
12 |
-
const ModelContext = createContext<ModelContextType | undefined>(undefined)
|
13 |
|
14 |
export function ModelProvider({ children }: { children: React.ReactNode }) {
|
15 |
-
const [progress, setProgress] = useState<number>(0)
|
16 |
-
const [status, setStatus] = useState<string>('idle')
|
17 |
-
const [
|
18 |
|
19 |
// set progress to 0 when model is changed
|
20 |
useEffect(() => {
|
21 |
-
setProgress(0)
|
22 |
-
}, [
|
23 |
|
24 |
return (
|
25 |
<ModelContext.Provider
|
26 |
-
value={{
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
>
|
28 |
{children}
|
29 |
</ModelContext.Provider>
|
30 |
-
)
|
31 |
}
|
32 |
|
33 |
export function useModel() {
|
34 |
-
const context = useContext(ModelContext)
|
35 |
if (context === undefined) {
|
36 |
-
throw new Error('useModel must be used within a ModelProvider')
|
37 |
}
|
38 |
-
return context
|
39 |
}
|
|
|
1 |
+
import React, { createContext, useContext, useEffect, useState } from 'react'
|
2 |
+
import { ModelInfo } from '../types'
|
3 |
|
4 |
interface ModelContextType {
|
5 |
+
progress: number
|
6 |
+
status: string
|
7 |
+
setProgress: (progress: number) => void
|
8 |
+
setStatus: (status: string) => void
|
9 |
+
modelInfo: ModelInfo
|
10 |
+
setModelInfo: (model: ModelInfo) => void
|
11 |
}
|
12 |
|
13 |
+
const ModelContext = createContext<ModelContextType | undefined>(undefined)
|
14 |
|
15 |
export function ModelProvider({ children }: { children: React.ReactNode }) {
|
16 |
+
const [progress, setProgress] = useState<number>(0)
|
17 |
+
const [status, setStatus] = useState<string>('idle')
|
18 |
+
const [modelInfo, setModelInfo] = useState<ModelInfo>({} as ModelInfo)
|
19 |
|
20 |
// set progress to 0 when model is changed
|
21 |
useEffect(() => {
|
22 |
+
setProgress(0)
|
23 |
+
}, [modelInfo.name])
|
24 |
|
25 |
return (
|
26 |
<ModelContext.Provider
|
27 |
+
value={{
|
28 |
+
progress,
|
29 |
+
setProgress,
|
30 |
+
status,
|
31 |
+
setStatus,
|
32 |
+
modelInfo,
|
33 |
+
setModelInfo
|
34 |
+
}}
|
35 |
>
|
36 |
{children}
|
37 |
</ModelContext.Provider>
|
38 |
+
)
|
39 |
}
|
40 |
|
41 |
export function useModel() {
|
42 |
+
const context = useContext(ModelContext)
|
43 |
if (context === undefined) {
|
44 |
+
throw new Error('useModel must be used within a ModelProvider')
|
45 |
}
|
46 |
+
return context
|
47 |
}
|
src/lib/huggingface.ts
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
interface ModelInfoResponse {
|
2 |
+
id: string
|
3 |
+
config: {
|
4 |
+
architectures: string[]
|
5 |
+
model_type: string
|
6 |
+
}
|
7 |
+
lastModified: string
|
8 |
+
pipeline_tag: string
|
9 |
+
tags: string[]
|
10 |
+
transformersInfo: {
|
11 |
+
pipeline_tag: string
|
12 |
+
auto_model: string
|
13 |
+
processor: string
|
14 |
+
}
|
15 |
+
safetensors?: {
|
16 |
+
parameters: {
|
17 |
+
F16?: number
|
18 |
+
F32?: number
|
19 |
+
total?: number
|
20 |
+
}
|
21 |
+
}
|
22 |
+
likes: number
|
23 |
+
downloads: number
|
24 |
+
}
|
25 |
+
|
26 |
+
const getModelInfo = async (modelName: string): Promise<ModelInfoResponse> => {
|
27 |
+
const token = process.env.REACT_APP_HUGGINGFACE_TOKEN
|
28 |
+
|
29 |
+
if (!token) {
|
30 |
+
throw new Error(
|
31 |
+
'Hugging Face token not found. Please set REACT_APP_HUGGINGFACE_TOKEN in your .env file'
|
32 |
+
)
|
33 |
+
}
|
34 |
+
|
35 |
+
const response = await fetch(
|
36 |
+
`https://huggingface.co/api/models/${modelName}`,
|
37 |
+
{
|
38 |
+
method: 'GET',
|
39 |
+
headers: {
|
40 |
+
Authorization: `Bearer ${token}`
|
41 |
+
}
|
42 |
+
}
|
43 |
+
)
|
44 |
+
|
45 |
+
if (!response.ok) {
|
46 |
+
throw new Error(`Failed to fetch model info: ${response.statusText}`)
|
47 |
+
}
|
48 |
+
return response.json()
|
49 |
+
}
|
50 |
+
|
51 |
+
export { getModelInfo }
|
src/types.ts
CHANGED
@@ -1,26 +1,36 @@
|
|
1 |
export interface Section {
|
2 |
-
title: string
|
3 |
-
items: string[]
|
4 |
}
|
5 |
|
6 |
export interface ClassificationOutput {
|
7 |
-
sequence: string
|
8 |
-
labels: string[]
|
9 |
-
scores: number[]
|
10 |
}
|
11 |
|
12 |
export interface WorkerMessage {
|
13 |
-
status: 'initiate' | 'ready' | 'output' | 'complete' | 'progress'
|
14 |
-
output?: any
|
15 |
}
|
16 |
|
17 |
export interface ZeroShotWorkerInput {
|
18 |
-
text: string
|
19 |
-
labels: string[]
|
|
|
20 |
}
|
21 |
|
22 |
export interface TextClassificationWorkerInput {
|
23 |
-
text: string
|
|
|
24 |
}
|
25 |
|
26 |
-
export type AppStatus = 'idle' | 'loading' | 'processing'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
export interface Section {
|
2 |
+
title: string
|
3 |
+
items: string[]
|
4 |
}
|
5 |
|
6 |
export interface ClassificationOutput {
|
7 |
+
sequence: string
|
8 |
+
labels: string[]
|
9 |
+
scores: number[]
|
10 |
}
|
11 |
|
12 |
export interface WorkerMessage {
|
13 |
+
status: 'initiate' | 'ready' | 'output' | 'complete' | 'progress'
|
14 |
+
output?: any
|
15 |
}
|
16 |
|
17 |
export interface ZeroShotWorkerInput {
|
18 |
+
text: string
|
19 |
+
labels: string[]
|
20 |
+
model: string
|
21 |
}
|
22 |
|
23 |
export interface TextClassificationWorkerInput {
|
24 |
+
text: string
|
25 |
+
model: string
|
26 |
}
|
27 |
|
28 |
+
export type AppStatus = 'idle' | 'loading' | 'processing'
|
29 |
+
|
30 |
+
export interface ModelInfo {
|
31 |
+
name: string
|
32 |
+
architecture: string
|
33 |
+
parameters: number
|
34 |
+
likes: number
|
35 |
+
downloads: number
|
36 |
+
}
|
src/workers/text-classification.js
CHANGED
@@ -3,11 +3,10 @@ import { pipeline } from '@huggingface/transformers';
|
|
3 |
|
4 |
class MyTextClassificationPipeline {
|
5 |
static task = 'text-classification';
|
6 |
-
static model = 'Xenova/bert-base-multilingual-uncased-sentiment';
|
7 |
static instance = null;
|
8 |
|
9 |
-
static async getInstance(progress_callback = null) {
|
10 |
-
this.instance ??= pipeline(this.task,
|
11 |
progress_callback
|
12 |
});
|
13 |
|
@@ -17,15 +16,24 @@ class MyTextClassificationPipeline {
|
|
17 |
|
18 |
// Listen for messages from the main thread
|
19 |
self.addEventListener('message', async (event) => {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
// Retrieve the pipeline. When called for the first time,
|
21 |
// this will load the pipeline and save it for future use.
|
22 |
-
const classifier = await MyTextClassificationPipeline.getInstance((x) => {
|
23 |
// We also add a progress callback to the pipeline so that we can
|
24 |
// track model loading.
|
25 |
self.postMessage({ status: 'progress', output: x });
|
26 |
});
|
27 |
|
28 |
-
|
29 |
|
30 |
const split = text.split('\n');
|
31 |
for (const line of split) {
|
|
|
3 |
|
4 |
class MyTextClassificationPipeline {
|
5 |
static task = 'text-classification';
|
|
|
6 |
static instance = null;
|
7 |
|
8 |
+
static async getInstance(model, progress_callback = null) {
|
9 |
+
this.instance ??= pipeline(this.task, model, {
|
10 |
progress_callback
|
11 |
});
|
12 |
|
|
|
16 |
|
17 |
// Listen for messages from the main thread
|
18 |
self.addEventListener('message', async (event) => {
|
19 |
+
const { text, model } = event.data;
|
20 |
+
if (!model) {
|
21 |
+
self.postMessage({
|
22 |
+
status: 'error',
|
23 |
+
output: 'No model provided'
|
24 |
+
});
|
25 |
+
return;
|
26 |
+
}
|
27 |
+
|
28 |
// Retrieve the pipeline. When called for the first time,
|
29 |
// this will load the pipeline and save it for future use.
|
30 |
+
const classifier = await MyTextClassificationPipeline.getInstance(model, (x) => {
|
31 |
// We also add a progress callback to the pipeline so that we can
|
32 |
// track model loading.
|
33 |
self.postMessage({ status: 'progress', output: x });
|
34 |
});
|
35 |
|
36 |
+
|
37 |
|
38 |
const split = text.split('\n');
|
39 |
for (const line of split) {
|
src/workers/zero-shot.js
CHANGED
@@ -3,11 +3,10 @@ import { pipeline } from '@huggingface/transformers';
|
|
3 |
|
4 |
class MyZeroShotClassificationPipeline {
|
5 |
static task = 'zero-shot-classification';
|
6 |
-
static model = 'MoritzLaurer/deberta-v3-xsmall-zeroshot-v1.1-all-33';
|
7 |
static instance = null;
|
8 |
|
9 |
-
static async getInstance(progress_callback = null) {
|
10 |
-
this.instance ??= pipeline(this.task,
|
11 |
progress_callback
|
12 |
});
|
13 |
|
@@ -17,16 +16,22 @@ class MyZeroShotClassificationPipeline {
|
|
17 |
|
18 |
// Listen for messages from the main thread
|
19 |
self.addEventListener('message', async (event) => {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
// Retrieve the pipeline. When called for the first time,
|
21 |
// this will load the pipeline and save it for future use.
|
22 |
-
const classifier = await MyZeroShotClassificationPipeline.getInstance((x) => {
|
23 |
// We also add a progress callback to the pipeline so that we can
|
24 |
// track model loading.
|
25 |
self.postMessage({ status: 'progress', output: x });
|
26 |
});
|
27 |
-
|
28 |
-
const { text, labels } = event.data;
|
29 |
-
|
30 |
const split = text.split('\n');
|
31 |
for (const line of split) {
|
32 |
const output = await classifier(line, labels, {
|
|
|
3 |
|
4 |
class MyZeroShotClassificationPipeline {
|
5 |
static task = 'zero-shot-classification';
|
|
|
6 |
static instance = null;
|
7 |
|
8 |
+
static async getInstance(model, progress_callback = null) {
|
9 |
+
this.instance ??= pipeline(this.task, model, {
|
10 |
progress_callback
|
11 |
});
|
12 |
|
|
|
16 |
|
17 |
// Listen for messages from the main thread
|
18 |
self.addEventListener('message', async (event) => {
|
19 |
+
const { text, labels, model } = event.data;
|
20 |
+
if (!model) {
|
21 |
+
self.postMessage({
|
22 |
+
status: 'error',
|
23 |
+
output: 'No model provided'
|
24 |
+
});
|
25 |
+
return;
|
26 |
+
}
|
27 |
+
|
28 |
// Retrieve the pipeline. When called for the first time,
|
29 |
// this will load the pipeline and save it for future use.
|
30 |
+
const classifier = await MyZeroShotClassificationPipeline.getInstance(model, (x) => {
|
31 |
// We also add a progress callback to the pipeline so that we can
|
32 |
// track model loading.
|
33 |
self.postMessage({ status: 'progress', output: x });
|
34 |
});
|
|
|
|
|
|
|
35 |
const split = text.split('\n');
|
36 |
for (const line of split) {
|
37 |
const output = await classifier(line, labels, {
|