Vokturz commited on
Commit
31283f8
·
1 Parent(s): 4d213de

remove pooling max since is not supported by transformers.js, and fix

Browse files
src/components/ModelLoader.tsx CHANGED
@@ -1,13 +1,13 @@
1
  import { useEffect, useCallback, useState } from 'react'
2
- import { ChevronDown, Loader, X } from 'lucide-react'
3
  import { QuantizationType, WorkerMessage } from '../types'
4
  import { useModel } from '../contexts/ModelContext'
5
  import { getWorker } from '../lib/workerManager'
6
  import { Alert, AlertDescription } from './ui/alert'
7
 
8
  const ModelLoader = () => {
9
- const [isError, setIsError] = useState(false)
10
- const [errorMessage, setErrorMessage] = useState('')
11
  const {
12
  modelInfo,
13
  selectedQuantization,
@@ -75,16 +75,23 @@ const ModelLoader = () => {
75
  output.file.startsWith('onnx')
76
  ) {
77
  setProgress(output.progress)
 
 
 
 
 
 
 
78
  }
79
  } else if (status === 'error') {
80
  setStatus('error')
81
  const error = e.data.output
82
  console.error(error)
83
- setErrorMessage(error.split('.')[0] + '. See console for details.')
84
- setIsError(true)
85
  setTimeout(() => {
86
- setIsError(false)
87
- setErrorMessage('')
88
  }, 3000)
89
  }
90
  }
@@ -106,6 +113,15 @@ const ModelLoader = () => {
106
  setHasBeenLoaded
107
  ])
108
 
 
 
 
 
 
 
 
 
 
109
  const loadModel = useCallback(() => {
110
  if (!modelInfo || !selectedQuantization) return
111
 
@@ -164,7 +180,7 @@ const ModelLoader = () => {
164
  >
165
  {status === 'loading' && !hasBeenLoaded ? (
166
  <>
167
- <Loader className="animate-spin h-4 w-4" />
168
  <span>{progress.toFixed(0)}%</span>
169
  </>
170
  ) : (
@@ -174,10 +190,12 @@ const ModelLoader = () => {
174
  </div>
175
  )}
176
  </div>
177
- {isError && (
178
  <div className="fixed bottom-0 right-0 m-2">
179
- <Alert variant="destructive">
180
- <AlertDescription>{errorMessage}</AlertDescription>
 
 
181
  </Alert>
182
  </div>
183
  )}
 
1
  import { useEffect, useCallback, useState } from 'react'
2
+ import { ChevronDown, Loader2, X } from 'lucide-react'
3
  import { QuantizationType, WorkerMessage } from '../types'
4
  import { useModel } from '../contexts/ModelContext'
5
  import { getWorker } from '../lib/workerManager'
6
  import { Alert, AlertDescription } from './ui/alert'
7
 
8
  const ModelLoader = () => {
9
+ const [showAlert, setShowAlert] = useState(false)
10
+ const [alertMessage, setAlertMessage] = useState<React.ReactNode>('')
11
  const {
12
  modelInfo,
13
  selectedQuantization,
 
75
  output.file.startsWith('onnx')
76
  ) {
77
  setProgress(output.progress)
78
+ setShowAlert(true)
79
+ setAlertMessage(
80
+ <div className="flex items-center">
81
+ <Loader2 className="animate-spin h-4 w-4 mr-2" />
82
+ Loading Model
83
+ </div>
84
+ )
85
  }
86
  } else if (status === 'error') {
87
  setStatus('error')
88
  const error = e.data.output
89
  console.error(error)
90
+ setAlertMessage(error.split('.')[0] + '. See console for details.')
91
+ setShowAlert(true)
92
  setTimeout(() => {
93
+ setShowAlert(false)
94
+ setAlertMessage('')
95
  }, 3000)
96
  }
97
  }
 
113
  setHasBeenLoaded
114
  ])
115
 
116
+ useEffect(() => {
117
+ if (progress === 100) {
118
+ setTimeout(() => {
119
+ setShowAlert(false)
120
+ setAlertMessage('')
121
+ }, 2000)
122
+ }
123
+ }, [progress])
124
+
125
  const loadModel = useCallback(() => {
126
  if (!modelInfo || !selectedQuantization) return
127
 
 
180
  >
181
  {status === 'loading' && !hasBeenLoaded ? (
182
  <>
183
+ <Loader2 className="animate-spin h-4 w-4" />
184
  <span>{progress.toFixed(0)}%</span>
185
  </>
186
  ) : (
 
190
  </div>
191
  )}
192
  </div>
193
+ {showAlert && (
194
  <div className="fixed bottom-0 right-0 m-2">
195
+ <Alert
196
+ variant={`${typeof alertMessage === 'string' ? 'destructive' : 'default'}`}
197
+ >
198
+ <AlertDescription>{alertMessage}</AlertDescription>
199
  </Alert>
200
  </div>
201
  )}
src/components/PipelineSelector.tsx CHANGED
@@ -5,7 +5,7 @@ import {
5
  SelectItem,
6
  SelectTrigger,
7
  SelectValue
8
- } from '@/components/ui/select' // Adjust the import path as needed
9
 
10
  export const supportedPipelines = [
11
  'feature-extraction',
 
5
  SelectItem,
6
  SelectTrigger,
7
  SelectValue
8
+ } from '@/components/ui/select'
9
 
10
  export const supportedPipelines = [
11
  'feature-extraction',
src/components/pipelines/FeatureExtraction.tsx CHANGED
@@ -46,7 +46,13 @@ function FeatureExtraction() {
46
 
47
  const [newExampleText, setNewExampleText] = useState<string>('')
48
  const [isExtracting, setIsExtracting] = useState<boolean>(false)
49
- const [showVisualization, setShowVisualization] = useState<boolean>(true)
 
 
 
 
 
 
50
  const [progress, setProgress] = useState<{
51
  completed: number
52
  total: number
@@ -60,6 +66,18 @@ function FeatureExtraction() {
60
  selectedQuantization
61
  } = useModel()
62
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  const chartRef = useRef<SVGSVGElement>(null)
64
 
65
  // PCA reduction to 2D for visualization
@@ -215,7 +233,6 @@ function FeatureExtraction() {
215
  })
216
  }
217
  })
218
- console.log({ examples })
219
  setIsExtracting(false)
220
  setProgress(null)
221
  } else if (status === 'error') {
 
46
 
47
  const [newExampleText, setNewExampleText] = useState<string>('')
48
  const [isExtracting, setIsExtracting] = useState<boolean>(false)
49
+ const [showVisualization, setShowVisualization] = useState<boolean>(() => {
50
+ if (typeof window !== 'undefined') {
51
+ return window.innerWidth >= 768
52
+ }
53
+ return true
54
+ })
55
+
56
  const [progress, setProgress] = useState<{
57
  completed: number
58
  total: number
 
66
  selectedQuantization
67
  } = useModel()
68
 
69
+ useEffect(() => {
70
+ const handleResize = () => {
71
+ setShowVisualization(window.innerWidth >= 768)
72
+ }
73
+
74
+ window.addEventListener('resize', handleResize)
75
+
76
+ return () => {
77
+ window.removeEventListener('resize', handleResize)
78
+ }
79
+ }, [])
80
+
81
  const chartRef = useRef<SVGSVGElement>(null)
82
 
83
  // PCA reduction to 2D for visualization
 
233
  })
234
  }
235
  })
 
236
  setIsExtracting(false)
237
  setProgress(null)
238
  } else if (status === 'error') {
src/components/pipelines/FeatureExtractionConfig.tsx CHANGED
@@ -1,5 +1,12 @@
1
  import React from 'react'
2
  import { useFeatureExtraction } from '../../contexts/FeatureExtractionContext'
 
 
 
 
 
 
 
3
 
4
  const FeatureExtractionConfig = () => {
5
  const { config, setConfig } = useFeatureExtraction()
@@ -15,20 +22,36 @@ const FeatureExtractionConfig = () => {
15
  <label className="block text-sm font-medium text-foreground/80 mb-1">
16
  Pooling Strategy
17
  </label>
18
- <select
19
  value={config.pooling}
20
- onChange={(e) =>
21
  setConfig((prev) => ({
22
  ...prev,
23
- pooling: e.target.value as 'mean' | 'cls' | 'max'
24
  }))
25
  }
26
- className="w-full px-3 py-2 border border-input rounded-md shadow-xs focus:outline-hidden focus:ring-2 focus:ring-ring focus:border-ring text-sm"
27
  >
28
- <option value="mean">Mean Pooling</option>
29
- <option value="cls">CLS Token</option>
30
- <option value="max">Max Pooling</option>
31
- </select>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  <p className="text-xs text-muted-foreground mt-1">
33
  How to aggregate token embeddings into sentence embeddings
34
  </p>
@@ -66,10 +89,6 @@ const FeatureExtractionConfig = () => {
66
  <strong>CLS Token:</strong> Use the [CLS] token embedding (if
67
  available)
68
  </p>
69
- <p>
70
- <strong>Max Pooling:</strong> Take element-wise maximum across
71
- tokens
72
- </p>
73
  </div>
74
  </div>
75
  </div>
 
1
  import React from 'react'
2
  import { useFeatureExtraction } from '../../contexts/FeatureExtractionContext'
3
+ import {
4
+ Select,
5
+ SelectContent,
6
+ SelectItem,
7
+ SelectTrigger,
8
+ SelectValue
9
+ } from '@/components/ui/select'
10
 
11
  const FeatureExtractionConfig = () => {
12
  const { config, setConfig } = useFeatureExtraction()
 
22
  <label className="block text-sm font-medium text-foreground/80 mb-1">
23
  Pooling Strategy
24
  </label>
25
+ <Select
26
  value={config.pooling}
27
+ onValueChange={(value) =>
28
  setConfig((prev) => ({
29
  ...prev,
30
+ pooling: value as 'mean' | 'cls'
31
  }))
32
  }
 
33
  >
34
+ <SelectTrigger className="w-full text-sm xl:text-base">
35
+ <SelectValue placeholder="Select a pooling strategy" />
36
+ </SelectTrigger>
37
+ <SelectContent>
38
+ <SelectItem
39
+ key="mean"
40
+ value="mean"
41
+ className="text-sm data-[state=checked]:font-bold"
42
+ >
43
+ Mean Pooling
44
+ </SelectItem>
45
+ <SelectItem
46
+ key="cls"
47
+ value="cls"
48
+ className="text-sm data-[state=checked]:font-bold"
49
+ >
50
+ CLS Token
51
+ </SelectItem>
52
+ </SelectContent>
53
+ </Select>
54
+
55
  <p className="text-xs text-muted-foreground mt-1">
56
  How to aggregate token embeddings into sentence embeddings
57
  </p>
 
89
  <strong>CLS Token:</strong> Use the [CLS] token embedding (if
90
  available)
91
  </p>
 
 
 
 
92
  </div>
93
  </div>
94
  </div>
src/contexts/FeatureExtractionContext.tsx CHANGED
@@ -2,7 +2,7 @@ import React, { createContext, useContext, useState, useCallback } from 'react'
2
  import { EmbeddingExample, SimilarityResult } from '../types'
3
 
4
  interface FeatureExtractionConfig {
5
- pooling: 'mean' | 'cls' | 'max'
6
  normalize: boolean
7
  }
8
 
@@ -10,7 +10,9 @@ interface FeatureExtractionContextType {
10
  examples: EmbeddingExample[]
11
  setExamples: React.Dispatch<React.SetStateAction<EmbeddingExample[]>>
12
  selectedExample: EmbeddingExample | null
13
- setSelectedExample: React.Dispatch<React.SetStateAction<EmbeddingExample | null>>
 
 
14
  similarities: SimilarityResult[]
15
  setSimilarities: React.Dispatch<React.SetStateAction<SimilarityResult[]>>
16
  config: FeatureExtractionConfig
@@ -22,12 +24,16 @@ interface FeatureExtractionContextType {
22
  clearExamples: () => void
23
  }
24
 
25
- const FeatureExtractionContext = createContext<FeatureExtractionContextType | undefined>(undefined)
 
 
26
 
27
  export const useFeatureExtraction = () => {
28
  const context = useContext(FeatureExtractionContext)
29
  if (!context) {
30
- throw new Error('useFeatureExtraction must be used within a FeatureExtractionProvider')
 
 
31
  }
32
  return context
33
  }
@@ -58,9 +64,12 @@ const cosineSimilarity = (a: number[], b: number[]): number => {
58
  return dotProduct / (normA * normB)
59
  }
60
 
61
- export const FeatureExtractionProvider: React.FC<{ children: React.ReactNode }> = ({ children }) => {
 
 
62
  const [examples, setExamples] = useState<EmbeddingExample[]>([])
63
- const [selectedExample, setSelectedExample] = useState<EmbeddingExample | null>(null)
 
64
  const [similarities, setSimilarities] = useState<SimilarityResult[]>([])
65
  const [config, setConfig] = useState<FeatureExtractionConfig>({
66
  pooling: 'mean',
@@ -74,39 +83,55 @@ export const FeatureExtractionProvider: React.FC<{ children: React.ReactNode }>
74
  embedding: undefined,
75
  isLoading: false
76
  }
77
- setExamples(prev => [...prev, newExample])
78
  }, [])
79
 
80
- const removeExample = useCallback((id: string) => {
81
- setExamples(prev => prev.filter(example => example.id !== id))
82
- if (selectedExample?.id === id) {
83
- setSelectedExample(null)
84
- setSimilarities([])
85
- }
86
- }, [selectedExample])
87
-
88
- const updateExample = useCallback((id: string, updates: Partial<EmbeddingExample>) => {
89
- setExamples(prev => prev.map(example =>
90
- example.id === id ? { ...example, ...updates } : example
91
- ))
92
- }, [])
93
-
94
- const calculateSimilarities = useCallback((targetExample: EmbeddingExample) => {
95
- if (!targetExample.embedding) {
96
- setSimilarities([])
97
- return
98
- }
99
 
100
- const newSimilarities: SimilarityResult[] = examples
101
- .filter(example => example.id !== targetExample.id && example.embedding)
102
- .map(example => ({
103
- exampleId: example.id,
104
- similarity: cosineSimilarity(targetExample.embedding!, example.embedding!)
105
- }))
106
- .sort((a, b) => b.similarity - a.similarity)
 
 
 
107
 
108
- setSimilarities(newSimilarities)
109
- }, [examples])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
 
111
  const clearExamples = useCallback(() => {
112
  setExamples([])
 
2
  import { EmbeddingExample, SimilarityResult } from '../types'
3
 
4
  interface FeatureExtractionConfig {
5
+ pooling: 'mean' | 'cls'
6
  normalize: boolean
7
  }
8
 
 
10
  examples: EmbeddingExample[]
11
  setExamples: React.Dispatch<React.SetStateAction<EmbeddingExample[]>>
12
  selectedExample: EmbeddingExample | null
13
+ setSelectedExample: React.Dispatch<
14
+ React.SetStateAction<EmbeddingExample | null>
15
+ >
16
  similarities: SimilarityResult[]
17
  setSimilarities: React.Dispatch<React.SetStateAction<SimilarityResult[]>>
18
  config: FeatureExtractionConfig
 
24
  clearExamples: () => void
25
  }
26
 
27
+ const FeatureExtractionContext = createContext<
28
+ FeatureExtractionContextType | undefined
29
+ >(undefined)
30
 
31
  export const useFeatureExtraction = () => {
32
  const context = useContext(FeatureExtractionContext)
33
  if (!context) {
34
+ throw new Error(
35
+ 'useFeatureExtraction must be used within a FeatureExtractionProvider'
36
+ )
37
  }
38
  return context
39
  }
 
64
  return dotProduct / (normA * normB)
65
  }
66
 
67
+ export const FeatureExtractionProvider: React.FC<{
68
+ children: React.ReactNode
69
+ }> = ({ children }) => {
70
  const [examples, setExamples] = useState<EmbeddingExample[]>([])
71
+ const [selectedExample, setSelectedExample] =
72
+ useState<EmbeddingExample | null>(null)
73
  const [similarities, setSimilarities] = useState<SimilarityResult[]>([])
74
  const [config, setConfig] = useState<FeatureExtractionConfig>({
75
  pooling: 'mean',
 
83
  embedding: undefined,
84
  isLoading: false
85
  }
86
+ setExamples((prev) => [...prev, newExample])
87
  }, [])
88
 
89
+ const removeExample = useCallback(
90
+ (id: string) => {
91
+ setExamples((prev) => prev.filter((example) => example.id !== id))
92
+ if (selectedExample?.id === id) {
93
+ setSelectedExample(null)
94
+ setSimilarities([])
95
+ }
96
+ },
97
+ [selectedExample]
98
+ )
 
 
 
 
 
 
 
 
 
99
 
100
+ const updateExample = useCallback(
101
+ (id: string, updates: Partial<EmbeddingExample>) => {
102
+ setExamples((prev) =>
103
+ prev.map((example) =>
104
+ example.id === id ? { ...example, ...updates } : example
105
+ )
106
+ )
107
+ },
108
+ []
109
+ )
110
 
111
+ const calculateSimilarities = useCallback(
112
+ (targetExample: EmbeddingExample) => {
113
+ if (!targetExample.embedding) {
114
+ setSimilarities([])
115
+ return
116
+ }
117
+
118
+ const newSimilarities: SimilarityResult[] = examples
119
+ .filter(
120
+ (example) => example.id !== targetExample.id && example.embedding
121
+ )
122
+ .map((example) => ({
123
+ exampleId: example.id,
124
+ similarity: cosineSimilarity(
125
+ targetExample.embedding!,
126
+ example.embedding!
127
+ )
128
+ }))
129
+ .sort((a, b) => b.similarity - a.similarity)
130
+
131
+ setSimilarities(newSimilarities)
132
+ },
133
+ [examples]
134
+ )
135
 
136
  const clearExamples = useCallback(() => {
137
  setExamples([])
src/types.ts CHANGED
@@ -75,7 +75,7 @@ export interface FeatureExtractionWorkerInput {
75
  model: string
76
  dtype: QuantizationType
77
  config: {
78
- pooling: 'mean' | 'cls' | 'max'
79
  normalize: boolean
80
  }
81
  }
 
75
  model: string
76
  dtype: QuantizationType
77
  config: {
78
+ pooling: 'mean' | 'cls'
79
  normalize: boolean
80
  }
81
  }