Vokturz commited on
Commit
5b8fd7e
·
1 Parent(s): d26ad91

improve text-classification components

Browse files
public/workers/text-classification.js CHANGED
@@ -41,7 +41,7 @@ class MyTextClassificationPipeline {
41
  // Listen for messages from the main thread
42
  self.addEventListener('message', async (event) => {
43
  try {
44
- const { type, model, dtype, text } = event.data
45
 
46
  if (!model) {
47
  self.postMessage({
@@ -76,13 +76,13 @@ self.addEventListener('message', async (event) => {
76
  const split = text.split('\n')
77
  for (const line of split) {
78
  if (line.trim()) {
79
- const output = await classifier(line)
80
  self.postMessage({
81
  status: 'output',
82
  output: {
83
  sequence: line,
84
- labels: [output[0].label],
85
- scores: [output[0].score]
86
  }
87
  })
88
  }
 
41
  // Listen for messages from the main thread
42
  self.addEventListener('message', async (event) => {
43
  try {
44
+ const { type, model, dtype, text, config } = event.data
45
 
46
  if (!model) {
47
  self.postMessage({
 
76
  const split = text.split('\n')
77
  for (const line of split) {
78
  if (line.trim()) {
79
+ const output = await classifier(line, config)
80
  self.postMessage({
81
  status: 'output',
82
  output: {
83
  sequence: line,
84
+ labels: output.map((item) => item.label),
85
+ scores: output.map((item) => item.score)
86
  }
87
  })
88
  }
src/components/ModelCode.tsx CHANGED
@@ -38,7 +38,9 @@ const ModelCode = ({ isCodeModalOpen, setIsCodeModalOpen }: ModelCodeProps) => {
38
  case 'text-classification':
39
  classType = 'classifier'
40
  exampleData = 'I love this product!'
41
- config = {}
 
 
42
  break
43
  case 'text-generation':
44
  classType = 'generator'
 
38
  case 'text-classification':
39
  classType = 'classifier'
40
  exampleData = 'I love this product!'
41
+ config = {
42
+ top_k: 1
43
+ }
44
  break
45
  case 'text-generation':
46
  classType = 'generator'
src/components/PipelineLayout.tsx CHANGED
@@ -3,6 +3,7 @@ import { TextGenerationProvider } from '../contexts/TextGenerationContext'
3
  import { FeatureExtractionProvider } from '../contexts/FeatureExtractionContext'
4
  import { ZeroShotClassificationProvider } from '../contexts/ZeroShotClassificationContext'
5
  import { ImageClassificationProvider } from '../contexts/ImageClassificationContext'
 
6
 
7
  export const PipelineLayout = ({ children }: { children: React.ReactNode }) => {
8
  const { pipeline } = useModel()
@@ -26,6 +27,9 @@ export const PipelineLayout = ({ children }: { children: React.ReactNode }) => {
26
  <ImageClassificationProvider>{children}</ImageClassificationProvider>
27
  )
28
 
 
 
 
29
  default:
30
  return <>{children}</>
31
  }
 
3
  import { FeatureExtractionProvider } from '../contexts/FeatureExtractionContext'
4
  import { ZeroShotClassificationProvider } from '../contexts/ZeroShotClassificationContext'
5
  import { ImageClassificationProvider } from '../contexts/ImageClassificationContext'
6
+ import { TextClassificationProvider } from '../contexts/TextClassificationContext'
7
 
8
  export const PipelineLayout = ({ children }: { children: React.ReactNode }) => {
9
  const { pipeline } = useModel()
 
27
  <ImageClassificationProvider>{children}</ImageClassificationProvider>
28
  )
29
 
30
+ case 'text-classification':
31
+ return <TextClassificationProvider>{children}</TextClassificationProvider>
32
+
33
  default:
34
  return <>{children}</>
35
  }
src/components/PipelineSelector.tsx CHANGED
@@ -12,9 +12,9 @@ export const supportedPipelines = [
12
  'image-classification',
13
  'text-generation',
14
  'zero-shot-classification',
15
- 'text-classification',
16
- 'summarization',
17
- 'translation'
18
  ]
19
 
20
  interface PipelineSelectorProps {
 
12
  'image-classification',
13
  'text-generation',
14
  'zero-shot-classification',
15
+ 'text-classification'
16
+ // 'summarization',
17
+ // 'translation'
18
  ]
19
 
20
  interface PipelineSelectorProps {
src/components/Sidebar.tsx CHANGED
@@ -7,6 +7,7 @@ import TextGenerationConfig from './pipelines/TextGenerationConfig'
7
  import FeatureExtractionConfig from './pipelines/FeatureExtractionConfig'
8
  import ZeroShotClassificationConfig from './pipelines/ZeroShotClassificationConfig'
9
  import ImageClassificationConfig from './pipelines/ImageClassificationConfig'
 
10
  import { Button } from '@/components/ui/button'
11
 
12
  interface SidebarProps {
@@ -102,6 +103,7 @@ const Sidebar = ({
102
  {pipeline === 'image-classification' && (
103
  <ImageClassificationConfig />
104
  )}
 
105
  </div>
106
  </div>
107
  </div>
 
7
  import FeatureExtractionConfig from './pipelines/FeatureExtractionConfig'
8
  import ZeroShotClassificationConfig from './pipelines/ZeroShotClassificationConfig'
9
  import ImageClassificationConfig from './pipelines/ImageClassificationConfig'
10
+ import TextClassificationConfig from './pipelines/TextClassificationConfig'
11
  import { Button } from '@/components/ui/button'
12
 
13
  interface SidebarProps {
 
103
  {pipeline === 'image-classification' && (
104
  <ImageClassificationConfig />
105
  )}
106
+ {pipeline === 'text-classification' && <TextClassificationConfig />}
107
  </div>
108
  </div>
109
  </div>
src/components/pipelines/TextClassification.tsx CHANGED
@@ -1,6 +1,11 @@
1
  import { useState, useCallback, useEffect } from 'react'
2
- import { TextClassificationWorkerInput, WorkerMessage } from '../../types'
 
 
 
 
3
  import { useModel } from '../../contexts/ModelContext'
 
4
 
5
  const PLACEHOLDER_TEXTS: string[] = [
6
  'I absolutely love this product! It exceeded all my expectations.',
@@ -18,7 +23,7 @@ const PLACEHOLDER_TEXTS: string[] = [
18
  function TextClassification() {
19
  const [text, setText] = useState<string>(PLACEHOLDER_TEXTS.join('\n'))
20
  const [numberExamples, setNumberExamples] = useState(PLACEHOLDER_TEXTS.length)
21
- const [results, setResults] = useState<any[]>([])
22
  const {
23
  activeWorker,
24
  status,
@@ -27,6 +32,7 @@ function TextClassification() {
27
  hasBeenLoaded,
28
  selectedQuantization
29
  } = useModel()
 
30
 
31
  useEffect(() => {
32
  if (modelInfo?.widgetData) {
@@ -51,10 +57,11 @@ function TextClassification() {
51
  type: 'classify',
52
  text,
53
  model: modelInfo.id,
54
- dtype: selectedQuantization ?? 'fp32'
 
55
  }
56
  activeWorker.postMessage(message)
57
- }, [text, modelInfo, activeWorker, selectedQuantization, setResults])
58
 
59
  // Handle worker messages
60
  useEffect(() => {
@@ -65,7 +72,7 @@ function TextClassification() {
65
  if (status === 'output') {
66
  setStatus('output')
67
  const result = e.data.output!
68
- setResults((prev: any[]) => [...prev, result])
69
  }
70
  }
71
 
@@ -135,17 +142,47 @@ function TextClassification() {
135
  <div className="space-y-3">
136
  {results.map((result, index) => (
137
  <div key={index} className="p-3 rounded-sm border-2">
138
- <div className="flex justify-between items-start mb-2">
139
- <span className="font-semibold text-sm">
140
- {result.labels[0]}
141
- </span>
142
- <span className="text-sm font-mono">
143
- {(result.scores[0] * 100).toFixed(1)}%
144
- </span>
145
- </div>
146
- <div className="text-sm text-gray-700">
147
  {result.sequence}
148
  </div>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
  </div>
150
  ))}
151
  </div>
 
1
  import { useState, useCallback, useEffect } from 'react'
2
+ import {
3
+ ClassificationOutput,
4
+ TextClassificationWorkerInput,
5
+ WorkerMessage
6
+ } from '../../types'
7
  import { useModel } from '../../contexts/ModelContext'
8
+ import { useTextClassification } from '../../contexts/TextClassificationContext'
9
 
10
  const PLACEHOLDER_TEXTS: string[] = [
11
  'I absolutely love this product! It exceeded all my expectations.',
 
23
  function TextClassification() {
24
  const [text, setText] = useState<string>(PLACEHOLDER_TEXTS.join('\n'))
25
  const [numberExamples, setNumberExamples] = useState(PLACEHOLDER_TEXTS.length)
26
+ const [results, setResults] = useState<ClassificationOutput[]>([])
27
  const {
28
  activeWorker,
29
  status,
 
32
  hasBeenLoaded,
33
  selectedQuantization
34
  } = useModel()
35
+ const { config } = useTextClassification()
36
 
37
  useEffect(() => {
38
  if (modelInfo?.widgetData) {
 
57
  type: 'classify',
58
  text,
59
  model: modelInfo.id,
60
+ dtype: selectedQuantization ?? 'fp32',
61
+ config
62
  }
63
  activeWorker.postMessage(message)
64
+ }, [text, modelInfo, activeWorker, selectedQuantization, config, setResults])
65
 
66
  // Handle worker messages
67
  useEffect(() => {
 
72
  if (status === 'output') {
73
  setStatus('output')
74
  const result = e.data.output!
75
+ setResults((prev: ClassificationOutput[]) => [...prev, result])
76
  }
77
  }
78
 
 
142
  <div className="space-y-3">
143
  {results.map((result, index) => (
144
  <div key={index} className="p-3 rounded-sm border-2">
145
+ <div className="text-sm text-gray-700 mb-3">
 
 
 
 
 
 
 
 
146
  {result.sequence}
147
  </div>
148
+ <div className="space-y-2">
149
+ {result.labels.map(
150
+ (label: string, labelIndex: number) => {
151
+ const score = result.scores[labelIndex]
152
+ const isTopPrediction = labelIndex === 0
153
+
154
+ return (
155
+ <div
156
+ key={labelIndex}
157
+ className={`flex justify-between items-center p-2 rounded ${
158
+ isTopPrediction
159
+ ? 'bg-blue-50 border-l-4 border-blue-500'
160
+ : 'bg-gray-50'
161
+ }`}
162
+ >
163
+ <span
164
+ className={`font-medium text-sm ${
165
+ isTopPrediction
166
+ ? 'text-blue-700'
167
+ : 'text-gray-700'
168
+ }`}
169
+ >
170
+ {label}
171
+ </span>
172
+ <span
173
+ className={`text-sm font-mono ${
174
+ isTopPrediction
175
+ ? 'text-blue-600'
176
+ : 'text-gray-600'
177
+ }`}
178
+ >
179
+ {(score * 100).toFixed(1)}%
180
+ </span>
181
+ </div>
182
+ )
183
+ }
184
+ )}
185
+ </div>
186
  </div>
187
  ))}
188
  </div>
src/components/pipelines/TextClassificationConfig.tsx ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import React from 'react'
2
+ import { useTextClassification } from '../../contexts/TextClassificationContext'
3
+ import { Slider } from '../ui/slider'
4
+
5
+ const TextClassificationConfig = () => {
6
+ const { config, setConfig } = useTextClassification()
7
+
8
+ return (
9
+ <div className="space-y-4">
10
+ <h3 className="text-lg font-semibold text-foreground">
11
+ Text Classification Settings
12
+ </h3>
13
+
14
+ <div className="space-y-3">
15
+ <div>
16
+ <label className="block text-sm font-medium text-foreground/80 mb-1">
17
+ Top K Predictions: {config.top_k}
18
+ </label>
19
+ <Slider
20
+ defaultValue={[config.top_k]}
21
+ min={1}
22
+ max={10}
23
+ step={1}
24
+ onValueChange={(value) => setConfig({ top_k: value[0] })}
25
+ className="w-full rounded-lg"
26
+ />
27
+ <div className="flex justify-between text-xs text-muted-foreground/60 mt-1">
28
+ <span>1</span>
29
+ <span>4</span>
30
+ <span>7</span>
31
+ <span>10</span>
32
+ </div>
33
+ <p className="text-xs text-muted-foreground mt-1">
34
+ Number of top predictions to return for each text
35
+ </p>
36
+ </div>
37
+
38
+ <div className="p-3 bg-chart-4/10 border border-chart-4/20 rounded-lg">
39
+ <h4 className="text-sm font-medium text-chart-4 mb-2">💡 Tips</h4>
40
+ <div className="text-xs text-chart-4 space-y-1">
41
+ <p>• Use Top K = 1-3 for most cases</p>
42
+ <p>• Higher values show more detailed rankings</p>
43
+ <p>• Try quantized models for faster processing</p>
44
+ </div>
45
+ </div>
46
+ </div>
47
+ </div>
48
+ )
49
+ }
50
+
51
+ export default TextClassificationConfig
src/contexts/TextClassificationContext.tsx ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import React, { createContext, useContext, useState } from 'react'
2
+
3
+ interface TextClassificationConfig {
4
+ top_k: number
5
+ }
6
+
7
+ interface TextClassificationContextType {
8
+ config: TextClassificationConfig
9
+ setConfig: React.Dispatch<React.SetStateAction<TextClassificationConfig>>
10
+ }
11
+
12
+ const TextClassificationContext = createContext<
13
+ TextClassificationContextType | undefined
14
+ >(undefined)
15
+
16
+ export function useTextClassification() {
17
+ const context = useContext(TextClassificationContext)
18
+ if (context === undefined) {
19
+ throw new Error(
20
+ 'useTextClassification must be used within a TextClassificationProvider'
21
+ )
22
+ }
23
+ return context
24
+ }
25
+
26
+ interface TextClassificationProviderProps {
27
+ children: React.ReactNode
28
+ }
29
+
30
+ export function TextClassificationProvider({
31
+ children
32
+ }: TextClassificationProviderProps) {
33
+ const [config, setConfig] = useState<TextClassificationConfig>({
34
+ top_k: 1
35
+ })
36
+
37
+ const value: TextClassificationContextType = {
38
+ config,
39
+ setConfig
40
+ }
41
+
42
+ return (
43
+ <TextClassificationContext.Provider value={value}>
44
+ {children}
45
+ </TextClassificationContext.Provider>
46
+ )
47
+ }
src/types.ts CHANGED
@@ -48,6 +48,9 @@ export interface TextClassificationWorkerInput {
48
  text: string
49
  model: string
50
  dtype: QuantizationType
 
 
 
51
  }
52
 
53
  export interface TextGenerationWorkerInput {
 
48
  text: string
49
  model: string
50
  dtype: QuantizationType
51
+ config?: {
52
+ top_k?: number
53
+ }
54
  }
55
 
56
  export interface TextGenerationWorkerInput {