Vokturz commited on
Commit
046ca57
·
1 Parent(s): 4a70176

Add image-classification

Browse files
public/workers/image-classification.js ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /* eslint-disable no-restricted-globals */
2
+ import { pipeline } from 'https://cdn.jsdelivr.net/npm/@huggingface/transformers@3.6.3'
3
+
4
+ class MyImageClassificationPipeline {
5
+ static task = 'image-classification'
6
+ static instance = null
7
+ static modelId = null
8
+
9
+ static async getInstance(model, dtype = 'fp32', progress_callback = null) {
10
+ if (this.modelId !== model) {
11
+ // Dispose of previous pipeline if model changed
12
+ if (this.instance && this.instance.dispose) {
13
+ this.instance.dispose()
14
+ }
15
+ this.instance = null
16
+ this.modelId = null
17
+ }
18
+
19
+ if (!this.instance) {
20
+ try {
21
+ // Try WebGPU first
22
+ this.instance = await pipeline(this.task, model, {
23
+ dtype,
24
+ device: 'webgpu',
25
+ progress_callback
26
+ })
27
+ } catch (webgpuError) {
28
+ // Fallback to WASM if WebGPU fails
29
+ if (progress_callback) {
30
+ progress_callback({
31
+ status: 'fallback',
32
+ message: 'WebGPU failed, falling back to WASM'
33
+ })
34
+ }
35
+ try {
36
+ this.instance = await pipeline(this.task, model, {
37
+ dtype,
38
+ device: 'wasm',
39
+ progress_callback
40
+ })
41
+ } catch (wasmError) {
42
+ throw new Error(
43
+ `Both WebGPU and WASM failed. WebGPU error: ${webgpuError.message}. WASM error: ${wasmError.message}`
44
+ )
45
+ }
46
+ }
47
+ this.modelId = model
48
+ }
49
+
50
+ return this.instance
51
+ }
52
+
53
+ static dispose() {
54
+ if (this.instance && this.instance.dispose) {
55
+ this.instance.dispose()
56
+ }
57
+ this.instance = null
58
+ this.modelId = null
59
+ }
60
+ }
61
+
62
+ // Listen for messages from the main thread
63
+ self.addEventListener('message', async (event) => {
64
+ try {
65
+ const { type, image, model, dtype, topK = 5 } = event.data
66
+
67
+ if (!model) {
68
+ self.postMessage({
69
+ status: 'error',
70
+ output: 'No model provided'
71
+ })
72
+ return
73
+ }
74
+
75
+ // Get the pipeline instance
76
+ const classifier = await MyImageClassificationPipeline.getInstance(
77
+ model,
78
+ dtype,
79
+ (x) => {
80
+ self.postMessage({ status: 'loading', output: x })
81
+ }
82
+ )
83
+
84
+ if (type === 'load') {
85
+ self.postMessage({
86
+ status: 'ready',
87
+ output: `Image classification model ${model}, dtype ${dtype} loaded`
88
+ })
89
+ return
90
+ }
91
+
92
+ if (type === 'classify') {
93
+ if (!image) {
94
+ self.postMessage({
95
+ status: 'error',
96
+ output: 'No image provided for classification'
97
+ })
98
+ return
99
+ }
100
+
101
+ try {
102
+ self.postMessage({ status: 'loading' })
103
+
104
+ // Run classification
105
+ const output = await classifier(image, {
106
+ topk: topK
107
+ })
108
+
109
+ // Format predictions
110
+ const predictions = output.map((item) => ({
111
+ label: item.label,
112
+ score: item.score
113
+ }))
114
+
115
+ self.postMessage({
116
+ status: 'output',
117
+ output: {
118
+ predictions
119
+ }
120
+ })
121
+ } catch (error) {
122
+ self.postMessage({
123
+ status: 'error',
124
+ output:
125
+ error.message || 'An error occurred during image classification'
126
+ })
127
+ }
128
+ } else if (type === 'dispose') {
129
+ MyImageClassificationPipeline.dispose()
130
+ self.postMessage({ status: 'disposed' })
131
+ }
132
+ } catch (error) {
133
+ self.postMessage({
134
+ status: 'error',
135
+ output:
136
+ error.message || 'An error occurred during pipeline initialization'
137
+ })
138
+ }
139
+ })
140
+
141
+ // Handle initialization
142
+ self.postMessage({ status: 'ready' })
src/App.tsx CHANGED
@@ -7,6 +7,7 @@ import { useModel } from './contexts/ModelContext'
7
  import { getModelsByPipeline } from './lib/huggingface'
8
  import TextGeneration from './components/pipelines/TextGeneration'
9
  import FeatureExtraction from './components/pipelines/FeatureExtraction'
 
10
  import Sidebar from './components/Sidebar'
11
  import ModelReadme from './components/ModelReadme'
12
  import { PipelineLayout } from './components/PipelineLayout'
@@ -63,6 +64,7 @@ function App() {
63
  {pipeline === 'text-classification' && <TextClassification />}
64
  {pipeline === 'text-generation' && <TextGeneration />}
65
  {pipeline === 'feature-extraction' && <FeatureExtraction />}
 
66
  </div>
67
  </div>
68
  </main>
 
7
  import { getModelsByPipeline } from './lib/huggingface'
8
  import TextGeneration from './components/pipelines/TextGeneration'
9
  import FeatureExtraction from './components/pipelines/FeatureExtraction'
10
+ import ImageClassification from './components/pipelines/ImageClassification'
11
  import Sidebar from './components/Sidebar'
12
  import ModelReadme from './components/ModelReadme'
13
  import { PipelineLayout } from './components/PipelineLayout'
 
64
  {pipeline === 'text-classification' && <TextClassification />}
65
  {pipeline === 'text-generation' && <TextGeneration />}
66
  {pipeline === 'feature-extraction' && <FeatureExtraction />}
67
+ {pipeline === 'image-classification' && <ImageClassification />}
68
  </div>
69
  </div>
70
  </main>
src/components/PipelineLayout.tsx CHANGED
@@ -2,6 +2,7 @@ import { useModel } from '../contexts/ModelContext'
2
  import { TextGenerationProvider } from '../contexts/TextGenerationContext'
3
  import { FeatureExtractionProvider } from '../contexts/FeatureExtractionContext'
4
  import { ZeroShotClassificationProvider } from '../contexts/ZeroShotClassificationContext'
 
5
 
6
  export const PipelineLayout = ({ children }: { children: React.ReactNode }) => {
7
  const { pipeline } = useModel()
@@ -20,6 +21,11 @@ export const PipelineLayout = ({ children }: { children: React.ReactNode }) => {
20
  </ZeroShotClassificationProvider>
21
  )
22
 
 
 
 
 
 
23
  default:
24
  return <>{children}</>
25
  }
 
2
  import { TextGenerationProvider } from '../contexts/TextGenerationContext'
3
  import { FeatureExtractionProvider } from '../contexts/FeatureExtractionContext'
4
  import { ZeroShotClassificationProvider } from '../contexts/ZeroShotClassificationContext'
5
+ import { ImageClassificationProvider } from '../contexts/ImageClassificationContext'
6
 
7
  export const PipelineLayout = ({ children }: { children: React.ReactNode }) => {
8
  const { pipeline } = useModel()
 
21
  </ZeroShotClassificationProvider>
22
  )
23
 
24
+ case 'image-classification':
25
+ return (
26
+ <ImageClassificationProvider>{children}</ImageClassificationProvider>
27
+ )
28
+
29
  default:
30
  return <>{children}</>
31
  }
src/components/Sidebar.tsx CHANGED
@@ -6,6 +6,7 @@ import { useModel } from '../contexts/ModelContext'
6
  import TextGenerationConfig from './pipelines/TextGenerationConfig'
7
  import FeatureExtractionConfig from './pipelines/FeatureExtractionConfig'
8
  import ZeroShotClassificationConfig from './pipelines/ZeroShotClassificationConfig'
 
9
 
10
  interface SidebarProps {
11
  isOpen: boolean
@@ -92,6 +93,9 @@ const Sidebar = ({ isOpen, onClose, setIsModalOpen }: SidebarProps) => {
92
  {pipeline === 'zero-shot-classification' && (
93
  <ZeroShotClassificationConfig />
94
  )}
 
 
 
95
  </div>
96
  </div>
97
  </div>
 
6
  import TextGenerationConfig from './pipelines/TextGenerationConfig'
7
  import FeatureExtractionConfig from './pipelines/FeatureExtractionConfig'
8
  import ZeroShotClassificationConfig from './pipelines/ZeroShotClassificationConfig'
9
+ import ImageClassificationConfig from './pipelines/ImageClassificationConfig'
10
 
11
  interface SidebarProps {
12
  isOpen: boolean
 
93
  {pipeline === 'zero-shot-classification' && (
94
  <ZeroShotClassificationConfig />
95
  )}
96
+ {pipeline === 'image-classification' && (
97
+ <ImageClassificationConfig />
98
+ )}
99
  </div>
100
  </div>
101
  </div>
src/components/pipelines/ImageClassification.tsx ADDED
@@ -0,0 +1,489 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { useState, useRef, useCallback, useEffect } from 'react'
2
+ import {
3
+ Upload,
4
+ Trash2,
5
+ Loader2,
6
+ X,
7
+ Eye,
8
+ EyeOff,
9
+ Image as ImageIcon
10
+ } from 'lucide-react'
11
+ import {
12
+ ImageClassificationWorkerInput,
13
+ WorkerMessage,
14
+ ImageExample
15
+ } from '../../types'
16
+ import { useModel } from '../../contexts/ModelContext'
17
+ import { useImageClassification } from '../../contexts/ImageClassificationContext'
18
+
19
+ // Sample images for quick testing (placeholder URLs)
20
+ const SAMPLE_IMAGES = [
21
+ {
22
+ name: 'Cat',
23
+ url: 'https://images.unsplash.com/photo-1514888286974-6c03e2ca1dba?w=300&h=300&fit=crop'
24
+ },
25
+ {
26
+ name: 'Dog',
27
+ url: 'https://images.unsplash.com/photo-1552053831-71594a27632d?w=300&h=300&fit=crop'
28
+ },
29
+ {
30
+ name: 'Car',
31
+ url: 'https://images.unsplash.com/photo-1605559424843-9e4c228bf1c2?w=300&h=300&fit=crop'
32
+ },
33
+ {
34
+ name: 'Flower',
35
+ url: 'https://images.unsplash.com/photo-1490750967868-88aa4486c946?w=300&h=300&fit=crop'
36
+ }
37
+ ]
38
+
39
+ function ImageClassification() {
40
+ const {
41
+ examples,
42
+ selectedExample,
43
+ setSelectedExample,
44
+ addExample,
45
+ removeExample,
46
+ updateExample,
47
+ clearExamples,
48
+ topK
49
+ } = useImageClassification()
50
+
51
+ const [isClassifying, setIsClassifying] = useState<boolean>(false)
52
+ const [showPreviews, setShowPreviews] = useState<boolean>(true)
53
+ const [dragOver, setDragOver] = useState<boolean>(false)
54
+ const [progress, setProgress] = useState<number | null>(null)
55
+
56
+ const {
57
+ activeWorker,
58
+ status,
59
+ modelInfo,
60
+ hasBeenLoaded,
61
+ selectedQuantization
62
+ } = useModel()
63
+
64
+ const fileInputRef = useRef<HTMLInputElement>(null)
65
+ const dropZoneRef = useRef<HTMLDivElement>(null)
66
+
67
+ const classifyImage = useCallback(
68
+ async (example: ImageExample) => {
69
+ if (!modelInfo || !activeWorker || !example.url) return
70
+
71
+ updateExample(example.id, { isLoading: true })
72
+ setIsClassifying(true)
73
+ setProgress(0)
74
+
75
+ const message: ImageClassificationWorkerInput = {
76
+ type: 'classify',
77
+ image: example.url,
78
+ model: modelInfo.id,
79
+ dtype: selectedQuantization ?? 'fp32',
80
+ topK
81
+ }
82
+
83
+ activeWorker.postMessage(message)
84
+ },
85
+ [modelInfo, activeWorker, selectedQuantization, topK, updateExample]
86
+ )
87
+
88
+ const handleFileSelect = useCallback(
89
+ (files: FileList | null) => {
90
+ if (!files) return
91
+
92
+ Array.from(files).forEach((file) => {
93
+ if (file.type.startsWith('image/')) {
94
+ addExample(file)
95
+ }
96
+ })
97
+ },
98
+ [addExample]
99
+ )
100
+
101
+ const handleDragOver = useCallback((e: React.DragEvent) => {
102
+ e.preventDefault()
103
+ setDragOver(true)
104
+ }, [])
105
+
106
+ const handleDragLeave = useCallback((e: React.DragEvent) => {
107
+ e.preventDefault()
108
+ setDragOver(false)
109
+ }, [])
110
+
111
+ const handleDrop = useCallback(
112
+ (e: React.DragEvent) => {
113
+ e.preventDefault()
114
+ setDragOver(false)
115
+ handleFileSelect(e.dataTransfer.files)
116
+ },
117
+ [handleFileSelect]
118
+ )
119
+
120
+ const handleClassifyAll = useCallback(() => {
121
+ const imagesToClassify = examples.filter(
122
+ (ex) => !ex.predictions && !ex.isLoading
123
+ )
124
+
125
+ imagesToClassify.forEach((example) => {
126
+ classifyImage(example)
127
+ })
128
+ }, [examples, classifyImage])
129
+
130
+ const handleSelectExample = useCallback(
131
+ (example: ImageExample) => {
132
+ setSelectedExample(example)
133
+ },
134
+ [setSelectedExample]
135
+ )
136
+
137
+ const handleLoadSampleImages = useCallback(async () => {
138
+ for (const sample of SAMPLE_IMAGES) {
139
+ try {
140
+ const response = await fetch(sample.url)
141
+ const blob = await response.blob()
142
+ const file = new File([blob], sample.name, { type: blob.type })
143
+ addExample(file)
144
+ } catch (error) {
145
+ console.error(`Failed to load sample image ${sample.name}:`, error)
146
+ }
147
+ }
148
+ }, [addExample])
149
+
150
+ useEffect(() => {
151
+ if (!activeWorker) return
152
+
153
+ const onMessageReceived = (e: MessageEvent<WorkerMessage>) => {
154
+ const { status, output, progress: workerProgress } = e.data
155
+
156
+ if (status === 'progress' && workerProgress !== undefined) {
157
+ setProgress(workerProgress)
158
+ } else if (status === 'output' && output?.predictions) {
159
+ // Find the example that was being processed
160
+ const processingExample = examples.find((ex) => ex.isLoading)
161
+ if (processingExample) {
162
+ updateExample(processingExample.id, {
163
+ predictions: output.predictions,
164
+ isLoading: false
165
+ })
166
+ }
167
+ setIsClassifying(false)
168
+ setProgress(null)
169
+ } else if (status === 'error') {
170
+ // Clear loading state for all examples
171
+ examples.forEach((ex) => {
172
+ if (ex.isLoading) {
173
+ updateExample(ex.id, { isLoading: false })
174
+ }
175
+ })
176
+ setIsClassifying(false)
177
+ setProgress(null)
178
+ }
179
+ }
180
+
181
+ activeWorker.addEventListener('message', onMessageReceived)
182
+ return () => activeWorker.removeEventListener('message', onMessageReceived)
183
+ }, [activeWorker, examples, updateExample])
184
+
185
+ const busy = status !== 'ready' || isClassifying
186
+
187
+ return (
188
+ <div className="flex flex-col h-full max-h-[92vh] w-full p-4">
189
+ <div className="flex items-center justify-between mb-4">
190
+ <h1 className="text-2xl font-bold">Image Classification</h1>
191
+ <div className="flex gap-2">
192
+ <button
193
+ onClick={handleLoadSampleImages}
194
+ disabled={!hasBeenLoaded || isClassifying}
195
+ className="px-3 py-2 bg-purple-100 hover:bg-purple-200 disabled:bg-gray-100 disabled:cursor-not-allowed rounded-lg transition-colors text-sm"
196
+ title="Load Sample Images"
197
+ >
198
+ Load Samples
199
+ </button>
200
+ <button
201
+ onClick={() => setShowPreviews(!showPreviews)}
202
+ className="p-2 bg-blue-100 hover:bg-blue-200 rounded-lg transition-colors"
203
+ title={showPreviews ? 'Hide Previews' : 'Show Previews'}
204
+ >
205
+ {showPreviews ? (
206
+ <EyeOff className="w-4 h-4" />
207
+ ) : (
208
+ <Eye className="w-4 h-4" />
209
+ )}
210
+ </button>
211
+ <button
212
+ onClick={clearExamples}
213
+ className="p-2 bg-red-100 hover:bg-red-200 rounded-lg transition-colors"
214
+ title="Clear All Images"
215
+ >
216
+ <Trash2 className="w-4 h-4" />
217
+ </button>
218
+ </div>
219
+ </div>
220
+
221
+ <div className="flex flex-col lg:flex-row gap-4 flex-1">
222
+ {/* Left Panel - Image Upload and List */}
223
+ <div className="lg:w-1/2 flex flex-col">
224
+ {/* Upload Area */}
225
+ <div className="mb-4">
226
+ <label className="block text-sm font-medium text-gray-700 mb-2">
227
+ Upload Images:
228
+ </label>
229
+ <div
230
+ ref={dropZoneRef}
231
+ onDragOver={handleDragOver}
232
+ onDragLeave={handleDragLeave}
233
+ onDrop={handleDrop}
234
+ className={`border-2 border-dashed rounded-lg p-6 text-center transition-colors cursor-pointer ${
235
+ dragOver
236
+ ? 'border-blue-500 bg-blue-50'
237
+ : 'border-gray-300 hover:border-gray-400'
238
+ } ${!hasBeenLoaded ? 'opacity-50 cursor-not-allowed' : ''}`}
239
+ onClick={() => hasBeenLoaded && fileInputRef.current?.click()}
240
+ >
241
+ <Upload className="w-8 h-8 mx-auto mb-2 text-gray-400" />
242
+ <p className="text-sm text-gray-600">
243
+ {dragOver
244
+ ? 'Drop images here'
245
+ : 'Click to upload or drag and drop images'}
246
+ </p>
247
+ <p className="text-xs text-gray-500 mt-1">
248
+ Supports JPG, PNG, GIF, WebP
249
+ </p>
250
+ <input
251
+ ref={fileInputRef}
252
+ type="file"
253
+ multiple
254
+ accept="image/*"
255
+ onChange={(e) => handleFileSelect(e.target.files)}
256
+ className="hidden"
257
+ disabled={!hasBeenLoaded}
258
+ />
259
+ </div>
260
+ </div>
261
+
262
+ {/* Classify Button */}
263
+ {examples.some((ex) => !ex.predictions) && (
264
+ <div className="mb-4">
265
+ <button
266
+ onClick={handleClassifyAll}
267
+ disabled={busy || !hasBeenLoaded}
268
+ className="px-6 py-2 bg-green-500 hover:bg-green-600 disabled:bg-gray-300 disabled:cursor-not-allowed text-white rounded-lg transition-colors flex items-center gap-2"
269
+ >
270
+ {isClassifying ? (
271
+ <>
272
+ <Loader2 className="w-4 h-4 animate-spin" />
273
+ Classifying...
274
+ {progress !== null && ` (${Math.round(progress * 100)}%)`}
275
+ </>
276
+ ) : (
277
+ 'Classify Images'
278
+ )}
279
+ </button>
280
+ </div>
281
+ )}
282
+
283
+ {/* Images List */}
284
+ <div className="flex-1 overflow-y-auto border border-gray-300 rounded-lg bg-white">
285
+ <div className="p-4">
286
+ <h3 className="text-sm font-medium text-gray-700 mb-3">
287
+ Images ({examples.length})
288
+ </h3>
289
+ {examples.length === 0 ? (
290
+ <div className="text-gray-500 italic text-center py-8">
291
+ No images uploaded yet. Upload some images above to get
292
+ started.
293
+ </div>
294
+ ) : (
295
+ <div className="space-y-3">
296
+ {examples.map((example) => (
297
+ <div
298
+ key={example.id}
299
+ className={`p-3 border rounded-lg cursor-pointer transition-colors ${
300
+ selectedExample?.id === example.id
301
+ ? 'border-blue-500 bg-blue-50'
302
+ : 'border-gray-200 hover:border-gray-300'
303
+ }`}
304
+ onClick={() => handleSelectExample(example)}
305
+ >
306
+ <div className="flex gap-3">
307
+ {showPreviews && (
308
+ <div className="flex-shrink-0">
309
+ <img
310
+ src={example.url}
311
+ alt={example.name}
312
+ className="w-16 h-16 object-cover rounded-lg"
313
+ />
314
+ </div>
315
+ )}
316
+ <div className="flex-1 min-w-0">
317
+ <div className="flex justify-between items-start">
318
+ <div className="flex-1 min-w-0">
319
+ <div className="text-sm font-medium text-gray-800 truncate">
320
+ {example.name}
321
+ </div>
322
+ <div className="flex items-center gap-2 mt-1">
323
+ {example.isLoading ? (
324
+ <div className="flex items-center gap-1 text-xs text-blue-600">
325
+ <Loader2 className="w-3 h-3 animate-spin" />
326
+ Classifying...
327
+ </div>
328
+ ) : example.predictions ? (
329
+ <div className="text-xs text-green-600">
330
+ ✓ Classified
331
+ </div>
332
+ ) : (
333
+ <div className="text-xs text-gray-500">
334
+ Not classified
335
+ </div>
336
+ )}
337
+ {selectedExample?.id === example.id && (
338
+ <div className="text-xs text-blue-600">
339
+ Selected
340
+ </div>
341
+ )}
342
+ </div>
343
+ </div>
344
+ <button
345
+ onClick={(e) => {
346
+ e.stopPropagation()
347
+ removeExample(example.id)
348
+ }}
349
+ className="ml-2 p-1 text-red-500 hover:text-red-700 transition-colors"
350
+ >
351
+ <X className="w-3 h-3" />
352
+ </button>
353
+ </div>
354
+ </div>
355
+ </div>
356
+ </div>
357
+ ))}
358
+ </div>
359
+ )}
360
+ </div>
361
+ </div>
362
+ </div>
363
+
364
+ {/* Right Panel - Preview and Results */}
365
+ <div className="lg:w-1/2 flex flex-col">
366
+ {/* Image Preview */}
367
+ {selectedExample && (
368
+ <div className="mb-4">
369
+ <h3 className="text-sm font-medium text-gray-700 mb-2">
370
+ Selected Image
371
+ </h3>
372
+ <div className="border border-gray-300 rounded-lg bg-white p-4">
373
+ <div className="flex flex-col items-center">
374
+ <img
375
+ src={selectedExample.url}
376
+ alt={selectedExample.name}
377
+ className="max-w-full max-h-64 object-contain rounded-lg mb-2"
378
+ />
379
+ <div className="text-sm text-gray-600 text-center">
380
+ {selectedExample.name}
381
+ </div>
382
+ </div>
383
+ </div>
384
+ </div>
385
+ )}
386
+
387
+ {/* Classification Results */}
388
+ <div className="flex-1 overflow-y-auto border border-gray-300 rounded-lg bg-white">
389
+ <div className="p-4">
390
+ <h3 className="text-sm font-medium text-gray-700 mb-3">
391
+ Classification Results
392
+ {selectedExample && ` - ${selectedExample.name}`}
393
+ </h3>
394
+ {!selectedExample ? (
395
+ <div className="text-gray-500 italic text-center py-8">
396
+ <ImageIcon className="w-12 h-12 mx-auto mb-2 text-gray-300" />
397
+ Select an image to see classification results
398
+ </div>
399
+ ) : selectedExample.isLoading ? (
400
+ <div className="text-center py-8">
401
+ <Loader2 className="w-8 h-8 animate-spin mx-auto mb-2 text-blue-500" />
402
+ <div className="text-sm text-gray-600">
403
+ Classifying image...
404
+ </div>
405
+ {progress !== null && (
406
+ <div className="text-xs text-gray-500 mt-1">
407
+ {Math.round(progress * 100)}% complete
408
+ </div>
409
+ )}
410
+ </div>
411
+ ) : !selectedExample.predictions ? (
412
+ <div className="text-gray-500 italic text-center py-8">
413
+ <button
414
+ onClick={() => classifyImage(selectedExample)}
415
+ disabled={busy || !hasBeenLoaded}
416
+ className="px-4 py-2 bg-blue-500 hover:bg-blue-600 disabled:bg-gray-300 disabled:cursor-not-allowed text-white rounded-lg transition-colors"
417
+ >
418
+ Classify This Image
419
+ </button>
420
+ </div>
421
+ ) : (
422
+ <div className="space-y-3">
423
+ {selectedExample.predictions.map((prediction, index) => {
424
+ const confidencePercent = (prediction.score * 100).toFixed(
425
+ 1
426
+ )
427
+ const isTopPrediction = index === 0
428
+
429
+ return (
430
+ <div
431
+ key={index}
432
+ className={`p-3 border rounded-lg ${
433
+ isTopPrediction
434
+ ? 'border-green-300 bg-green-50'
435
+ : 'border-gray-200'
436
+ }`}
437
+ >
438
+ <div className="flex justify-between items-center mb-2">
439
+ <div className="flex items-center gap-2">
440
+ <span className="text-sm font-medium text-gray-800">
441
+ {prediction.label}
442
+ </span>
443
+ </div>
444
+ <span className="text-sm font-medium text-gray-600">
445
+ {confidencePercent}%
446
+ </span>
447
+ </div>
448
+ <div className="w-full bg-gray-200 rounded-full h-2">
449
+ <div
450
+ className={`h-2 rounded-full transition-all duration-300 ${
451
+ isTopPrediction
452
+ ? 'bg-green-500'
453
+ : prediction.score > 0.5
454
+ ? 'bg-blue-500'
455
+ : 'bg-gray-400'
456
+ }`}
457
+ style={{
458
+ width: `${Math.max(prediction.score * 100, 2)}%`
459
+ }}
460
+ />
461
+ </div>
462
+ </div>
463
+ )
464
+ })}
465
+ </div>
466
+ )}
467
+ </div>
468
+ </div>
469
+ </div>
470
+ </div>
471
+
472
+ {!hasBeenLoaded && (
473
+ <div className="text-center text-gray-500 text-sm mt-2">
474
+ Please load an image classification model first to start classifying
475
+ images
476
+ </div>
477
+ )}
478
+
479
+ {hasBeenLoaded && examples.length === 0 && (
480
+ <div className="text-center text-blue-600 text-sm mt-2">
481
+ 💡 Tip: Click "Load Samples" to try with example images, or upload
482
+ your own images above
483
+ </div>
484
+ )}
485
+ </div>
486
+ )
487
+ }
488
+
489
+ export default ImageClassification
src/components/pipelines/ImageClassificationConfig.tsx ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import React from 'react'
2
+ import { useImageClassification } from '../../contexts/ImageClassificationContext'
3
+
4
+ const ImageClassificationConfig = () => {
5
+ const { topK, setTopK } = useImageClassification()
6
+
7
+ return (
8
+ <div className="space-y-4">
9
+ <h3 className="text-lg font-semibold text-gray-900">
10
+ Image Classification Settings
11
+ </h3>
12
+
13
+ <div className="space-y-3">
14
+ <div>
15
+ <label className="block text-sm font-medium text-gray-700 mb-1">
16
+ Top K Predictions: {topK}
17
+ </label>
18
+ <input
19
+ type="range"
20
+ min="1"
21
+ max="10"
22
+ step="1"
23
+ value={topK}
24
+ onChange={(e) => setTopK(parseInt(e.target.value))}
25
+ className="w-full h-2 bg-gray-200 rounded-lg appearance-none cursor-pointer"
26
+ />
27
+ <div className="flex justify-between text-xs text-gray-400 mt-1">
28
+ <span>1</span>
29
+ <span>5</span>
30
+ <span>10</span>
31
+ </div>
32
+ <p className="text-xs text-gray-500 mt-1">
33
+ Number of top predictions to return for each image
34
+ </p>
35
+ </div>
36
+
37
+ <div className="p-3 bg-yellow-50 border border-yellow-200 rounded-lg">
38
+ <h4 className="text-sm font-medium text-yellow-800 mb-2">💡 Tips</h4>
39
+ <div className="text-xs text-yellow-700 space-y-1">
40
+ <p>• Use Top K = 3-5 for most cases</p>
41
+ <p>• Smaller images process faster</p>
42
+ <p>• Try quantized models for speed</p>
43
+ </div>
44
+ </div>
45
+ </div>
46
+ </div>
47
+ )
48
+ }
49
+
50
+ export default ImageClassificationConfig
src/contexts/ImageClassificationContext.tsx ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import React, { createContext, useContext, useState, useCallback } from 'react'
2
+ import { ImageExample } from '../types'
3
+
4
+ interface ImageClassificationContextType {
5
+ examples: ImageExample[]
6
+ selectedExample: ImageExample | null
7
+ setSelectedExample: (example: ImageExample | null) => void
8
+ addExample: (file: File) => void
9
+ removeExample: (id: string) => void
10
+ updateExample: (id: string, updates: Partial<ImageExample>) => void
11
+ clearExamples: () => void
12
+ topK: number
13
+ setTopK: (k: number) => void
14
+ }
15
+
16
+ const ImageClassificationContext = createContext<
17
+ ImageClassificationContextType | undefined
18
+ >(undefined)
19
+
20
+ export function useImageClassification() {
21
+ const context = useContext(ImageClassificationContext)
22
+ if (context === undefined) {
23
+ throw new Error(
24
+ 'useImageClassification must be used within an ImageClassificationProvider'
25
+ )
26
+ }
27
+ return context
28
+ }
29
+
30
+ interface ImageClassificationProviderProps {
31
+ children: React.ReactNode
32
+ }
33
+
34
+ export function ImageClassificationProvider({
35
+ children
36
+ }: ImageClassificationProviderProps) {
37
+ const [examples, setExamples] = useState<ImageExample[]>([])
38
+ const [selectedExample, setSelectedExample] = useState<ImageExample | null>(
39
+ null
40
+ )
41
+ const [topK, setTopK] = useState<number>(5)
42
+
43
+ const addExample = useCallback((file: File) => {
44
+ const id = Math.random().toString(36).substr(2, 9)
45
+ const url = URL.createObjectURL(file)
46
+
47
+ const newExample: ImageExample = {
48
+ id,
49
+ name: file.name,
50
+ url,
51
+ file,
52
+ predictions: undefined,
53
+ isLoading: false
54
+ }
55
+
56
+ setExamples((prev) => [...prev, newExample])
57
+ }, [])
58
+
59
+ const removeExample = useCallback((id: string) => {
60
+ setExamples((prev) => {
61
+ const updated = prev.filter((ex) => ex.id !== id)
62
+ // Clean up object URL to prevent memory leaks
63
+ const example = prev.find((ex) => ex.id === id)
64
+ if (example?.url) {
65
+ URL.revokeObjectURL(example.url)
66
+ }
67
+ return updated
68
+ })
69
+
70
+ // Clear selection if the selected example was removed
71
+ setSelectedExample((prev) => (prev?.id === id ? null : prev))
72
+ }, [])
73
+
74
+ const updateExample = useCallback(
75
+ (id: string, updates: Partial<ImageExample>) => {
76
+ setExamples((prev) =>
77
+ prev.map((ex) => (ex.id === id ? { ...ex, ...updates } : ex))
78
+ )
79
+
80
+ // Update selected example if it's the one being updated
81
+ setSelectedExample((prev) =>
82
+ prev?.id === id ? { ...prev, ...updates } : prev
83
+ )
84
+ },
85
+ []
86
+ )
87
+
88
+ const clearExamples = useCallback(() => {
89
+ // Clean up all object URLs to prevent memory leaks
90
+ examples.forEach((example) => {
91
+ if (example.url) {
92
+ URL.revokeObjectURL(example.url)
93
+ }
94
+ })
95
+
96
+ setExamples([])
97
+ setSelectedExample(null)
98
+ }, [examples])
99
+
100
+ const value: ImageClassificationContextType = {
101
+ examples,
102
+ selectedExample,
103
+ setSelectedExample,
104
+ addExample,
105
+ removeExample,
106
+ updateExample,
107
+ clearExamples,
108
+ topK,
109
+ setTopK
110
+ }
111
+
112
+ return (
113
+ <ImageClassificationContext.Provider value={value}>
114
+ {children}
115
+ </ImageClassificationContext.Provider>
116
+ )
117
+ }
src/lib/huggingface.ts CHANGED
@@ -34,9 +34,9 @@ const getModelInfo = async (
34
  const modelData: ModelInfoResponse = await response.json()
35
 
36
  const requiredFiles = [
37
- 'config.json',
38
- 'tokenizer.json',
39
- 'tokenizer_config.json'
40
  ]
41
 
42
  const siblingFiles = modelData.siblings?.map((s) => s.rfilename) || []
 
34
  const modelData: ModelInfoResponse = await response.json()
35
 
36
  const requiredFiles = [
37
+ 'config.json'
38
+ // 'tokenizer.json',
39
+ // 'tokenizer_config.json'
40
  ]
41
 
42
  const siblingFiles = modelData.siblings?.map((s) => s.rfilename) || []
src/lib/workerManager.ts CHANGED
@@ -17,6 +17,9 @@ export const getWorker = (pipeline: string) => {
17
  case 'feature-extraction':
18
  workerUrl = `${process.env.PUBLIC_URL}/workers/feature-extraction.js`
19
  break
 
 
 
20
  default:
21
  return null
22
  }
 
17
  case 'feature-extraction':
18
  workerUrl = `${process.env.PUBLIC_URL}/workers/feature-extraction.js`
19
  break
20
+ case 'image-classification':
21
+ workerUrl = `${process.env.PUBLIC_URL}/workers/image-classification.js`
22
+ break
23
  default:
24
  return null
25
  }
src/types.ts CHANGED
@@ -75,6 +75,28 @@ export interface FeatureExtractionWorkerInput {
75
  }
76
  }
77
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  export interface EmbeddingExample {
79
  id: string
80
  text: string
 
75
  }
76
  }
77
 
78
+ export interface ImageClassificationWorkerInput {
79
+ type: 'classify'
80
+ image: string | ImageData | HTMLImageElement | HTMLCanvasElement
81
+ model: string
82
+ dtype: QuantizationType
83
+ topK?: number
84
+ }
85
+
86
+ export interface ImageClassificationResult {
87
+ label: string
88
+ score: number
89
+ }
90
+
91
+ export interface ImageExample {
92
+ id: string
93
+ name: string
94
+ url: string
95
+ file?: File
96
+ predictions?: ImageClassificationResult[]
97
+ isLoading?: boolean
98
+ }
99
+
100
  export interface EmbeddingExample {
101
  id: string
102
  text: string