Vokturz commited on
Commit
2656c1e
·
1 Parent(s): 79671b7

feat: add dtype support for quantization in model inputs and remove debug logs

Browse files
src/components/ModelLoader.tsx CHANGED
@@ -64,8 +64,6 @@ const ModelLoader = () => {
64
 
65
  const onMessageReceived = (e: MessageEvent<WorkerMessage>) => {
66
  const { status, output } = e.data
67
- console.log('Received output from worker', e.data)
68
-
69
  if (status === 'ready') {
70
  setStatus('ready')
71
  if (e.data.output) console.log(e.data.output)
 
64
 
65
  const onMessageReceived = (e: MessageEvent<WorkerMessage>) => {
66
  const { status, output } = e.data
 
 
67
  if (status === 'ready') {
68
  setStatus('ready')
69
  if (e.data.output) console.log(e.data.output)
src/components/TextClassification.tsx CHANGED
@@ -3,7 +3,7 @@ import {
3
  TextClassificationWorkerInput,
4
  } from '../types'
5
  import { useModel } from '../contexts/ModelContext'
6
- import { set } from 'lodash'
7
  const PLACEHOLDER_TEXTS: string[] = [
8
  'I absolutely love this product! It exceeded all my expectations.',
9
  "This is the worst purchase I've ever made. Complete waste of money.",
@@ -19,7 +19,7 @@ const PLACEHOLDER_TEXTS: string[] = [
19
 
20
  function TextClassification() {
21
  const [text, setText] = useState<string>(PLACEHOLDER_TEXTS.join('\n'))
22
- const { activeWorker, status, modelInfo, results, setResults, hasBeenLoaded} = useModel()
23
 
24
  const classify = useCallback(() => {
25
  if (!modelInfo || !activeWorker) {
@@ -30,10 +30,11 @@ function TextClassification() {
30
  const message: TextClassificationWorkerInput = {
31
  type: 'classify',
32
  text,
33
- model: modelInfo.id
 
34
  }
35
  activeWorker.postMessage(message)
36
- }, [text, modelInfo, activeWorker, set])
37
 
38
  const busy: boolean = status !== 'ready'
39
 
 
3
  TextClassificationWorkerInput,
4
  } from '../types'
5
  import { useModel } from '../contexts/ModelContext'
6
+
7
  const PLACEHOLDER_TEXTS: string[] = [
8
  'I absolutely love this product! It exceeded all my expectations.',
9
  "This is the worst purchase I've ever made. Complete waste of money.",
 
19
 
20
  function TextClassification() {
21
  const [text, setText] = useState<string>(PLACEHOLDER_TEXTS.join('\n'))
22
+ const { activeWorker, status, modelInfo, results, setResults, hasBeenLoaded, selectedQuantization} = useModel()
23
 
24
  const classify = useCallback(() => {
25
  if (!modelInfo || !activeWorker) {
 
30
  const message: TextClassificationWorkerInput = {
31
  type: 'classify',
32
  text,
33
+ model: modelInfo.id,
34
+ dtype: selectedQuantization ?? 'fp32'
35
  }
36
  activeWorker.postMessage(message)
37
+ }, [text, modelInfo, activeWorker, selectedQuantization, setResults])
38
 
39
  const busy: boolean = status !== 'ready'
40
 
src/components/TextGeneration.tsx CHANGED
@@ -30,7 +30,7 @@ function TextGeneration() {
30
  // Generation state
31
  const [isGenerating, setIsGenerating] = useState<boolean>(false)
32
 
33
- const { activeWorker, status, modelInfo, hasBeenLoaded } = useModel()
34
  const messagesEndRef = useRef<HTMLDivElement>(null)
35
 
36
  const scrollToBottom = () => {
@@ -73,10 +73,11 @@ function TextGeneration() {
73
  top_p: topP,
74
  top_k: topK,
75
  do_sample: doSample,
 
76
  }
77
 
78
  activeWorker.postMessage(message)
79
- }, [currentMessage, messages, modelInfo, activeWorker, temperature, maxTokens, topP, topK, doSample, isGenerating])
80
 
81
  const handleGenerateText = useCallback(() => {
82
  if (!prompt.trim() || !modelInfo || !activeWorker || isGenerating) {
@@ -94,11 +95,12 @@ function TextGeneration() {
94
  max_new_tokens: maxTokens,
95
  top_p: topP,
96
  top_k: topK,
97
- do_sample: doSample
 
98
  }
99
 
100
  activeWorker.postMessage(message)
101
- }, [prompt, modelInfo, activeWorker, temperature, maxTokens, topP, topK, doSample, isGenerating])
102
 
103
  useEffect(() => {
104
  if (!activeWorker) return
 
30
  // Generation state
31
  const [isGenerating, setIsGenerating] = useState<boolean>(false)
32
 
33
+ const { activeWorker, status, modelInfo, hasBeenLoaded, selectedQuantization } = useModel()
34
  const messagesEndRef = useRef<HTMLDivElement>(null)
35
 
36
  const scrollToBottom = () => {
 
73
  top_p: topP,
74
  top_k: topK,
75
  do_sample: doSample,
76
+ dtype: selectedQuantization ?? 'fp32'
77
  }
78
 
79
  activeWorker.postMessage(message)
80
+ }, [currentMessage, messages, modelInfo, activeWorker, temperature, maxTokens, topP, topK, doSample, isGenerating, selectedQuantization])
81
 
82
  const handleGenerateText = useCallback(() => {
83
  if (!prompt.trim() || !modelInfo || !activeWorker || isGenerating) {
 
95
  max_new_tokens: maxTokens,
96
  top_p: topP,
97
  top_k: topK,
98
+ do_sample: doSample,
99
+ dtype: selectedQuantization ?? 'fp32'
100
  }
101
 
102
  activeWorker.postMessage(message)
103
+ }, [prompt, modelInfo, activeWorker, temperature, maxTokens, topP, topK, doSample, isGenerating, selectedQuantization])
104
 
105
  useEffect(() => {
106
  if (!activeWorker) return
src/components/ZeroShotClassification.tsx CHANGED
@@ -48,7 +48,7 @@ function ZeroShotClassification() {
48
  PLACEHOLDER_SECTIONS.map((title) => ({ title, items: [] }))
49
  )
50
 
51
- const { activeWorker, status, modelInfo, hasBeenLoaded } = useModel()
52
 
53
  const classify = useCallback(() => {
54
  if (!modelInfo || !activeWorker) {
@@ -70,10 +70,11 @@ function ZeroShotClassification() {
70
  labels: sections
71
  .slice(0, sections.length - 1)
72
  .map((section) => section.title),
73
- model: modelInfo.id
 
74
  }
75
  activeWorker.postMessage(message)
76
- }, [text, sections, modelInfo, activeWorker])
77
 
78
  // Handle worker messages
79
  useEffect(() => {
 
48
  PLACEHOLDER_SECTIONS.map((title) => ({ title, items: [] }))
49
  )
50
 
51
+ const { activeWorker, status, modelInfo, hasBeenLoaded, selectedQuantization } = useModel()
52
 
53
  const classify = useCallback(() => {
54
  if (!modelInfo || !activeWorker) {
 
70
  labels: sections
71
  .slice(0, sections.length - 1)
72
  .map((section) => section.title),
73
+ model: modelInfo.id,
74
+ dtype: selectedQuantization ?? 'fp32'
75
  }
76
  activeWorker.postMessage(message)
77
+ }, [text, sections, modelInfo, activeWorker, selectedQuantization])
78
 
79
  // Handle worker messages
80
  useEffect(() => {
src/types.ts CHANGED
@@ -39,12 +39,14 @@ export interface ZeroShotWorkerInput {
39
  text: string
40
  labels: string[]
41
  model: string
 
42
  }
43
 
44
  export interface TextClassificationWorkerInput {
45
  type: 'classify'
46
  text: string
47
  model: string
 
48
  }
49
 
50
  export interface TextGenerationWorkerInput {
@@ -58,6 +60,7 @@ export interface TextGenerationWorkerInput {
58
  top_p?: number
59
  top_k?: number
60
  do_sample?: boolean
 
61
  }
62
 
63
  const q8Types = ['q8', 'int8', 'bnb8', 'uint8'] as const
 
39
  text: string
40
  labels: string[]
41
  model: string
42
+ dtype: QuantizationType
43
  }
44
 
45
  export interface TextClassificationWorkerInput {
46
  type: 'classify'
47
  text: string
48
  model: string
49
+ dtype: QuantizationType
50
  }
51
 
52
  export interface TextGenerationWorkerInput {
 
60
  top_p?: number
61
  top_k?: number
62
  do_sample?: boolean
63
+ dtype: QuantizationType
64
  }
65
 
66
  const q8Types = ['q8', 'int8', 'bnb8', 'uint8'] as const