Vokturz commited on
Commit
96812c9
·
1 Parent(s): fb852fe

feat: enhance model information display and improve context management

Browse files

- Added model information display in the pipeline selection area, including model name, likes, downloads, and parameters.
- Updated the context to manage model information instead of just the model name.
- Integrated model information fetching from Hugging Face API in both TextClassification and ZeroShotClassification components.
- Modified worker scripts to accept dynamic model names for text classification and zero-shot classification tasks.
- Improved formatting for large numbers in the UI.
- Added environment variable support for Hugging Face API token.
- Updated types to include ModelInfo interface for better type safety.

.env.example ADDED
@@ -0,0 +1 @@
 
 
1
+ REACT_APP_HUGGINGFACE_TOKEN="hf_..."
.gitignore CHANGED
@@ -21,3 +21,5 @@
21
  npm-debug.log*
22
  yarn-debug.log*
23
  yarn-error.log*
 
 
 
21
  npm-debug.log*
22
  yarn-debug.log*
23
  yarn-error.log*
24
+
25
+ .env
.vscode/settings.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ {
2
+ "CodeGPT.apiKey": "CodeGPT Plus Beta"
3
+ }
package.json CHANGED
@@ -13,6 +13,9 @@
13
  "@types/react": "^19.1.8",
14
  "@types/react-dom": "^19.1.6",
15
  "build": "^0.1.4",
 
 
 
16
  "react": "^19.1.0",
17
  "react-dom": "^19.1.0",
18
  "react-scripts": "5.0.1",
 
13
  "@types/react": "^19.1.8",
14
  "@types/react-dom": "^19.1.6",
15
  "build": "^0.1.4",
16
+ "dotenv": "^17.0.1",
17
+ "lucide-react": "^0.525.0",
18
+ "path-browserify": "^1.0.1",
19
  "react": "^19.1.0",
20
  "react-dom": "^19.1.0",
21
  "react-scripts": "5.0.1",
pnpm-lock.yaml CHANGED
The diff for this file is too large to render. See raw diff
 
src/App.tsx CHANGED
@@ -5,10 +5,22 @@ import TextClassification from './components/TextClassification';
5
  import Header from './Header';
6
  import Footer from './Footer';
7
  import { useModel } from './contexts/ModelContext';
 
8
 
9
  function App() {
10
  const [pipeline, setPipeline] = useState('zero-shot-classification');
11
- const { progress, status, model } = useModel();
 
 
 
 
 
 
 
 
 
 
 
12
 
13
  return (
14
  <div className="min-h-screen bg-gradient-to-br from-blue-50 to-indigo-100">
@@ -18,9 +30,47 @@ function App() {
18
  {/* Pipeline Selection */}
19
  <div className="mb-8">
20
  <div className="bg-white rounded-lg shadow-sm border p-6">
21
- <h2 className="text-lg font-semibold text-gray-900 mb-4">
22
- Choose a Pipeline
23
- </h2>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  <PipelineSelector pipeline={pipeline} setPipeline={setPipeline} />
25
 
26
  {/* Model Loading Progress */}
@@ -85,11 +135,6 @@ function App() {
85
  {pipeline === 'zero-shot-classification'
86
  ? 'Zero-Shot Classification'
87
  : 'Text-Classification'}
88
- {model && (
89
- <span className="ml-2 text-xs text-gray-500 font-normal">
90
- ({model})
91
- </span>
92
- )}
93
  </h3>
94
  <p className="text-sm text-gray-600 mt-1">
95
  {pipeline === 'zero-shot-classification'
 
5
  import Header from './Header';
6
  import Footer from './Footer';
7
  import { useModel } from './contexts/ModelContext';
8
+ import { Bot, Heart, Download, Cpu } from 'lucide-react';
9
 
10
  function App() {
11
  const [pipeline, setPipeline] = useState('zero-shot-classification');
12
+ const { progress, status, modelInfo } = useModel();
13
+
14
+ const formatNumber = (num: number) => {
15
+ if (num >= 1000000000) {
16
+ return (num / 1000000000).toFixed(1) + 'B'
17
+ } else if (num >= 1000000) {
18
+ return (num / 1000000).toFixed(1) + 'M'
19
+ } else if (num >= 1000) {
20
+ return (num / 1000).toFixed(1) + 'K'
21
+ }
22
+ return num.toString();
23
+ };
24
 
25
  return (
26
  <div className="min-h-screen bg-gradient-to-br from-blue-50 to-indigo-100">
 
30
  {/* Pipeline Selection */}
31
  <div className="mb-8">
32
  <div className="bg-white rounded-lg shadow-sm border p-6">
33
+ <div className="flex items-center justify-between mb-4">
34
+ <h2 className="text-lg font-semibold text-gray-900">
35
+ Choose a Pipeline
36
+ </h2>
37
+
38
+ {/* Model Info Display */}
39
+ {modelInfo.name && (
40
+ <div className="flex items-center space-x-4 bg-gradient-to-r from-blue-50 to-indigo-50 px-4 py-2 rounded-lg border border-blue-200">
41
+ <div className="flex items-center space-x-2">
42
+ <Bot className="w-4 h-4 text-blue-600" />
43
+ <span className="text-sm font-medium text-gray-700 truncate max-w-80" title={modelInfo.name}>
44
+ {modelInfo.name.split('/').pop()}
45
+ </span>
46
+ </div>
47
+
48
+ <div className="flex items-center space-x-4 text-xs text-gray-600">
49
+ {modelInfo.likes > 0 && (
50
+ <div className="flex items-center space-x-1">
51
+ <Heart className="w-3 h-3 text-red-500" />
52
+ <span>{formatNumber(modelInfo.likes)}</span>
53
+ </div>
54
+ )}
55
+
56
+ {modelInfo.downloads > 0 && (
57
+ <div className="flex items-center space-x-1">
58
+ <Download className="w-3 h-3 text-green-500" />
59
+ <span>{formatNumber(modelInfo.downloads)}</span>
60
+ </div>
61
+ )}
62
+
63
+ {modelInfo.parameters > 0 && (
64
+ <div className="flex items-center space-x-1">
65
+ <Cpu className="w-3 h-3 text-purple-500" />
66
+ <span>{formatNumber(modelInfo.parameters)}</span>
67
+ </div>
68
+ )}
69
+ </div>
70
+ </div>
71
+ )}
72
+ </div>
73
+
74
  <PipelineSelector pipeline={pipeline} setPipeline={setPipeline} />
75
 
76
  {/* Model Loading Progress */}
 
135
  {pipeline === 'zero-shot-classification'
136
  ? 'Zero-Shot Classification'
137
  : 'Text-Classification'}
 
 
 
 
 
138
  </h3>
139
  <p className="text-sm text-gray-600 mt-1">
140
  {pipeline === 'zero-shot-classification'
src/components/TextClassification.tsx CHANGED
@@ -2,9 +2,11 @@ import { useState, useRef, useEffect, useCallback } from 'react';
2
  import {
3
  ClassificationOutput,
4
  TextClassificationWorkerInput,
5
- WorkerMessage
 
6
  } from '../types';
7
  import { useModel } from '../contexts/ModelContext';
 
8
 
9
 
10
  const PLACEHOLDER_TEXTS: string[] = [
@@ -21,13 +23,41 @@ const PLACEHOLDER_TEXTS: string[] = [
21
  ].sort(() => Math.random() - 0.5);
22
 
23
  function TextClassification() {
24
- const [text, setText] = useState<string>(PLACEHOLDER_TEXTS.join('\n'));
25
- const [results, setResults] = useState<ClassificationOutput[]>([]);
26
- const { setProgress, status, setStatus, setModel } = useModel();
27
- setModel('Xenova/bert-base-multilingual-uncased-sentiment')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
  // Create a reference to the worker object.
30
- const worker = useRef<Worker | null>(null);
31
 
32
  // We use the `useEffect` hook to setup the worker as soon as the component is mounted.
33
  useEffect(() => {
@@ -38,54 +68,57 @@ function TextClassification() {
38
  {
39
  type: 'module'
40
  }
41
- );
42
  }
43
 
44
  // Create a callback function for messages from the worker thread.
45
  const onMessageReceived = (e: MessageEvent<WorkerMessage>) => {
46
- const status = e.data.status;
47
  if (status === 'initiate') {
48
- setStatus('loading');
49
  } else if (status === 'ready') {
50
- setStatus('ready');
51
  } else if (status === 'progress') {
52
- setStatus('progress');
53
  if (
54
  e.data.output.progress &&
55
  (e.data.output.file as string).startsWith('onnx')
56
  )
57
- setProgress(e.data.output.progress);
58
  } else if (status === 'output') {
59
- setStatus('output');
60
- const result = e.data.output!;
61
- setResults((prevResults) => [...prevResults, result]);
62
- console.log(result);
63
  } else if (status === 'complete') {
64
- setStatus('idle');
65
- setProgress(100);
 
 
 
66
  }
67
- };
68
 
69
  // Attach the callback function as an event listener.
70
- worker.current.addEventListener('message', onMessageReceived);
71
 
72
  // Define a cleanup function for when the component is unmounted.
73
  return () =>
74
- worker.current?.removeEventListener('message', onMessageReceived);
75
- }, []);
76
 
77
  const classify = useCallback(() => {
78
- setStatus('processing');
79
- setResults([]); // Clear previous results
80
- const message: TextClassificationWorkerInput = { text };
81
- worker.current?.postMessage(message);
82
- }, [text]);
83
 
84
- const busy: boolean = status !== 'idle';
85
 
86
  const handleClear = (): void => {
87
- setResults([]);
88
- };
89
 
90
  return (
91
  <div className="flex flex-col h-[40vh] max-h-[80vh] w-full p-4">
@@ -157,7 +190,7 @@ function TextClassification() {
157
  </div>
158
  </div>
159
  </div>
160
- );
161
  }
162
 
163
  export default TextClassification;
 
2
  import {
3
  ClassificationOutput,
4
  TextClassificationWorkerInput,
5
+ WorkerMessage,
6
+ ModelInfo
7
  } from '../types';
8
  import { useModel } from '../contexts/ModelContext';
9
+ import { getModelInfo } from '../lib/huggingface';
10
 
11
 
12
  const PLACEHOLDER_TEXTS: string[] = [
 
23
  ].sort(() => Math.random() - 0.5);
24
 
25
  function TextClassification() {
26
+ const [text, setText] = useState<string>(PLACEHOLDER_TEXTS.join('\n'))
27
+ const [results, setResults] = useState<ClassificationOutput[]>([])
28
+ const { setProgress, status, setStatus, modelInfo, setModelInfo} = useModel()
29
+ useEffect(() => {
30
+ const modelName = 'Xenova/distilbert-base-uncased-finetuned-sst-2-english'
31
+ const fetchModelInfo = async () => {
32
+ try {
33
+ const modelInfoResponse = await getModelInfo(modelName)
34
+ console.log(modelInfoResponse)
35
+ let parameters = 0
36
+ if (modelInfoResponse.safetensors) {
37
+ const safetensors = modelInfoResponse.safetensors
38
+ parameters =
39
+ (safetensors.parameters.F16 ||
40
+ safetensors.parameters.F32 ||
41
+ safetensors.parameters.total ||
42
+ 0)
43
+ }
44
+ setModelInfo({
45
+ name: modelName,
46
+ architecture: modelInfoResponse.config.architectures[0],
47
+ parameters,
48
+ likes: modelInfoResponse.likes,
49
+ downloads: modelInfoResponse.downloads
50
+ })
51
+ } catch (error) {
52
+ console.error('Error fetching model info:', error)
53
+ }
54
+ }
55
+
56
+ fetchModelInfo()
57
+ }, [setModelInfo])
58
 
59
  // Create a reference to the worker object.
60
+ const worker = useRef<Worker | null>(null)
61
 
62
  // We use the `useEffect` hook to setup the worker as soon as the component is mounted.
63
  useEffect(() => {
 
68
  {
69
  type: 'module'
70
  }
71
+ )
72
  }
73
 
74
  // Create a callback function for messages from the worker thread.
75
  const onMessageReceived = (e: MessageEvent<WorkerMessage>) => {
76
+ const status = e.data.status
77
  if (status === 'initiate') {
78
+ setStatus('loading')
79
  } else if (status === 'ready') {
80
+ setStatus('ready')
81
  } else if (status === 'progress') {
82
+ setStatus('progress')
83
  if (
84
  e.data.output.progress &&
85
  (e.data.output.file as string).startsWith('onnx')
86
  )
87
+ setProgress(e.data.output.progress)
88
  } else if (status === 'output') {
89
+ setStatus('output')
90
+ const result = e.data.output!
91
+ setResults((prevResults) => [...prevResults, result])
92
+ console.log(result)
93
  } else if (status === 'complete') {
94
+ setStatus('idle')
95
+ setProgress(100)
96
+ } else if (status === 'error') {
97
+ setStatus('error')
98
+ console.error(e.data.output)
99
  }
100
+ }
101
 
102
  // Attach the callback function as an event listener.
103
+ worker.current.addEventListener('message', onMessageReceived)
104
 
105
  // Define a cleanup function for when the component is unmounted.
106
  return () =>
107
+ worker.current?.removeEventListener('message', onMessageReceived)
108
+ }, [])
109
 
110
  const classify = useCallback(() => {
111
+ setStatus('processing')
112
+ setResults([]) // Clear previous results
113
+ const message: TextClassificationWorkerInput = { text, model: modelInfo.name }
114
+ worker.current?.postMessage(message)
115
+ }, [text, modelInfo.name])
116
 
117
+ const busy: boolean = status !== 'idle'
118
 
119
  const handleClear = (): void => {
120
+ setResults([])
121
+ }
122
 
123
  return (
124
  <div className="flex flex-col h-[40vh] max-h-[80vh] w-full p-4">
 
190
  </div>
191
  </div>
192
  </div>
193
+ )
194
  }
195
 
196
  export default TextClassification;
src/components/ZeroShotClassification.tsx CHANGED
@@ -1,7 +1,13 @@
1
  // src/App.tsx
2
- import { useState, useRef, useEffect, useCallback } from 'react';
3
- import { Section, WorkerMessage, ZeroShotWorkerInput } from '../types';
4
- import { useModel } from '../contexts/ModelContext';
 
 
 
 
 
 
5
 
6
  const PLACEHOLDER_REVIEWS: string[] = [
7
  // battery/charging problems
@@ -28,7 +34,7 @@ const PLACEHOLDER_REVIEWS: string[] = [
28
  "I'm not sure what to make of this phone. It's not bad, but it's not great either. I'm on the fence about it.",
29
  "I hate the color of this phone. It's so ugly!",
30
  "This phone sucks! I'm returning it."
31
- ].sort(() => Math.random() - 0.5);
32
 
33
  const PLACEHOLDER_SECTIONS: string[] = [
34
  'Battery and charging problems',
@@ -36,20 +42,48 @@ const PLACEHOLDER_SECTIONS: string[] = [
36
  'Poor build quality',
37
  'Software issues',
38
  'Other'
39
- ];
40
 
41
  function ZeroShotClassification() {
42
- const [text, setText] = useState<string>(PLACEHOLDER_REVIEWS.join('\n'));
43
 
44
  const [sections, setSections] = useState<Section[]>(
45
  PLACEHOLDER_SECTIONS.map((title) => ({ title, items: [] }))
46
- );
47
 
48
- const { setProgress, status, setStatus, setModel } = useModel();
49
- setModel('MoritzLaurer/deberta-v3-xsmall-zeroshot-v1.1-all-33')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
  // Create a reference to the worker object.
52
- const worker = useRef<Worker | null>(null);
53
 
54
  // We use the `useEffect` hook to setup the worker as soon as the `App` component is mounted.
55
  useEffect(() => {
@@ -60,86 +94,90 @@ function ZeroShotClassification() {
60
  {
61
  type: 'module'
62
  }
63
- );
64
  }
65
 
66
  // Create a callback function for messages from the worker thread.
67
  const onMessageReceived = (e: MessageEvent<WorkerMessage>) => {
68
- const status = e.data.status;
69
  if (status === 'initiate') {
70
- setStatus('loading');
71
  } else if (status === 'ready') {
72
- setStatus('ready');
73
  } else if (status === 'progress') {
74
- setStatus('progress');
75
  if (
76
  e.data.output.progress &&
77
  (e.data.output.file as string).startsWith('onnx')
78
  )
79
- setProgress(e.data.output.progress);
80
  } else if (status === 'output') {
81
- setStatus('output');
82
- const { sequence, labels, scores } = e.data.output!;
83
 
84
  // Threshold for classification
85
- const label = scores[0] > 0.5 ? labels[0] : 'Other';
86
 
87
  const sectionID =
88
- sections.map((x) => x.title).indexOf(label) ?? sections.length - 1;
89
  setSections((sections) => {
90
- const newSections = [...sections];
91
  newSections[sectionID] = {
92
  ...newSections[sectionID],
93
  items: [...newSections[sectionID].items, sequence]
94
- };
95
- return newSections;
96
- });
97
  } else if (status === 'complete') {
98
- setStatus('idle');
99
- setProgress(100);
 
 
 
100
  }
101
- };
102
 
103
  // Attach the callback function as an event listener.
104
- worker.current.addEventListener('message', onMessageReceived);
105
 
106
  // Define a cleanup function for when the component is unmounted.
107
  return () =>
108
- worker.current?.removeEventListener('message', onMessageReceived);
109
- }, [sections]);
110
 
111
  const classify = useCallback(() => {
112
- setStatus('processing');
113
  const message: ZeroShotWorkerInput = {
114
  text,
115
  labels: sections
116
  .slice(0, sections.length - 1)
117
- .map((section) => section.title)
118
- };
119
- worker.current?.postMessage(message);
120
- }, [text, sections]);
 
121
 
122
- const busy: boolean = status !== 'idle';
123
 
124
  const handleAddCategory = (): void => {
125
  setSections((sections) => {
126
- const newSections = [...sections];
127
  // add at position 2 from the end
128
  newSections.splice(newSections.length - 1, 0, {
129
  title: 'New Category',
130
  items: []
131
- });
132
- return newSections;
133
- });
134
- };
135
 
136
  const handleRemoveCategory = (): void => {
137
  setSections((sections) => {
138
- const newSections = [...sections];
139
- newSections.splice(newSections.length - 2, 1); // Remove second last element
140
- return newSections;
141
- });
142
- };
143
 
144
  const handleClear = (): void => {
145
  setSections((sections) =>
@@ -147,16 +185,16 @@ function ZeroShotClassification() {
147
  ...section,
148
  items: []
149
  }))
150
- );
151
- };
152
 
153
  const handleSectionTitleChange = (index: number, newTitle: string): void => {
154
  setSections((sections) => {
155
- const newSections = [...sections];
156
- newSections[index].title = newTitle;
157
- return newSections;
158
- });
159
- };
160
 
161
  return (
162
  <div className="flex flex-col h-screen w-full p-1">
@@ -174,8 +212,8 @@ function ZeroShotClassification() {
174
  {!busy
175
  ? 'Categorize'
176
  : status === 'loading'
177
- ? 'Model loading...'
178
- : 'Processing'}
179
  </button>
180
  <div className="flex gap-1">
181
  <button
@@ -223,7 +261,7 @@ function ZeroShotClassification() {
223
  ))}
224
  </div>
225
  </div>
226
- );
227
  }
228
 
229
- export default ZeroShotClassification;
 
1
  // src/App.tsx
2
+ import { useState, useRef, useEffect, useCallback } from 'react'
3
+ import {
4
+ Section,
5
+ WorkerMessage,
6
+ ZeroShotWorkerInput,
7
+ ModelInfo
8
+ } from '../types'
9
+ import { useModel } from '../contexts/ModelContext'
10
+ import { getModelInfo } from '../lib/huggingface'
11
 
12
  const PLACEHOLDER_REVIEWS: string[] = [
13
  // battery/charging problems
 
34
  "I'm not sure what to make of this phone. It's not bad, but it's not great either. I'm on the fence about it.",
35
  "I hate the color of this phone. It's so ugly!",
36
  "This phone sucks! I'm returning it."
37
+ ].sort(() => Math.random() - 0.5)
38
 
39
  const PLACEHOLDER_SECTIONS: string[] = [
40
  'Battery and charging problems',
 
42
  'Poor build quality',
43
  'Software issues',
44
  'Other'
45
+ ]
46
 
47
  function ZeroShotClassification() {
48
+ const [text, setText] = useState<string>(PLACEHOLDER_REVIEWS.join('\n'))
49
 
50
  const [sections, setSections] = useState<Section[]>(
51
  PLACEHOLDER_SECTIONS.map((title) => ({ title, items: [] }))
52
+ )
53
 
54
+ const { setProgress, status, setStatus, modelInfo, setModelInfo } = useModel()
55
+ useEffect(() => {
56
+ const modelName = 'MoritzLaurer/deberta-v3-xsmall-zeroshot-v1.1-all-33'
57
+ const fetchModelInfo = async () => {
58
+ try {
59
+ const modelInfoResponse = await getModelInfo(modelName)
60
+ console.log(modelInfoResponse)
61
+ let parameters = 0
62
+ if (modelInfoResponse.safetensors) {
63
+ const safetensors = modelInfoResponse.safetensors
64
+ parameters =
65
+ safetensors.parameters.F16 ||
66
+ safetensors.parameters.F32 ||
67
+ safetensors.parameters.total ||
68
+ 0
69
+ }
70
+ setModelInfo({
71
+ name: modelName,
72
+ architecture: modelInfoResponse.config.architectures[0],
73
+ parameters,
74
+ likes: modelInfoResponse.likes,
75
+ downloads: modelInfoResponse.downloads
76
+ })
77
+ } catch (error) {
78
+ console.error('Error fetching model info:', error)
79
+ }
80
+ }
81
+
82
+ fetchModelInfo()
83
+ }, [setModelInfo])
84
 
85
  // Create a reference to the worker object.
86
+ const worker = useRef<Worker | null>(null)
87
 
88
  // We use the `useEffect` hook to setup the worker as soon as the `App` component is mounted.
89
  useEffect(() => {
 
94
  {
95
  type: 'module'
96
  }
97
+ )
98
  }
99
 
100
  // Create a callback function for messages from the worker thread.
101
  const onMessageReceived = (e: MessageEvent<WorkerMessage>) => {
102
+ const status = e.data.status
103
  if (status === 'initiate') {
104
+ setStatus('loading')
105
  } else if (status === 'ready') {
106
+ setStatus('ready')
107
  } else if (status === 'progress') {
108
+ setStatus('progress')
109
  if (
110
  e.data.output.progress &&
111
  (e.data.output.file as string).startsWith('onnx')
112
  )
113
+ setProgress(e.data.output.progress)
114
  } else if (status === 'output') {
115
+ setStatus('output')
116
+ const { sequence, labels, scores } = e.data.output!
117
 
118
  // Threshold for classification
119
+ const label = scores[0] > 0.5 ? labels[0] : 'Other'
120
 
121
  const sectionID =
122
+ sections.map((x) => x.title).indexOf(label) ?? sections.length - 1
123
  setSections((sections) => {
124
+ const newSections = [...sections]
125
  newSections[sectionID] = {
126
  ...newSections[sectionID],
127
  items: [...newSections[sectionID].items, sequence]
128
+ }
129
+ return newSections
130
+ })
131
  } else if (status === 'complete') {
132
+ setStatus('idle')
133
+ setProgress(100)
134
+ } else if (status === 'error') {
135
+ setStatus('error')
136
+ console.error(e.data.output)
137
  }
138
+ }
139
 
140
  // Attach the callback function as an event listener.
141
+ worker.current.addEventListener('message', onMessageReceived)
142
 
143
  // Define a cleanup function for when the component is unmounted.
144
  return () =>
145
+ worker.current?.removeEventListener('message', onMessageReceived)
146
+ }, [sections])
147
 
148
  const classify = useCallback(() => {
149
+ setStatus('processing')
150
  const message: ZeroShotWorkerInput = {
151
  text,
152
  labels: sections
153
  .slice(0, sections.length - 1)
154
+ .map((section) => section.title),
155
+ model: modelInfo.name
156
+ }
157
+ worker.current?.postMessage(message)
158
+ }, [text, sections, modelInfo.name])
159
 
160
+ const busy: boolean = status !== 'idle'
161
 
162
  const handleAddCategory = (): void => {
163
  setSections((sections) => {
164
+ const newSections = [...sections]
165
  // add at position 2 from the end
166
  newSections.splice(newSections.length - 1, 0, {
167
  title: 'New Category',
168
  items: []
169
+ })
170
+ return newSections
171
+ })
172
+ }
173
 
174
  const handleRemoveCategory = (): void => {
175
  setSections((sections) => {
176
+ const newSections = [...sections]
177
+ newSections.splice(newSections.length - 2, 1) // Remove second last element
178
+ return newSections
179
+ })
180
+ }
181
 
182
  const handleClear = (): void => {
183
  setSections((sections) =>
 
185
  ...section,
186
  items: []
187
  }))
188
+ )
189
+ }
190
 
191
  const handleSectionTitleChange = (index: number, newTitle: string): void => {
192
  setSections((sections) => {
193
+ const newSections = [...sections]
194
+ newSections[index].title = newTitle
195
+ return newSections
196
+ })
197
+ }
198
 
199
  return (
200
  <div className="flex flex-col h-screen w-full p-1">
 
212
  {!busy
213
  ? 'Categorize'
214
  : status === 'loading'
215
+ ? 'Model loading...'
216
+ : 'Processing'}
217
  </button>
218
  <div className="flex gap-1">
219
  <button
 
261
  ))}
262
  </div>
263
  </div>
264
+ )
265
  }
266
 
267
+ export default ZeroShotClassification
src/contexts/ModelContext.tsx CHANGED
@@ -1,39 +1,47 @@
1
- import React, { createContext, useContext, useEffect, useState } from 'react';
 
2
 
3
  interface ModelContextType {
4
- progress: number;
5
- status: string;
6
- setProgress: (progress: number) => void;
7
- setStatus: (status: string) => void;
8
- model: string;
9
- setModel: (model: string) => void;
10
  }
11
 
12
- const ModelContext = createContext<ModelContextType | undefined>(undefined);
13
 
14
  export function ModelProvider({ children }: { children: React.ReactNode }) {
15
- const [progress, setProgress] = useState<number>(0);
16
- const [status, setStatus] = useState<string>('idle');
17
- const [model, setModel] = useState<string>('');
18
 
19
  // set progress to 0 when model is changed
20
  useEffect(() => {
21
- setProgress(0);
22
- }, [model]);
23
 
24
  return (
25
  <ModelContext.Provider
26
- value={{ progress, setProgress, status, setStatus, model, setModel }}
 
 
 
 
 
 
 
27
  >
28
  {children}
29
  </ModelContext.Provider>
30
- );
31
  }
32
 
33
  export function useModel() {
34
- const context = useContext(ModelContext);
35
  if (context === undefined) {
36
- throw new Error('useModel must be used within a ModelProvider');
37
  }
38
- return context;
39
  }
 
1
+ import React, { createContext, useContext, useEffect, useState } from 'react'
2
+ import { ModelInfo } from '../types'
3
 
4
  interface ModelContextType {
5
+ progress: number
6
+ status: string
7
+ setProgress: (progress: number) => void
8
+ setStatus: (status: string) => void
9
+ modelInfo: ModelInfo
10
+ setModelInfo: (model: ModelInfo) => void
11
  }
12
 
13
+ const ModelContext = createContext<ModelContextType | undefined>(undefined)
14
 
15
  export function ModelProvider({ children }: { children: React.ReactNode }) {
16
+ const [progress, setProgress] = useState<number>(0)
17
+ const [status, setStatus] = useState<string>('idle')
18
+ const [modelInfo, setModelInfo] = useState<ModelInfo>({} as ModelInfo)
19
 
20
  // set progress to 0 when model is changed
21
  useEffect(() => {
22
+ setProgress(0)
23
+ }, [modelInfo.name])
24
 
25
  return (
26
  <ModelContext.Provider
27
+ value={{
28
+ progress,
29
+ setProgress,
30
+ status,
31
+ setStatus,
32
+ modelInfo,
33
+ setModelInfo
34
+ }}
35
  >
36
  {children}
37
  </ModelContext.Provider>
38
+ )
39
  }
40
 
41
  export function useModel() {
42
+ const context = useContext(ModelContext)
43
  if (context === undefined) {
44
+ throw new Error('useModel must be used within a ModelProvider')
45
  }
46
+ return context
47
  }
src/lib/huggingface.ts ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ interface ModelInfoResponse {
2
+ id: string
3
+ config: {
4
+ architectures: string[]
5
+ model_type: string
6
+ }
7
+ lastModified: string
8
+ pipeline_tag: string
9
+ tags: string[]
10
+ transformersInfo: {
11
+ pipeline_tag: string
12
+ auto_model: string
13
+ processor: string
14
+ }
15
+ safetensors?: {
16
+ parameters: {
17
+ F16?: number
18
+ F32?: number
19
+ total?: number
20
+ }
21
+ }
22
+ likes: number
23
+ downloads: number
24
+ }
25
+
26
+ const getModelInfo = async (modelName: string): Promise<ModelInfoResponse> => {
27
+ const token = process.env.REACT_APP_HUGGINGFACE_TOKEN
28
+
29
+ if (!token) {
30
+ throw new Error(
31
+ 'Hugging Face token not found. Please set REACT_APP_HUGGINGFACE_TOKEN in your .env file'
32
+ )
33
+ }
34
+
35
+ const response = await fetch(
36
+ `https://huggingface.co/api/models/${modelName}`,
37
+ {
38
+ method: 'GET',
39
+ headers: {
40
+ Authorization: `Bearer ${token}`
41
+ }
42
+ }
43
+ )
44
+
45
+ if (!response.ok) {
46
+ throw new Error(`Failed to fetch model info: ${response.statusText}`)
47
+ }
48
+ return response.json()
49
+ }
50
+
51
+ export { getModelInfo }
src/types.ts CHANGED
@@ -1,26 +1,36 @@
1
  export interface Section {
2
- title: string;
3
- items: string[];
4
  }
5
 
6
  export interface ClassificationOutput {
7
- sequence: string;
8
- labels: string[];
9
- scores: number[];
10
  }
11
 
12
  export interface WorkerMessage {
13
- status: 'initiate' | 'ready' | 'output' | 'complete' | 'progress';
14
- output?: any;
15
  }
16
 
17
  export interface ZeroShotWorkerInput {
18
- text: string;
19
- labels: string[];
 
20
  }
21
 
22
  export interface TextClassificationWorkerInput {
23
- text: string;
 
24
  }
25
 
26
- export type AppStatus = 'idle' | 'loading' | 'processing';
 
 
 
 
 
 
 
 
 
1
  export interface Section {
2
+ title: string
3
+ items: string[]
4
  }
5
 
6
  export interface ClassificationOutput {
7
+ sequence: string
8
+ labels: string[]
9
+ scores: number[]
10
  }
11
 
12
  export interface WorkerMessage {
13
+ status: 'initiate' | 'ready' | 'output' | 'complete' | 'progress'
14
+ output?: any
15
  }
16
 
17
  export interface ZeroShotWorkerInput {
18
+ text: string
19
+ labels: string[]
20
+ model: string
21
  }
22
 
23
  export interface TextClassificationWorkerInput {
24
+ text: string
25
+ model: string
26
  }
27
 
28
+ export type AppStatus = 'idle' | 'loading' | 'processing'
29
+
30
+ export interface ModelInfo {
31
+ name: string
32
+ architecture: string
33
+ parameters: number
34
+ likes: number
35
+ downloads: number
36
+ }
src/workers/text-classification.js CHANGED
@@ -3,11 +3,10 @@ import { pipeline } from '@huggingface/transformers';
3
 
4
  class MyTextClassificationPipeline {
5
  static task = 'text-classification';
6
- static model = 'Xenova/bert-base-multilingual-uncased-sentiment';
7
  static instance = null;
8
 
9
- static async getInstance(progress_callback = null) {
10
- this.instance ??= pipeline(this.task, this.model, {
11
  progress_callback
12
  });
13
 
@@ -17,15 +16,24 @@ class MyTextClassificationPipeline {
17
 
18
  // Listen for messages from the main thread
19
  self.addEventListener('message', async (event) => {
 
 
 
 
 
 
 
 
 
20
  // Retrieve the pipeline. When called for the first time,
21
  // this will load the pipeline and save it for future use.
22
- const classifier = await MyTextClassificationPipeline.getInstance((x) => {
23
  // We also add a progress callback to the pipeline so that we can
24
  // track model loading.
25
  self.postMessage({ status: 'progress', output: x });
26
  });
27
 
28
- const { text } = event.data;
29
 
30
  const split = text.split('\n');
31
  for (const line of split) {
 
3
 
4
  class MyTextClassificationPipeline {
5
  static task = 'text-classification';
 
6
  static instance = null;
7
 
8
+ static async getInstance(model, progress_callback = null) {
9
+ this.instance ??= pipeline(this.task, model, {
10
  progress_callback
11
  });
12
 
 
16
 
17
  // Listen for messages from the main thread
18
  self.addEventListener('message', async (event) => {
19
+ const { text, model } = event.data;
20
+ if (!model) {
21
+ self.postMessage({
22
+ status: 'error',
23
+ output: 'No model provided'
24
+ });
25
+ return;
26
+ }
27
+
28
  // Retrieve the pipeline. When called for the first time,
29
  // this will load the pipeline and save it for future use.
30
+ const classifier = await MyTextClassificationPipeline.getInstance(model, (x) => {
31
  // We also add a progress callback to the pipeline so that we can
32
  // track model loading.
33
  self.postMessage({ status: 'progress', output: x });
34
  });
35
 
36
+
37
 
38
  const split = text.split('\n');
39
  for (const line of split) {
src/workers/zero-shot.js CHANGED
@@ -3,11 +3,10 @@ import { pipeline } from '@huggingface/transformers';
3
 
4
  class MyZeroShotClassificationPipeline {
5
  static task = 'zero-shot-classification';
6
- static model = 'MoritzLaurer/deberta-v3-xsmall-zeroshot-v1.1-all-33';
7
  static instance = null;
8
 
9
- static async getInstance(progress_callback = null) {
10
- this.instance ??= pipeline(this.task, this.model, {
11
  progress_callback
12
  });
13
 
@@ -17,16 +16,22 @@ class MyZeroShotClassificationPipeline {
17
 
18
  // Listen for messages from the main thread
19
  self.addEventListener('message', async (event) => {
 
 
 
 
 
 
 
 
 
20
  // Retrieve the pipeline. When called for the first time,
21
  // this will load the pipeline and save it for future use.
22
- const classifier = await MyZeroShotClassificationPipeline.getInstance((x) => {
23
  // We also add a progress callback to the pipeline so that we can
24
  // track model loading.
25
  self.postMessage({ status: 'progress', output: x });
26
  });
27
-
28
- const { text, labels } = event.data;
29
-
30
  const split = text.split('\n');
31
  for (const line of split) {
32
  const output = await classifier(line, labels, {
 
3
 
4
  class MyZeroShotClassificationPipeline {
5
  static task = 'zero-shot-classification';
 
6
  static instance = null;
7
 
8
+ static async getInstance(model, progress_callback = null) {
9
+ this.instance ??= pipeline(this.task, model, {
10
  progress_callback
11
  });
12
 
 
16
 
17
  // Listen for messages from the main thread
18
  self.addEventListener('message', async (event) => {
19
+ const { text, labels, model } = event.data;
20
+ if (!model) {
21
+ self.postMessage({
22
+ status: 'error',
23
+ output: 'No model provided'
24
+ });
25
+ return;
26
+ }
27
+
28
  // Retrieve the pipeline. When called for the first time,
29
  // this will load the pipeline and save it for future use.
30
+ const classifier = await MyZeroShotClassificationPipeline.getInstance(model, (x) => {
31
  // We also add a progress callback to the pipeline so that we can
32
  // track model loading.
33
  self.postMessage({ status: 'progress', output: x });
34
  });
 
 
 
35
  const split = text.split('\n');
36
  for (const line of split) {
37
  const output = await classifier(line, labels, {