Refactor model configuration to use config object
Browse files- public/workers/image-classification.js +2 -4
- public/workers/text-generation.js +7 -20
- src/components/pipelines/ImageClassification.tsx +3 -3
- src/components/pipelines/ImageClassificationConfig.tsx +4 -4
- src/components/pipelines/TextGeneration.tsx +3 -11
- src/contexts/ImageClassificationContext.tsx +11 -5
- src/types.ts +10 -6
public/workers/image-classification.js
CHANGED
@@ -62,7 +62,7 @@ class MyImageClassificationPipeline {
|
|
62 |
// Listen for messages from the main thread
|
63 |
self.addEventListener('message', async (event) => {
|
64 |
try {
|
65 |
-
const { type, image, model, dtype,
|
66 |
|
67 |
if (!model) {
|
68 |
self.postMessage({
|
@@ -100,9 +100,7 @@ self.addEventListener('message', async (event) => {
|
|
100 |
|
101 |
try {
|
102 |
// Run classification
|
103 |
-
const output = await classifier(image,
|
104 |
-
top_k: topK
|
105 |
-
})
|
106 |
|
107 |
// Format predictions
|
108 |
const predictions = output.map((item) => ({
|
|
|
62 |
// Listen for messages from the main thread
|
63 |
self.addEventListener('message', async (event) => {
|
64 |
try {
|
65 |
+
const { type, image, model, dtype, config } = event.data
|
66 |
|
67 |
if (!model) {
|
68 |
self.postMessage({
|
|
|
100 |
|
101 |
try {
|
102 |
// Run classification
|
103 |
+
const output = await classifier(image, config)
|
|
|
|
|
104 |
|
105 |
// Format predictions
|
106 |
const predictions = output.map((item) => ({
|
public/workers/text-generation.js
CHANGED
@@ -49,20 +49,8 @@ class MyTextGenerationPipeline {
|
|
49 |
// Listen for messages from the main thread
|
50 |
self.addEventListener('message', async (event) => {
|
51 |
try {
|
52 |
-
const {
|
53 |
-
|
54 |
-
model,
|
55 |
-
dtype,
|
56 |
-
messages,
|
57 |
-
prompt,
|
58 |
-
hasChatTemplate,
|
59 |
-
temperature,
|
60 |
-
max_new_tokens,
|
61 |
-
top_p,
|
62 |
-
top_k,
|
63 |
-
do_sample,
|
64 |
-
stop_words
|
65 |
-
} = event.data
|
66 |
|
67 |
if (type === 'stop') {
|
68 |
MyTextGenerationPipeline.stopGeneration()
|
@@ -108,12 +96,11 @@ self.addEventListener('message', async (event) => {
|
|
108 |
}
|
109 |
|
110 |
const options = {
|
111 |
-
max_new_tokens: max_new_tokens || 100,
|
112 |
-
temperature: temperature || 0.7,
|
113 |
-
do_sample: do_sample !== false,
|
114 |
-
...(top_p && { top_p }),
|
115 |
-
...(top_k && { top_k })
|
116 |
-
...(stop_words && stop_words.length > 0 && { stop_words })
|
117 |
}
|
118 |
|
119 |
// Create an AbortController for this generation
|
|
|
49 |
// Listen for messages from the main thread
|
50 |
self.addEventListener('message', async (event) => {
|
51 |
try {
|
52 |
+
const { type, model, dtype, messages, prompt, hasChatTemplate, config } =
|
53 |
+
event.data
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
|
55 |
if (type === 'stop') {
|
56 |
MyTextGenerationPipeline.stopGeneration()
|
|
|
96 |
}
|
97 |
|
98 |
const options = {
|
99 |
+
max_new_tokens: config.max_new_tokens || 100,
|
100 |
+
temperature: config.temperature || 0.7,
|
101 |
+
do_sample: config.do_sample !== false,
|
102 |
+
...(config.top_p && { top_p }),
|
103 |
+
...(config.top_k && { top_k })
|
|
|
104 |
}
|
105 |
|
106 |
// Create an AbortController for this generation
|
src/components/pipelines/ImageClassification.tsx
CHANGED
@@ -45,7 +45,7 @@ function ImageClassification() {
|
|
45 |
removeExample,
|
46 |
updateExample,
|
47 |
clearExamples,
|
48 |
-
|
49 |
} = useImageClassification()
|
50 |
|
51 |
const [isClassifying, setIsClassifying] = useState<boolean>(false)
|
@@ -75,12 +75,12 @@ function ImageClassification() {
|
|
75 |
image: example.url,
|
76 |
model: modelInfo.id,
|
77 |
dtype: selectedQuantization ?? 'fp32',
|
78 |
-
|
79 |
}
|
80 |
|
81 |
activeWorker.postMessage(message)
|
82 |
},
|
83 |
-
[modelInfo, activeWorker, selectedQuantization,
|
84 |
)
|
85 |
|
86 |
const handleFileSelect = useCallback(
|
|
|
45 |
removeExample,
|
46 |
updateExample,
|
47 |
clearExamples,
|
48 |
+
config
|
49 |
} = useImageClassification()
|
50 |
|
51 |
const [isClassifying, setIsClassifying] = useState<boolean>(false)
|
|
|
75 |
image: example.url,
|
76 |
model: modelInfo.id,
|
77 |
dtype: selectedQuantization ?? 'fp32',
|
78 |
+
config
|
79 |
}
|
80 |
|
81 |
activeWorker.postMessage(message)
|
82 |
},
|
83 |
+
[modelInfo, activeWorker, selectedQuantization, config, updateExample]
|
84 |
)
|
85 |
|
86 |
const handleFileSelect = useCallback(
|
src/components/pipelines/ImageClassificationConfig.tsx
CHANGED
@@ -3,7 +3,7 @@ import { useImageClassification } from '../../contexts/ImageClassificationContex
|
|
3 |
import { Slider } from '../ui/slider'
|
4 |
|
5 |
const ImageClassificationConfig = () => {
|
6 |
-
const {
|
7 |
|
8 |
return (
|
9 |
<div className="space-y-4">
|
@@ -14,14 +14,14 @@ const ImageClassificationConfig = () => {
|
|
14 |
<div className="space-y-3">
|
15 |
<div>
|
16 |
<label className="block text-sm font-medium text-foreground/80 mb-1">
|
17 |
-
Top K Predictions: {
|
18 |
</label>
|
19 |
<Slider
|
20 |
-
defaultValue={[
|
21 |
min={1}
|
22 |
max={10}
|
23 |
step={1}
|
24 |
-
onValueChange={(value) =>
|
25 |
className="w-full rounded-lg"
|
26 |
/>
|
27 |
<div className="flex justify-between text-xs text-muted-foreground/60 mt-1">
|
|
|
3 |
import { Slider } from '../ui/slider'
|
4 |
|
5 |
const ImageClassificationConfig = () => {
|
6 |
+
const { config, setConfig } = useImageClassification()
|
7 |
|
8 |
return (
|
9 |
<div className="space-y-4">
|
|
|
14 |
<div className="space-y-3">
|
15 |
<div>
|
16 |
<label className="block text-sm font-medium text-foreground/80 mb-1">
|
17 |
+
Top K Predictions: {config.top_k}
|
18 |
</label>
|
19 |
<Slider
|
20 |
+
defaultValue={[config.top_k]}
|
21 |
min={1}
|
22 |
max={10}
|
23 |
step={1}
|
24 |
+
onValueChange={(value) => setConfig({ top_k: value[0] })}
|
25 |
className="w-full rounded-lg"
|
26 |
/>
|
27 |
<div className="flex justify-between text-xs text-muted-foreground/60 mt-1">
|
src/components/pipelines/TextGeneration.tsx
CHANGED
@@ -58,12 +58,8 @@ function TextGeneration() {
|
|
58 |
messages: updatedMessages,
|
59 |
hasChatTemplate: modelInfo.hasChatTemplate,
|
60 |
model: modelInfo.id,
|
61 |
-
|
62 |
-
|
63 |
-
top_p: config.topP,
|
64 |
-
top_k: config.topK,
|
65 |
-
do_sample: config.doSample,
|
66 |
-
dtype: selectedQuantization ?? 'fp32'
|
67 |
}
|
68 |
|
69 |
activeWorker.postMessage(message)
|
@@ -87,11 +83,7 @@ function TextGeneration() {
|
|
87 |
prompt: prompt.trim(),
|
88 |
hasChatTemplate: modelInfo.hasChatTemplate,
|
89 |
model: modelInfo.id,
|
90 |
-
|
91 |
-
max_new_tokens: config.maxTokens,
|
92 |
-
top_p: config.topP,
|
93 |
-
top_k: config.topK,
|
94 |
-
do_sample: config.doSample,
|
95 |
dtype: selectedQuantization ?? 'fp32'
|
96 |
}
|
97 |
|
|
|
58 |
messages: updatedMessages,
|
59 |
hasChatTemplate: modelInfo.hasChatTemplate,
|
60 |
model: modelInfo.id,
|
61 |
+
dtype: selectedQuantization ?? 'fp32',
|
62 |
+
config
|
|
|
|
|
|
|
|
|
63 |
}
|
64 |
|
65 |
activeWorker.postMessage(message)
|
|
|
83 |
prompt: prompt.trim(),
|
84 |
hasChatTemplate: modelInfo.hasChatTemplate,
|
85 |
model: modelInfo.id,
|
86 |
+
config,
|
|
|
|
|
|
|
|
|
87 |
dtype: selectedQuantization ?? 'fp32'
|
88 |
}
|
89 |
|
src/contexts/ImageClassificationContext.tsx
CHANGED
@@ -1,6 +1,10 @@
|
|
1 |
import React, { createContext, useContext, useState, useCallback } from 'react'
|
2 |
import { ImageExample } from '../types'
|
3 |
|
|
|
|
|
|
|
|
|
4 |
interface ImageClassificationContextType {
|
5 |
examples: ImageExample[]
|
6 |
selectedExample: ImageExample | null
|
@@ -9,8 +13,8 @@ interface ImageClassificationContextType {
|
|
9 |
removeExample: (id: string) => void
|
10 |
updateExample: (id: string, updates: Partial<ImageExample>) => void
|
11 |
clearExamples: () => void
|
12 |
-
|
13 |
-
|
14 |
}
|
15 |
|
16 |
const ImageClassificationContext = createContext<
|
@@ -38,7 +42,9 @@ export function ImageClassificationProvider({
|
|
38 |
const [selectedExample, setSelectedExample] = useState<ImageExample | null>(
|
39 |
null
|
40 |
)
|
41 |
-
const [
|
|
|
|
|
42 |
|
43 |
const addExample = useCallback((file: File) => {
|
44 |
const id = Math.random().toString(36).substr(2, 9)
|
@@ -105,8 +111,8 @@ export function ImageClassificationProvider({
|
|
105 |
removeExample,
|
106 |
updateExample,
|
107 |
clearExamples,
|
108 |
-
|
109 |
-
|
110 |
}
|
111 |
|
112 |
return (
|
|
|
1 |
import React, { createContext, useContext, useState, useCallback } from 'react'
|
2 |
import { ImageExample } from '../types'
|
3 |
|
4 |
+
interface ImageClassificationConfig {
|
5 |
+
top_k: number
|
6 |
+
}
|
7 |
+
|
8 |
interface ImageClassificationContextType {
|
9 |
examples: ImageExample[]
|
10 |
selectedExample: ImageExample | null
|
|
|
13 |
removeExample: (id: string) => void
|
14 |
updateExample: (id: string, updates: Partial<ImageExample>) => void
|
15 |
clearExamples: () => void
|
16 |
+
config: ImageClassificationConfig
|
17 |
+
setConfig: React.Dispatch<React.SetStateAction<ImageClassificationConfig>>
|
18 |
}
|
19 |
|
20 |
const ImageClassificationContext = createContext<
|
|
|
42 |
const [selectedExample, setSelectedExample] = useState<ImageExample | null>(
|
43 |
null
|
44 |
)
|
45 |
+
const [config, setConfig] = useState<ImageClassificationConfig>({
|
46 |
+
top_k: 5
|
47 |
+
})
|
48 |
|
49 |
const addExample = useCallback((file: File) => {
|
50 |
const id = Math.random().toString(36).substr(2, 9)
|
|
|
111 |
removeExample,
|
112 |
updateExample,
|
113 |
clearExamples,
|
114 |
+
config,
|
115 |
+
setConfig
|
116 |
}
|
117 |
|
118 |
return (
|
src/types.ts
CHANGED
@@ -56,11 +56,13 @@ export interface TextGenerationWorkerInput {
|
|
56 |
messages?: ChatMessage[]
|
57 |
hasChatTemplate: boolean
|
58 |
model: string
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
|
|
|
|
64 |
dtype: QuantizationType
|
65 |
}
|
66 |
|
@@ -80,7 +82,9 @@ export interface ImageClassificationWorkerInput {
|
|
80 |
image: string | ImageData | HTMLImageElement | HTMLCanvasElement
|
81 |
model: string
|
82 |
dtype: QuantizationType
|
83 |
-
|
|
|
|
|
84 |
}
|
85 |
|
86 |
export interface ImageClassificationResult {
|
|
|
56 |
messages?: ChatMessage[]
|
57 |
hasChatTemplate: boolean
|
58 |
model: string
|
59 |
+
config?: {
|
60 |
+
temperature?: number
|
61 |
+
max_new_tokens?: number
|
62 |
+
top_p?: number
|
63 |
+
top_k?: number
|
64 |
+
do_sample?: boolean
|
65 |
+
}
|
66 |
dtype: QuantizationType
|
67 |
}
|
68 |
|
|
|
82 |
image: string | ImageData | HTMLImageElement | HTMLCanvasElement
|
83 |
model: string
|
84 |
dtype: QuantizationType
|
85 |
+
config: {
|
86 |
+
top_k?: number
|
87 |
+
}
|
88 |
}
|
89 |
|
90 |
export interface ImageClassificationResult {
|