improve text-classification components
Browse files- public/workers/text-classification.js +4 -4
- src/components/ModelCode.tsx +3 -1
- src/components/PipelineLayout.tsx +4 -0
- src/components/PipelineSelector.tsx +3 -3
- src/components/Sidebar.tsx +2 -0
- src/components/pipelines/TextClassification.tsx +51 -14
- src/components/pipelines/TextClassificationConfig.tsx +51 -0
- src/contexts/TextClassificationContext.tsx +47 -0
- src/types.ts +3 -0
public/workers/text-classification.js
CHANGED
@@ -41,7 +41,7 @@ class MyTextClassificationPipeline {
|
|
41 |
// Listen for messages from the main thread
|
42 |
self.addEventListener('message', async (event) => {
|
43 |
try {
|
44 |
-
const { type, model, dtype, text } = event.data
|
45 |
|
46 |
if (!model) {
|
47 |
self.postMessage({
|
@@ -76,13 +76,13 @@ self.addEventListener('message', async (event) => {
|
|
76 |
const split = text.split('\n')
|
77 |
for (const line of split) {
|
78 |
if (line.trim()) {
|
79 |
-
const output = await classifier(line)
|
80 |
self.postMessage({
|
81 |
status: 'output',
|
82 |
output: {
|
83 |
sequence: line,
|
84 |
-
labels:
|
85 |
-
scores:
|
86 |
}
|
87 |
})
|
88 |
}
|
|
|
41 |
// Listen for messages from the main thread
|
42 |
self.addEventListener('message', async (event) => {
|
43 |
try {
|
44 |
+
const { type, model, dtype, text, config } = event.data
|
45 |
|
46 |
if (!model) {
|
47 |
self.postMessage({
|
|
|
76 |
const split = text.split('\n')
|
77 |
for (const line of split) {
|
78 |
if (line.trim()) {
|
79 |
+
const output = await classifier(line, config)
|
80 |
self.postMessage({
|
81 |
status: 'output',
|
82 |
output: {
|
83 |
sequence: line,
|
84 |
+
labels: output.map((item) => item.label),
|
85 |
+
scores: output.map((item) => item.score)
|
86 |
}
|
87 |
})
|
88 |
}
|
src/components/ModelCode.tsx
CHANGED
@@ -38,7 +38,9 @@ const ModelCode = ({ isCodeModalOpen, setIsCodeModalOpen }: ModelCodeProps) => {
|
|
38 |
case 'text-classification':
|
39 |
classType = 'classifier'
|
40 |
exampleData = 'I love this product!'
|
41 |
-
config = {
|
|
|
|
|
42 |
break
|
43 |
case 'text-generation':
|
44 |
classType = 'generator'
|
|
|
38 |
case 'text-classification':
|
39 |
classType = 'classifier'
|
40 |
exampleData = 'I love this product!'
|
41 |
+
config = {
|
42 |
+
top_k: 1
|
43 |
+
}
|
44 |
break
|
45 |
case 'text-generation':
|
46 |
classType = 'generator'
|
src/components/PipelineLayout.tsx
CHANGED
@@ -3,6 +3,7 @@ import { TextGenerationProvider } from '../contexts/TextGenerationContext'
|
|
3 |
import { FeatureExtractionProvider } from '../contexts/FeatureExtractionContext'
|
4 |
import { ZeroShotClassificationProvider } from '../contexts/ZeroShotClassificationContext'
|
5 |
import { ImageClassificationProvider } from '../contexts/ImageClassificationContext'
|
|
|
6 |
|
7 |
export const PipelineLayout = ({ children }: { children: React.ReactNode }) => {
|
8 |
const { pipeline } = useModel()
|
@@ -26,6 +27,9 @@ export const PipelineLayout = ({ children }: { children: React.ReactNode }) => {
|
|
26 |
<ImageClassificationProvider>{children}</ImageClassificationProvider>
|
27 |
)
|
28 |
|
|
|
|
|
|
|
29 |
default:
|
30 |
return <>{children}</>
|
31 |
}
|
|
|
3 |
import { FeatureExtractionProvider } from '../contexts/FeatureExtractionContext'
|
4 |
import { ZeroShotClassificationProvider } from '../contexts/ZeroShotClassificationContext'
|
5 |
import { ImageClassificationProvider } from '../contexts/ImageClassificationContext'
|
6 |
+
import { TextClassificationProvider } from '../contexts/TextClassificationContext'
|
7 |
|
8 |
export const PipelineLayout = ({ children }: { children: React.ReactNode }) => {
|
9 |
const { pipeline } = useModel()
|
|
|
27 |
<ImageClassificationProvider>{children}</ImageClassificationProvider>
|
28 |
)
|
29 |
|
30 |
+
case 'text-classification':
|
31 |
+
return <TextClassificationProvider>{children}</TextClassificationProvider>
|
32 |
+
|
33 |
default:
|
34 |
return <>{children}</>
|
35 |
}
|
src/components/PipelineSelector.tsx
CHANGED
@@ -12,9 +12,9 @@ export const supportedPipelines = [
|
|
12 |
'image-classification',
|
13 |
'text-generation',
|
14 |
'zero-shot-classification',
|
15 |
-
'text-classification'
|
16 |
-
'summarization',
|
17 |
-
'translation'
|
18 |
]
|
19 |
|
20 |
interface PipelineSelectorProps {
|
|
|
12 |
'image-classification',
|
13 |
'text-generation',
|
14 |
'zero-shot-classification',
|
15 |
+
'text-classification'
|
16 |
+
// 'summarization',
|
17 |
+
// 'translation'
|
18 |
]
|
19 |
|
20 |
interface PipelineSelectorProps {
|
src/components/Sidebar.tsx
CHANGED
@@ -7,6 +7,7 @@ import TextGenerationConfig from './pipelines/TextGenerationConfig'
|
|
7 |
import FeatureExtractionConfig from './pipelines/FeatureExtractionConfig'
|
8 |
import ZeroShotClassificationConfig from './pipelines/ZeroShotClassificationConfig'
|
9 |
import ImageClassificationConfig from './pipelines/ImageClassificationConfig'
|
|
|
10 |
import { Button } from '@/components/ui/button'
|
11 |
|
12 |
interface SidebarProps {
|
@@ -102,6 +103,7 @@ const Sidebar = ({
|
|
102 |
{pipeline === 'image-classification' && (
|
103 |
<ImageClassificationConfig />
|
104 |
)}
|
|
|
105 |
</div>
|
106 |
</div>
|
107 |
</div>
|
|
|
7 |
import FeatureExtractionConfig from './pipelines/FeatureExtractionConfig'
|
8 |
import ZeroShotClassificationConfig from './pipelines/ZeroShotClassificationConfig'
|
9 |
import ImageClassificationConfig from './pipelines/ImageClassificationConfig'
|
10 |
+
import TextClassificationConfig from './pipelines/TextClassificationConfig'
|
11 |
import { Button } from '@/components/ui/button'
|
12 |
|
13 |
interface SidebarProps {
|
|
|
103 |
{pipeline === 'image-classification' && (
|
104 |
<ImageClassificationConfig />
|
105 |
)}
|
106 |
+
{pipeline === 'text-classification' && <TextClassificationConfig />}
|
107 |
</div>
|
108 |
</div>
|
109 |
</div>
|
src/components/pipelines/TextClassification.tsx
CHANGED
@@ -1,6 +1,11 @@
|
|
1 |
import { useState, useCallback, useEffect } from 'react'
|
2 |
-
import {
|
|
|
|
|
|
|
|
|
3 |
import { useModel } from '../../contexts/ModelContext'
|
|
|
4 |
|
5 |
const PLACEHOLDER_TEXTS: string[] = [
|
6 |
'I absolutely love this product! It exceeded all my expectations.',
|
@@ -18,7 +23,7 @@ const PLACEHOLDER_TEXTS: string[] = [
|
|
18 |
function TextClassification() {
|
19 |
const [text, setText] = useState<string>(PLACEHOLDER_TEXTS.join('\n'))
|
20 |
const [numberExamples, setNumberExamples] = useState(PLACEHOLDER_TEXTS.length)
|
21 |
-
const [results, setResults] = useState<
|
22 |
const {
|
23 |
activeWorker,
|
24 |
status,
|
@@ -27,6 +32,7 @@ function TextClassification() {
|
|
27 |
hasBeenLoaded,
|
28 |
selectedQuantization
|
29 |
} = useModel()
|
|
|
30 |
|
31 |
useEffect(() => {
|
32 |
if (modelInfo?.widgetData) {
|
@@ -51,10 +57,11 @@ function TextClassification() {
|
|
51 |
type: 'classify',
|
52 |
text,
|
53 |
model: modelInfo.id,
|
54 |
-
dtype: selectedQuantization ?? 'fp32'
|
|
|
55 |
}
|
56 |
activeWorker.postMessage(message)
|
57 |
-
}, [text, modelInfo, activeWorker, selectedQuantization, setResults])
|
58 |
|
59 |
// Handle worker messages
|
60 |
useEffect(() => {
|
@@ -65,7 +72,7 @@ function TextClassification() {
|
|
65 |
if (status === 'output') {
|
66 |
setStatus('output')
|
67 |
const result = e.data.output!
|
68 |
-
setResults((prev:
|
69 |
}
|
70 |
}
|
71 |
|
@@ -135,17 +142,47 @@ function TextClassification() {
|
|
135 |
<div className="space-y-3">
|
136 |
{results.map((result, index) => (
|
137 |
<div key={index} className="p-3 rounded-sm border-2">
|
138 |
-
<div className="
|
139 |
-
<span className="font-semibold text-sm">
|
140 |
-
{result.labels[0]}
|
141 |
-
</span>
|
142 |
-
<span className="text-sm font-mono">
|
143 |
-
{(result.scores[0] * 100).toFixed(1)}%
|
144 |
-
</span>
|
145 |
-
</div>
|
146 |
-
<div className="text-sm text-gray-700">
|
147 |
{result.sequence}
|
148 |
</div>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
149 |
</div>
|
150 |
))}
|
151 |
</div>
|
|
|
1 |
import { useState, useCallback, useEffect } from 'react'
|
2 |
+
import {
|
3 |
+
ClassificationOutput,
|
4 |
+
TextClassificationWorkerInput,
|
5 |
+
WorkerMessage
|
6 |
+
} from '../../types'
|
7 |
import { useModel } from '../../contexts/ModelContext'
|
8 |
+
import { useTextClassification } from '../../contexts/TextClassificationContext'
|
9 |
|
10 |
const PLACEHOLDER_TEXTS: string[] = [
|
11 |
'I absolutely love this product! It exceeded all my expectations.',
|
|
|
23 |
function TextClassification() {
|
24 |
const [text, setText] = useState<string>(PLACEHOLDER_TEXTS.join('\n'))
|
25 |
const [numberExamples, setNumberExamples] = useState(PLACEHOLDER_TEXTS.length)
|
26 |
+
const [results, setResults] = useState<ClassificationOutput[]>([])
|
27 |
const {
|
28 |
activeWorker,
|
29 |
status,
|
|
|
32 |
hasBeenLoaded,
|
33 |
selectedQuantization
|
34 |
} = useModel()
|
35 |
+
const { config } = useTextClassification()
|
36 |
|
37 |
useEffect(() => {
|
38 |
if (modelInfo?.widgetData) {
|
|
|
57 |
type: 'classify',
|
58 |
text,
|
59 |
model: modelInfo.id,
|
60 |
+
dtype: selectedQuantization ?? 'fp32',
|
61 |
+
config
|
62 |
}
|
63 |
activeWorker.postMessage(message)
|
64 |
+
}, [text, modelInfo, activeWorker, selectedQuantization, config, setResults])
|
65 |
|
66 |
// Handle worker messages
|
67 |
useEffect(() => {
|
|
|
72 |
if (status === 'output') {
|
73 |
setStatus('output')
|
74 |
const result = e.data.output!
|
75 |
+
setResults((prev: ClassificationOutput[]) => [...prev, result])
|
76 |
}
|
77 |
}
|
78 |
|
|
|
142 |
<div className="space-y-3">
|
143 |
{results.map((result, index) => (
|
144 |
<div key={index} className="p-3 rounded-sm border-2">
|
145 |
+
<div className="text-sm text-gray-700 mb-3">
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
146 |
{result.sequence}
|
147 |
</div>
|
148 |
+
<div className="space-y-2">
|
149 |
+
{result.labels.map(
|
150 |
+
(label: string, labelIndex: number) => {
|
151 |
+
const score = result.scores[labelIndex]
|
152 |
+
const isTopPrediction = labelIndex === 0
|
153 |
+
|
154 |
+
return (
|
155 |
+
<div
|
156 |
+
key={labelIndex}
|
157 |
+
className={`flex justify-between items-center p-2 rounded ${
|
158 |
+
isTopPrediction
|
159 |
+
? 'bg-blue-50 border-l-4 border-blue-500'
|
160 |
+
: 'bg-gray-50'
|
161 |
+
}`}
|
162 |
+
>
|
163 |
+
<span
|
164 |
+
className={`font-medium text-sm ${
|
165 |
+
isTopPrediction
|
166 |
+
? 'text-blue-700'
|
167 |
+
: 'text-gray-700'
|
168 |
+
}`}
|
169 |
+
>
|
170 |
+
{label}
|
171 |
+
</span>
|
172 |
+
<span
|
173 |
+
className={`text-sm font-mono ${
|
174 |
+
isTopPrediction
|
175 |
+
? 'text-blue-600'
|
176 |
+
: 'text-gray-600'
|
177 |
+
}`}
|
178 |
+
>
|
179 |
+
{(score * 100).toFixed(1)}%
|
180 |
+
</span>
|
181 |
+
</div>
|
182 |
+
)
|
183 |
+
}
|
184 |
+
)}
|
185 |
+
</div>
|
186 |
</div>
|
187 |
))}
|
188 |
</div>
|
src/components/pipelines/TextClassificationConfig.tsx
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import React from 'react'
|
2 |
+
import { useTextClassification } from '../../contexts/TextClassificationContext'
|
3 |
+
import { Slider } from '../ui/slider'
|
4 |
+
|
5 |
+
const TextClassificationConfig = () => {
|
6 |
+
const { config, setConfig } = useTextClassification()
|
7 |
+
|
8 |
+
return (
|
9 |
+
<div className="space-y-4">
|
10 |
+
<h3 className="text-lg font-semibold text-foreground">
|
11 |
+
Text Classification Settings
|
12 |
+
</h3>
|
13 |
+
|
14 |
+
<div className="space-y-3">
|
15 |
+
<div>
|
16 |
+
<label className="block text-sm font-medium text-foreground/80 mb-1">
|
17 |
+
Top K Predictions: {config.top_k}
|
18 |
+
</label>
|
19 |
+
<Slider
|
20 |
+
defaultValue={[config.top_k]}
|
21 |
+
min={1}
|
22 |
+
max={10}
|
23 |
+
step={1}
|
24 |
+
onValueChange={(value) => setConfig({ top_k: value[0] })}
|
25 |
+
className="w-full rounded-lg"
|
26 |
+
/>
|
27 |
+
<div className="flex justify-between text-xs text-muted-foreground/60 mt-1">
|
28 |
+
<span>1</span>
|
29 |
+
<span>4</span>
|
30 |
+
<span>7</span>
|
31 |
+
<span>10</span>
|
32 |
+
</div>
|
33 |
+
<p className="text-xs text-muted-foreground mt-1">
|
34 |
+
Number of top predictions to return for each text
|
35 |
+
</p>
|
36 |
+
</div>
|
37 |
+
|
38 |
+
<div className="p-3 bg-chart-4/10 border border-chart-4/20 rounded-lg">
|
39 |
+
<h4 className="text-sm font-medium text-chart-4 mb-2">💡 Tips</h4>
|
40 |
+
<div className="text-xs text-chart-4 space-y-1">
|
41 |
+
<p>• Use Top K = 1-3 for most cases</p>
|
42 |
+
<p>• Higher values show more detailed rankings</p>
|
43 |
+
<p>• Try quantized models for faster processing</p>
|
44 |
+
</div>
|
45 |
+
</div>
|
46 |
+
</div>
|
47 |
+
</div>
|
48 |
+
)
|
49 |
+
}
|
50 |
+
|
51 |
+
export default TextClassificationConfig
|
src/contexts/TextClassificationContext.tsx
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import React, { createContext, useContext, useState } from 'react'
|
2 |
+
|
3 |
+
interface TextClassificationConfig {
|
4 |
+
top_k: number
|
5 |
+
}
|
6 |
+
|
7 |
+
interface TextClassificationContextType {
|
8 |
+
config: TextClassificationConfig
|
9 |
+
setConfig: React.Dispatch<React.SetStateAction<TextClassificationConfig>>
|
10 |
+
}
|
11 |
+
|
12 |
+
const TextClassificationContext = createContext<
|
13 |
+
TextClassificationContextType | undefined
|
14 |
+
>(undefined)
|
15 |
+
|
16 |
+
export function useTextClassification() {
|
17 |
+
const context = useContext(TextClassificationContext)
|
18 |
+
if (context === undefined) {
|
19 |
+
throw new Error(
|
20 |
+
'useTextClassification must be used within a TextClassificationProvider'
|
21 |
+
)
|
22 |
+
}
|
23 |
+
return context
|
24 |
+
}
|
25 |
+
|
26 |
+
interface TextClassificationProviderProps {
|
27 |
+
children: React.ReactNode
|
28 |
+
}
|
29 |
+
|
30 |
+
export function TextClassificationProvider({
|
31 |
+
children
|
32 |
+
}: TextClassificationProviderProps) {
|
33 |
+
const [config, setConfig] = useState<TextClassificationConfig>({
|
34 |
+
top_k: 1
|
35 |
+
})
|
36 |
+
|
37 |
+
const value: TextClassificationContextType = {
|
38 |
+
config,
|
39 |
+
setConfig
|
40 |
+
}
|
41 |
+
|
42 |
+
return (
|
43 |
+
<TextClassificationContext.Provider value={value}>
|
44 |
+
{children}
|
45 |
+
</TextClassificationContext.Provider>
|
46 |
+
)
|
47 |
+
}
|
src/types.ts
CHANGED
@@ -48,6 +48,9 @@ export interface TextClassificationWorkerInput {
|
|
48 |
text: string
|
49 |
model: string
|
50 |
dtype: QuantizationType
|
|
|
|
|
|
|
51 |
}
|
52 |
|
53 |
export interface TextGenerationWorkerInput {
|
|
|
48 |
text: string
|
49 |
model: string
|
50 |
dtype: QuantizationType
|
51 |
+
config?: {
|
52 |
+
top_k?: number
|
53 |
+
}
|
54 |
}
|
55 |
|
56 |
export interface TextGenerationWorkerInput {
|