feat: add dtype support for quantization in model inputs and remove debug logs
Browse files
src/components/ModelLoader.tsx
CHANGED
@@ -64,8 +64,6 @@ const ModelLoader = () => {
|
|
64 |
|
65 |
const onMessageReceived = (e: MessageEvent<WorkerMessage>) => {
|
66 |
const { status, output } = e.data
|
67 |
-
console.log('Received output from worker', e.data)
|
68 |
-
|
69 |
if (status === 'ready') {
|
70 |
setStatus('ready')
|
71 |
if (e.data.output) console.log(e.data.output)
|
|
|
64 |
|
65 |
const onMessageReceived = (e: MessageEvent<WorkerMessage>) => {
|
66 |
const { status, output } = e.data
|
|
|
|
|
67 |
if (status === 'ready') {
|
68 |
setStatus('ready')
|
69 |
if (e.data.output) console.log(e.data.output)
|
src/components/TextClassification.tsx
CHANGED
@@ -3,7 +3,7 @@ import {
|
|
3 |
TextClassificationWorkerInput,
|
4 |
} from '../types'
|
5 |
import { useModel } from '../contexts/ModelContext'
|
6 |
-
|
7 |
const PLACEHOLDER_TEXTS: string[] = [
|
8 |
'I absolutely love this product! It exceeded all my expectations.',
|
9 |
"This is the worst purchase I've ever made. Complete waste of money.",
|
@@ -19,7 +19,7 @@ const PLACEHOLDER_TEXTS: string[] = [
|
|
19 |
|
20 |
function TextClassification() {
|
21 |
const [text, setText] = useState<string>(PLACEHOLDER_TEXTS.join('\n'))
|
22 |
-
const { activeWorker, status, modelInfo, results, setResults, hasBeenLoaded} = useModel()
|
23 |
|
24 |
const classify = useCallback(() => {
|
25 |
if (!modelInfo || !activeWorker) {
|
@@ -30,10 +30,11 @@ function TextClassification() {
|
|
30 |
const message: TextClassificationWorkerInput = {
|
31 |
type: 'classify',
|
32 |
text,
|
33 |
-
model: modelInfo.id
|
|
|
34 |
}
|
35 |
activeWorker.postMessage(message)
|
36 |
-
}, [text, modelInfo, activeWorker,
|
37 |
|
38 |
const busy: boolean = status !== 'ready'
|
39 |
|
|
|
3 |
TextClassificationWorkerInput,
|
4 |
} from '../types'
|
5 |
import { useModel } from '../contexts/ModelContext'
|
6 |
+
|
7 |
const PLACEHOLDER_TEXTS: string[] = [
|
8 |
'I absolutely love this product! It exceeded all my expectations.',
|
9 |
"This is the worst purchase I've ever made. Complete waste of money.",
|
|
|
19 |
|
20 |
function TextClassification() {
|
21 |
const [text, setText] = useState<string>(PLACEHOLDER_TEXTS.join('\n'))
|
22 |
+
const { activeWorker, status, modelInfo, results, setResults, hasBeenLoaded, selectedQuantization} = useModel()
|
23 |
|
24 |
const classify = useCallback(() => {
|
25 |
if (!modelInfo || !activeWorker) {
|
|
|
30 |
const message: TextClassificationWorkerInput = {
|
31 |
type: 'classify',
|
32 |
text,
|
33 |
+
model: modelInfo.id,
|
34 |
+
dtype: selectedQuantization ?? 'fp32'
|
35 |
}
|
36 |
activeWorker.postMessage(message)
|
37 |
+
}, [text, modelInfo, activeWorker, selectedQuantization, setResults])
|
38 |
|
39 |
const busy: boolean = status !== 'ready'
|
40 |
|
src/components/TextGeneration.tsx
CHANGED
@@ -30,7 +30,7 @@ function TextGeneration() {
|
|
30 |
// Generation state
|
31 |
const [isGenerating, setIsGenerating] = useState<boolean>(false)
|
32 |
|
33 |
-
const { activeWorker, status, modelInfo, hasBeenLoaded } = useModel()
|
34 |
const messagesEndRef = useRef<HTMLDivElement>(null)
|
35 |
|
36 |
const scrollToBottom = () => {
|
@@ -73,10 +73,11 @@ function TextGeneration() {
|
|
73 |
top_p: topP,
|
74 |
top_k: topK,
|
75 |
do_sample: doSample,
|
|
|
76 |
}
|
77 |
|
78 |
activeWorker.postMessage(message)
|
79 |
-
}, [currentMessage, messages, modelInfo, activeWorker, temperature, maxTokens, topP, topK, doSample, isGenerating])
|
80 |
|
81 |
const handleGenerateText = useCallback(() => {
|
82 |
if (!prompt.trim() || !modelInfo || !activeWorker || isGenerating) {
|
@@ -94,11 +95,12 @@ function TextGeneration() {
|
|
94 |
max_new_tokens: maxTokens,
|
95 |
top_p: topP,
|
96 |
top_k: topK,
|
97 |
-
do_sample: doSample
|
|
|
98 |
}
|
99 |
|
100 |
activeWorker.postMessage(message)
|
101 |
-
}, [prompt, modelInfo, activeWorker, temperature, maxTokens, topP, topK, doSample, isGenerating])
|
102 |
|
103 |
useEffect(() => {
|
104 |
if (!activeWorker) return
|
|
|
30 |
// Generation state
|
31 |
const [isGenerating, setIsGenerating] = useState<boolean>(false)
|
32 |
|
33 |
+
const { activeWorker, status, modelInfo, hasBeenLoaded, selectedQuantization } = useModel()
|
34 |
const messagesEndRef = useRef<HTMLDivElement>(null)
|
35 |
|
36 |
const scrollToBottom = () => {
|
|
|
73 |
top_p: topP,
|
74 |
top_k: topK,
|
75 |
do_sample: doSample,
|
76 |
+
dtype: selectedQuantization ?? 'fp32'
|
77 |
}
|
78 |
|
79 |
activeWorker.postMessage(message)
|
80 |
+
}, [currentMessage, messages, modelInfo, activeWorker, temperature, maxTokens, topP, topK, doSample, isGenerating, selectedQuantization])
|
81 |
|
82 |
const handleGenerateText = useCallback(() => {
|
83 |
if (!prompt.trim() || !modelInfo || !activeWorker || isGenerating) {
|
|
|
95 |
max_new_tokens: maxTokens,
|
96 |
top_p: topP,
|
97 |
top_k: topK,
|
98 |
+
do_sample: doSample,
|
99 |
+
dtype: selectedQuantization ?? 'fp32'
|
100 |
}
|
101 |
|
102 |
activeWorker.postMessage(message)
|
103 |
+
}, [prompt, modelInfo, activeWorker, temperature, maxTokens, topP, topK, doSample, isGenerating, selectedQuantization])
|
104 |
|
105 |
useEffect(() => {
|
106 |
if (!activeWorker) return
|
src/components/ZeroShotClassification.tsx
CHANGED
@@ -48,7 +48,7 @@ function ZeroShotClassification() {
|
|
48 |
PLACEHOLDER_SECTIONS.map((title) => ({ title, items: [] }))
|
49 |
)
|
50 |
|
51 |
-
const { activeWorker, status, modelInfo, hasBeenLoaded } = useModel()
|
52 |
|
53 |
const classify = useCallback(() => {
|
54 |
if (!modelInfo || !activeWorker) {
|
@@ -70,10 +70,11 @@ function ZeroShotClassification() {
|
|
70 |
labels: sections
|
71 |
.slice(0, sections.length - 1)
|
72 |
.map((section) => section.title),
|
73 |
-
model: modelInfo.id
|
|
|
74 |
}
|
75 |
activeWorker.postMessage(message)
|
76 |
-
}, [text, sections, modelInfo, activeWorker])
|
77 |
|
78 |
// Handle worker messages
|
79 |
useEffect(() => {
|
|
|
48 |
PLACEHOLDER_SECTIONS.map((title) => ({ title, items: [] }))
|
49 |
)
|
50 |
|
51 |
+
const { activeWorker, status, modelInfo, hasBeenLoaded, selectedQuantization } = useModel()
|
52 |
|
53 |
const classify = useCallback(() => {
|
54 |
if (!modelInfo || !activeWorker) {
|
|
|
70 |
labels: sections
|
71 |
.slice(0, sections.length - 1)
|
72 |
.map((section) => section.title),
|
73 |
+
model: modelInfo.id,
|
74 |
+
dtype: selectedQuantization ?? 'fp32'
|
75 |
}
|
76 |
activeWorker.postMessage(message)
|
77 |
+
}, [text, sections, modelInfo, activeWorker, selectedQuantization])
|
78 |
|
79 |
// Handle worker messages
|
80 |
useEffect(() => {
|
src/types.ts
CHANGED
@@ -39,12 +39,14 @@ export interface ZeroShotWorkerInput {
|
|
39 |
text: string
|
40 |
labels: string[]
|
41 |
model: string
|
|
|
42 |
}
|
43 |
|
44 |
export interface TextClassificationWorkerInput {
|
45 |
type: 'classify'
|
46 |
text: string
|
47 |
model: string
|
|
|
48 |
}
|
49 |
|
50 |
export interface TextGenerationWorkerInput {
|
@@ -58,6 +60,7 @@ export interface TextGenerationWorkerInput {
|
|
58 |
top_p?: number
|
59 |
top_k?: number
|
60 |
do_sample?: boolean
|
|
|
61 |
}
|
62 |
|
63 |
const q8Types = ['q8', 'int8', 'bnb8', 'uint8'] as const
|
|
|
39 |
text: string
|
40 |
labels: string[]
|
41 |
model: string
|
42 |
+
dtype: QuantizationType
|
43 |
}
|
44 |
|
45 |
export interface TextClassificationWorkerInput {
|
46 |
type: 'classify'
|
47 |
text: string
|
48 |
model: string
|
49 |
+
dtype: QuantizationType
|
50 |
}
|
51 |
|
52 |
export interface TextGenerationWorkerInput {
|
|
|
60 |
top_p?: number
|
61 |
top_k?: number
|
62 |
do_sample?: boolean
|
63 |
+
dtype: QuantizationType
|
64 |
}
|
65 |
|
66 |
const q8Types = ['q8', 'int8', 'bnb8', 'uint8'] as const
|