feat: Add text generation functionality
Browse files- public/workers/text-generation.js +139 -0
- src/App.tsx +2 -0
- src/components/ModelLoader.tsx +1 -1
- src/components/ModelSelector.tsx +2 -4
- src/components/TextGeneration.tsx +494 -0
- src/lib/workerManager.ts +17 -19
- src/types.ts +26 -2
public/workers/text-generation.js
ADDED
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/* eslint-disable no-restricted-globals */
|
2 |
+
import { pipeline } from 'https://cdn.jsdelivr.net/npm/@huggingface/transformers@3.6.3'
|
3 |
+
|
4 |
+
class MyTextGenerationPipeline {
|
5 |
+
static task = 'text-generation'
|
6 |
+
static instance = null
|
7 |
+
static currentGeneration = null
|
8 |
+
|
9 |
+
static async getInstance(model, dtype = 'fp32', progress_callback = null) {
|
10 |
+
this.instance = pipeline(this.task, model, {
|
11 |
+
dtype,
|
12 |
+
device: 'webgpu',
|
13 |
+
progress_callback
|
14 |
+
})
|
15 |
+
return this.instance
|
16 |
+
}
|
17 |
+
|
18 |
+
static stopGeneration() {
|
19 |
+
if (this.currentGeneration) {
|
20 |
+
this.currentGeneration.abort()
|
21 |
+
this.currentGeneration = null
|
22 |
+
}
|
23 |
+
}
|
24 |
+
}
|
25 |
+
|
26 |
+
// Listen for messages from the main thread
|
27 |
+
self.addEventListener('message', async (event) => {
|
28 |
+
try {
|
29 |
+
const {
|
30 |
+
type,
|
31 |
+
model,
|
32 |
+
dtype,
|
33 |
+
messages,
|
34 |
+
prompt,
|
35 |
+
hasChatTemplate,
|
36 |
+
temperature,
|
37 |
+
max_new_tokens,
|
38 |
+
top_p,
|
39 |
+
top_k,
|
40 |
+
do_sample,
|
41 |
+
stop_words
|
42 |
+
} = event.data
|
43 |
+
|
44 |
+
if (type === 'stop') {
|
45 |
+
MyTextGenerationPipeline.stopGeneration()
|
46 |
+
self.postMessage({ status: 'ready' })
|
47 |
+
return
|
48 |
+
}
|
49 |
+
|
50 |
+
if (!model) {
|
51 |
+
self.postMessage({
|
52 |
+
status: 'error',
|
53 |
+
output: 'No model provided'
|
54 |
+
})
|
55 |
+
return
|
56 |
+
}
|
57 |
+
|
58 |
+
// Retrieve the pipeline. This will download the model if not already cached.
|
59 |
+
const generator = await MyTextGenerationPipeline.getInstance(
|
60 |
+
model,
|
61 |
+
dtype,
|
62 |
+
(x) => {
|
63 |
+
self.postMessage({ status: 'loading', output: x })
|
64 |
+
}
|
65 |
+
)
|
66 |
+
|
67 |
+
if (type === 'load') {
|
68 |
+
self.postMessage({
|
69 |
+
status: 'ready',
|
70 |
+
output: `Model ${model}, dtype ${dtype} loaded`
|
71 |
+
})
|
72 |
+
return
|
73 |
+
}
|
74 |
+
|
75 |
+
if (type === 'generate') {
|
76 |
+
let inputText = ''
|
77 |
+
|
78 |
+
if (hasChatTemplate && messages && messages.length > 0) {
|
79 |
+
inputText = messages
|
80 |
+
} else if (!hasChatTemplate && prompt) {
|
81 |
+
inputText = prompt
|
82 |
+
} else {
|
83 |
+
self.postMessage({ status: 'ready' })
|
84 |
+
return
|
85 |
+
}
|
86 |
+
|
87 |
+
const options = {
|
88 |
+
max_new_tokens: max_new_tokens || 100,
|
89 |
+
temperature: temperature || 0.7,
|
90 |
+
do_sample: do_sample !== false,
|
91 |
+
...(top_p && { top_p }),
|
92 |
+
...(top_k && { top_k }),
|
93 |
+
...(stop_words && stop_words.length > 0 && { stop_words })
|
94 |
+
}
|
95 |
+
|
96 |
+
// Create an AbortController for this generation
|
97 |
+
const abortController = new AbortController()
|
98 |
+
MyTextGenerationPipeline.currentGeneration = abortController
|
99 |
+
|
100 |
+
try {
|
101 |
+
const output = await generator(inputText, {
|
102 |
+
...options,
|
103 |
+
signal: abortController.signal
|
104 |
+
})
|
105 |
+
|
106 |
+
if (hasChatTemplate) {
|
107 |
+
// For chat mode, extract only the assistant's response
|
108 |
+
self.postMessage({
|
109 |
+
status: 'output',
|
110 |
+
output: output[0].generated_text.slice(-1)[0]
|
111 |
+
})
|
112 |
+
} else {
|
113 |
+
self.postMessage({
|
114 |
+
status: 'output',
|
115 |
+
output: {
|
116 |
+
role: 'assistant',
|
117 |
+
content: output[0].generated_text
|
118 |
+
}
|
119 |
+
})
|
120 |
+
}
|
121 |
+
|
122 |
+
self.postMessage({ status: 'ready' })
|
123 |
+
} catch (error) {
|
124 |
+
if (error.name === 'AbortError') {
|
125 |
+
self.postMessage({ status: 'ready' })
|
126 |
+
} else {
|
127 |
+
throw error
|
128 |
+
}
|
129 |
+
} finally {
|
130 |
+
MyTextGenerationPipeline.currentGeneration = null
|
131 |
+
}
|
132 |
+
}
|
133 |
+
} catch (error) {
|
134 |
+
self.postMessage({
|
135 |
+
status: 'error',
|
136 |
+
output: error.message || 'An error occurred during text generation'
|
137 |
+
})
|
138 |
+
}
|
139 |
+
})
|
src/App.tsx
CHANGED
@@ -8,6 +8,7 @@ import { getModelsByPipeline } from './lib/huggingface'
|
|
8 |
import ModelSelector from './components/ModelSelector'
|
9 |
import ModelInfo from './components/ModelInfo'
|
10 |
import ModelReadme from './components/ModelReadme'
|
|
|
11 |
|
12 |
function App() {
|
13 |
const { pipeline, setPipeline, setModels, setModelInfo, modelInfo, setIsFetching} = useModel()
|
@@ -76,6 +77,7 @@ function App() {
|
|
76 |
<ZeroShotClassification />
|
77 |
)}
|
78 |
{pipeline === 'text-classification' && <TextClassification />}
|
|
|
79 |
</div>
|
80 |
</main>
|
81 |
</div>
|
|
|
8 |
import ModelSelector from './components/ModelSelector'
|
9 |
import ModelInfo from './components/ModelInfo'
|
10 |
import ModelReadme from './components/ModelReadme'
|
11 |
+
import TextGeneration from './components/TextGeneration'
|
12 |
|
13 |
function App() {
|
14 |
const { pipeline, setPipeline, setModels, setModelInfo, modelInfo, setIsFetching} = useModel()
|
|
|
77 |
<ZeroShotClassification />
|
78 |
)}
|
79 |
{pipeline === 'text-classification' && <TextClassification />}
|
80 |
+
{pipeline === 'text-generation' && <TextGeneration />}
|
81 |
</div>
|
82 |
</main>
|
83 |
</div>
|
src/components/ModelLoader.tsx
CHANGED
@@ -23,7 +23,7 @@ const ModelLoader = () => {
|
|
23 |
|
24 |
useEffect(() => {
|
25 |
setHasBeenLoaded(false)
|
26 |
-
}, [selectedQuantization])
|
27 |
|
28 |
useEffect(() => {
|
29 |
if (!modelInfo) return
|
|
|
23 |
|
24 |
useEffect(() => {
|
25 |
setHasBeenLoaded(false)
|
26 |
+
}, [selectedQuantization, setHasBeenLoaded])
|
27 |
|
28 |
useEffect(() => {
|
29 |
if (!modelInfo) return
|
src/components/ModelSelector.tsx
CHANGED
@@ -107,11 +107,9 @@ function ModelSelector() {
|
|
107 |
incompatibilityReason: modelInfoResponse.incompatibilityReason,
|
108 |
supportedQuantizations: modelInfoResponse.supportedQuantizations,
|
109 |
baseId: modelInfoResponse.baseId,
|
110 |
-
readme: modelInfoResponse.readme
|
|
|
111 |
}
|
112 |
-
|
113 |
-
console.log('Fetched model info:', modelInfoResponse)
|
114 |
-
|
115 |
setModelInfo(modelInfo)
|
116 |
setIsCustomModel(isCustom)
|
117 |
setIsFetching(false)
|
|
|
107 |
incompatibilityReason: modelInfoResponse.incompatibilityReason,
|
108 |
supportedQuantizations: modelInfoResponse.supportedQuantizations,
|
109 |
baseId: modelInfoResponse.baseId,
|
110 |
+
readme: modelInfoResponse.readme,
|
111 |
+
hasChatTemplate: Boolean(modelInfoResponse.config?.tokenizer_config?.chat_template)
|
112 |
}
|
|
|
|
|
|
|
113 |
setModelInfo(modelInfo)
|
114 |
setIsCustomModel(isCustom)
|
115 |
setIsFetching(false)
|
src/components/TextGeneration.tsx
ADDED
@@ -0,0 +1,494 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import { useState, useRef, useEffect, useCallback } from 'react'
|
2 |
+
import { Send, Settings, Trash2, Loader2, X } from 'lucide-react'
|
3 |
+
import { Dialog, Transition, Switch } from '@headlessui/react'
|
4 |
+
import { Fragment } from 'react'
|
5 |
+
import {
|
6 |
+
ChatMessage,
|
7 |
+
TextGenerationWorkerInput,
|
8 |
+
WorkerMessage
|
9 |
+
} from '../types'
|
10 |
+
import { useModel } from '../contexts/ModelContext'
|
11 |
+
|
12 |
+
function TextGeneration() {
|
13 |
+
const [messages, setMessages] = useState<ChatMessage[]>([
|
14 |
+
{ role: 'system', content: 'You are a helpful assistant.' }
|
15 |
+
])
|
16 |
+
const [currentMessage, setCurrentMessage] = useState<string>('')
|
17 |
+
const [showSettings, setShowSettings] = useState<boolean>(false)
|
18 |
+
|
19 |
+
// For simple text generation
|
20 |
+
const [prompt, setPrompt] = useState<string>('')
|
21 |
+
const [generatedText, setGeneratedText] = useState<string>('')
|
22 |
+
|
23 |
+
// Generation parameters
|
24 |
+
const [temperature, setTemperature] = useState<number>(0.7)
|
25 |
+
const [maxTokens, setMaxTokens] = useState<number>(100)
|
26 |
+
const [topP, setTopP] = useState<number>(0.9)
|
27 |
+
const [topK, setTopK] = useState<number>(50)
|
28 |
+
const [doSample, setDoSample] = useState<boolean>(true)
|
29 |
+
|
30 |
+
// Generation state
|
31 |
+
const [isGenerating, setIsGenerating] = useState<boolean>(false)
|
32 |
+
|
33 |
+
const { activeWorker, status, modelInfo, hasBeenLoaded } = useModel()
|
34 |
+
const messagesEndRef = useRef<HTMLDivElement>(null)
|
35 |
+
|
36 |
+
const scrollToBottom = () => {
|
37 |
+
messagesEndRef.current?.scrollIntoView({ behavior: 'smooth' })
|
38 |
+
}
|
39 |
+
|
40 |
+
useEffect(() => {
|
41 |
+
scrollToBottom()
|
42 |
+
}, [messages, generatedText])
|
43 |
+
|
44 |
+
const stopGeneration = useCallback(() => {
|
45 |
+
if (activeWorker && isGenerating) {
|
46 |
+
activeWorker.postMessage({ type: 'stop' })
|
47 |
+
setIsGenerating(false)
|
48 |
+
}
|
49 |
+
}, [activeWorker, isGenerating])
|
50 |
+
|
51 |
+
const handleSendMessage = useCallback(() => {
|
52 |
+
if (!currentMessage.trim() || !modelInfo || !activeWorker || isGenerating) {
|
53 |
+
return
|
54 |
+
}
|
55 |
+
|
56 |
+
const userMessage: ChatMessage = {
|
57 |
+
role: 'user',
|
58 |
+
content: currentMessage.trim()
|
59 |
+
}
|
60 |
+
|
61 |
+
const updatedMessages = [...messages, userMessage]
|
62 |
+
setMessages(updatedMessages)
|
63 |
+
setCurrentMessage('')
|
64 |
+
setIsGenerating(true)
|
65 |
+
|
66 |
+
const message: TextGenerationWorkerInput = {
|
67 |
+
type: 'generate',
|
68 |
+
messages: updatedMessages,
|
69 |
+
hasChatTemplate: modelInfo.hasChatTemplate,
|
70 |
+
model: modelInfo.id,
|
71 |
+
temperature,
|
72 |
+
max_new_tokens: maxTokens,
|
73 |
+
top_p: topP,
|
74 |
+
top_k: topK,
|
75 |
+
do_sample: doSample,
|
76 |
+
}
|
77 |
+
|
78 |
+
activeWorker.postMessage(message)
|
79 |
+
}, [currentMessage, messages, modelInfo, activeWorker, temperature, maxTokens, topP, topK, doSample, isGenerating])
|
80 |
+
|
81 |
+
const handleGenerateText = useCallback(() => {
|
82 |
+
if (!prompt.trim() || !modelInfo || !activeWorker || isGenerating) {
|
83 |
+
return
|
84 |
+
}
|
85 |
+
|
86 |
+
setIsGenerating(true)
|
87 |
+
|
88 |
+
const message: TextGenerationWorkerInput = {
|
89 |
+
type: 'generate',
|
90 |
+
prompt: prompt.trim(),
|
91 |
+
hasChatTemplate: modelInfo.hasChatTemplate,
|
92 |
+
model: modelInfo.id,
|
93 |
+
temperature,
|
94 |
+
max_new_tokens: maxTokens,
|
95 |
+
top_p: topP,
|
96 |
+
top_k: topK,
|
97 |
+
do_sample: doSample
|
98 |
+
}
|
99 |
+
|
100 |
+
activeWorker.postMessage(message)
|
101 |
+
}, [prompt, modelInfo, activeWorker, temperature, maxTokens, topP, topK, doSample, isGenerating])
|
102 |
+
|
103 |
+
useEffect(() => {
|
104 |
+
if (!activeWorker) return
|
105 |
+
|
106 |
+
const onMessageReceived = (e: MessageEvent<WorkerMessage>) => {
|
107 |
+
const { status, output } = e.data
|
108 |
+
|
109 |
+
if (status === 'output' && output) {
|
110 |
+
setIsGenerating(false)
|
111 |
+
if (modelInfo?.hasChatTemplate) {
|
112 |
+
// Chat mode
|
113 |
+
const assistantMessage: ChatMessage = {
|
114 |
+
role: 'assistant',
|
115 |
+
content: output.content
|
116 |
+
}
|
117 |
+
setMessages(prev => [...prev, assistantMessage])
|
118 |
+
} else {
|
119 |
+
// Simple text generation mode
|
120 |
+
setGeneratedText(output.content)
|
121 |
+
}
|
122 |
+
} else if (status === 'ready') {
|
123 |
+
setIsGenerating(false)
|
124 |
+
} else if (status === 'error') {
|
125 |
+
setIsGenerating(false)
|
126 |
+
}
|
127 |
+
}
|
128 |
+
|
129 |
+
activeWorker.addEventListener('message', onMessageReceived)
|
130 |
+
|
131 |
+
return () => {
|
132 |
+
activeWorker.removeEventListener('message', onMessageReceived)
|
133 |
+
}
|
134 |
+
}, [activeWorker, modelInfo?.hasChatTemplate])
|
135 |
+
|
136 |
+
const handleKeyPress = (e: React.KeyboardEvent) => {
|
137 |
+
if (e.key === 'Enter' && !e.shiftKey) {
|
138 |
+
e.preventDefault()
|
139 |
+
if (modelInfo?.hasChatTemplate) {
|
140 |
+
handleSendMessage()
|
141 |
+
} else {
|
142 |
+
handleGenerateText()
|
143 |
+
}
|
144 |
+
}
|
145 |
+
}
|
146 |
+
|
147 |
+
const clearChat = () => {
|
148 |
+
if (modelInfo?.hasChatTemplate) {
|
149 |
+
setMessages([{ role: 'system', content: 'You are a helpful assistant.' }])
|
150 |
+
} else {
|
151 |
+
setPrompt('')
|
152 |
+
setGeneratedText('')
|
153 |
+
}
|
154 |
+
}
|
155 |
+
|
156 |
+
const updateSystemMessage = (content: string) => {
|
157 |
+
setMessages(prev => [
|
158 |
+
{ role: 'system', content },
|
159 |
+
...prev.filter(msg => msg.role !== 'system')
|
160 |
+
])
|
161 |
+
}
|
162 |
+
|
163 |
+
|
164 |
+
const busy = status !== 'ready' || isGenerating
|
165 |
+
const hasChatTemplate = modelInfo?.hasChatTemplate
|
166 |
+
|
167 |
+
return (
|
168 |
+
<div className="flex flex-col h-[70vh] max-h-[100vh] w-full p-4">
|
169 |
+
<div className="flex items-center justify-between mb-4">
|
170 |
+
<h1 className="text-2xl font-bold">
|
171 |
+
{hasChatTemplate ? 'Chat with AI' : 'Text Generation'}
|
172 |
+
</h1>
|
173 |
+
<div className="flex gap-2">
|
174 |
+
<button
|
175 |
+
onClick={() => setShowSettings(true)}
|
176 |
+
className="p-2 bg-gray-100 hover:bg-gray-200 rounded-lg transition-colors"
|
177 |
+
title="Settings"
|
178 |
+
>
|
179 |
+
<Settings className="w-4 h-4" />
|
180 |
+
</button>
|
181 |
+
<button
|
182 |
+
onClick={clearChat}
|
183 |
+
className="p-2 bg-red-100 hover:bg-red-200 rounded-lg transition-colors"
|
184 |
+
title={hasChatTemplate ? "Clear Chat" : "Clear Text"}
|
185 |
+
>
|
186 |
+
<Trash2 className="w-4 h-4" />
|
187 |
+
</button>
|
188 |
+
{isGenerating && (
|
189 |
+
<button
|
190 |
+
onClick={stopGeneration}
|
191 |
+
className="p-2 bg-orange-100 hover:bg-orange-200 rounded-lg transition-colors"
|
192 |
+
title="Stop Generation"
|
193 |
+
>
|
194 |
+
<X className="w-4 h-4" />
|
195 |
+
</button>
|
196 |
+
)}
|
197 |
+
</div>
|
198 |
+
</div>
|
199 |
+
|
200 |
+
{/* Settings Dialog using Headless UI */}
|
201 |
+
<Transition appear show={showSettings} as={Fragment}>
|
202 |
+
<Dialog as="div" className="relative z-10" onClose={() => setShowSettings(false)}>
|
203 |
+
<Transition.Child
|
204 |
+
as={Fragment}
|
205 |
+
enter="ease-out duration-300"
|
206 |
+
enterFrom="opacity-0"
|
207 |
+
enterTo="opacity-100"
|
208 |
+
leave="ease-in duration-200"
|
209 |
+
leaveFrom="opacity-100"
|
210 |
+
leaveTo="opacity-0"
|
211 |
+
>
|
212 |
+
<div className="fixed inset-0 bg-black bg-opacity-25" />
|
213 |
+
</Transition.Child>
|
214 |
+
|
215 |
+
<div className="fixed inset-0 overflow-y-auto">
|
216 |
+
<div className="flex min-h-full items-center justify-center p-4 text-center">
|
217 |
+
<Transition.Child
|
218 |
+
as={Fragment}
|
219 |
+
enter="ease-out duration-300"
|
220 |
+
enterFrom="opacity-0 scale-95"
|
221 |
+
enterTo="opacity-100 scale-100"
|
222 |
+
leave="ease-in duration-200"
|
223 |
+
leaveFrom="opacity-100 scale-100"
|
224 |
+
leaveTo="opacity-0 scale-95"
|
225 |
+
>
|
226 |
+
<Dialog.Panel className="w-full max-w-2xl transform overflow-hidden rounded-2xl bg-white p-6 text-left align-middle shadow-xl transition-all">
|
227 |
+
<Dialog.Title
|
228 |
+
as="h3"
|
229 |
+
className="text-lg font-medium leading-6 text-gray-900 mb-4"
|
230 |
+
>
|
231 |
+
Generation Settings
|
232 |
+
</Dialog.Title>
|
233 |
+
|
234 |
+
<div className="space-y-6">
|
235 |
+
{/* Generation Parameters */}
|
236 |
+
<div>
|
237 |
+
<h4 className="font-semibold text-gray-800 mb-3">Parameters</h4>
|
238 |
+
<div className="grid grid-cols-2 md:grid-cols-3 gap-4">
|
239 |
+
<div>
|
240 |
+
<label className="block text-sm font-medium text-gray-700 mb-1">
|
241 |
+
Temperature: {temperature}
|
242 |
+
</label>
|
243 |
+
<input
|
244 |
+
type="range"
|
245 |
+
min="0.1"
|
246 |
+
max="2.0"
|
247 |
+
step="0.1"
|
248 |
+
value={temperature}
|
249 |
+
onChange={(e) => setTemperature(parseFloat(e.target.value))}
|
250 |
+
className="w-full"
|
251 |
+
/>
|
252 |
+
</div>
|
253 |
+
|
254 |
+
<div>
|
255 |
+
<label className="block text-sm font-medium text-gray-700 mb-1">
|
256 |
+
Max Tokens: {maxTokens}
|
257 |
+
</label>
|
258 |
+
<input
|
259 |
+
type="range"
|
260 |
+
min="10"
|
261 |
+
max="500"
|
262 |
+
step="10"
|
263 |
+
value={maxTokens}
|
264 |
+
onChange={(e) => setMaxTokens(parseInt(e.target.value))}
|
265 |
+
className="w-full"
|
266 |
+
/>
|
267 |
+
</div>
|
268 |
+
|
269 |
+
<div>
|
270 |
+
<label className="block text-sm font-medium text-gray-700 mb-1">
|
271 |
+
Top P: {topP}
|
272 |
+
</label>
|
273 |
+
<input
|
274 |
+
type="range"
|
275 |
+
min="0.1"
|
276 |
+
max="1.0"
|
277 |
+
step="0.1"
|
278 |
+
value={topP}
|
279 |
+
onChange={(e) => setTopP(parseFloat(e.target.value))}
|
280 |
+
className="w-full"
|
281 |
+
/>
|
282 |
+
</div>
|
283 |
+
|
284 |
+
<div>
|
285 |
+
<label className="block text-sm font-medium text-gray-700 mb-1">
|
286 |
+
Top K: {topK}
|
287 |
+
</label>
|
288 |
+
<input
|
289 |
+
type="range"
|
290 |
+
min="1"
|
291 |
+
max="100"
|
292 |
+
step="1"
|
293 |
+
value={topK}
|
294 |
+
onChange={(e) => setTopK(parseInt(e.target.value))}
|
295 |
+
className="w-full"
|
296 |
+
/>
|
297 |
+
</div>
|
298 |
+
|
299 |
+
<div className="flex items-center">
|
300 |
+
<Switch
|
301 |
+
checked={doSample}
|
302 |
+
onChange={setDoSample}
|
303 |
+
className={`${
|
304 |
+
doSample ? 'bg-blue-600' : 'bg-gray-200'
|
305 |
+
} relative inline-flex h-6 w-11 items-center rounded-full`}
|
306 |
+
>
|
307 |
+
<span className="sr-only">Enable sampling</span>
|
308 |
+
<span
|
309 |
+
className={`${
|
310 |
+
doSample ? 'translate-x-6' : 'translate-x-1'
|
311 |
+
} inline-block h-4 w-4 transform rounded-full bg-white transition`}
|
312 |
+
/>
|
313 |
+
</Switch>
|
314 |
+
<label className="ml-2 text-sm font-medium text-gray-700">
|
315 |
+
Do Sample
|
316 |
+
</label>
|
317 |
+
</div>
|
318 |
+
</div>
|
319 |
+
</div>
|
320 |
+
|
321 |
+
|
322 |
+
{/* System Message for Chat */}
|
323 |
+
{hasChatTemplate && (
|
324 |
+
<div>
|
325 |
+
<h4 className="font-semibold text-gray-800 mb-3">System Message</h4>
|
326 |
+
<textarea
|
327 |
+
value={messages.find(m => m.role === 'system')?.content || ''}
|
328 |
+
onChange={(e) => updateSystemMessage(e.target.value)}
|
329 |
+
className="w-full p-2 border border-gray-300 rounded-md text-sm"
|
330 |
+
rows={3}
|
331 |
+
placeholder="Enter system message..."
|
332 |
+
/>
|
333 |
+
</div>
|
334 |
+
)}
|
335 |
+
</div>
|
336 |
+
|
337 |
+
<div className="mt-6 flex justify-end">
|
338 |
+
<button
|
339 |
+
type="button"
|
340 |
+
className="inline-flex justify-center rounded-md border border-transparent bg-blue-100 px-4 py-2 text-sm font-medium text-blue-900 hover:bg-blue-200 focus:outline-none focus-visible:ring-2 focus-visible:ring-blue-500 focus-visible:ring-offset-2"
|
341 |
+
onClick={() => setShowSettings(false)}
|
342 |
+
>
|
343 |
+
Close
|
344 |
+
</button>
|
345 |
+
</div>
|
346 |
+
</Dialog.Panel>
|
347 |
+
</Transition.Child>
|
348 |
+
</div>
|
349 |
+
</div>
|
350 |
+
</Dialog>
|
351 |
+
</Transition>
|
352 |
+
|
353 |
+
{hasChatTemplate ? (
|
354 |
+
// Chat Layout
|
355 |
+
<>
|
356 |
+
{/* Chat Messages */}
|
357 |
+
<div className="flex-1 overflow-y-auto border border-gray-300 rounded-lg p-4 mb-4 bg-white">
|
358 |
+
<div className="space-y-4">
|
359 |
+
{messages.filter(msg => msg.role !== 'system').map((message, index) => (
|
360 |
+
<div
|
361 |
+
key={index}
|
362 |
+
className={`flex ${message.role === 'user' ? 'justify-end' : 'justify-start'}`}
|
363 |
+
>
|
364 |
+
<div
|
365 |
+
className={`max-w-[80%] p-3 rounded-lg ${
|
366 |
+
message.role === 'user'
|
367 |
+
? 'bg-blue-500 text-white'
|
368 |
+
: 'bg-gray-100 text-gray-800'
|
369 |
+
}`}
|
370 |
+
>
|
371 |
+
<div className="text-xs font-medium mb-1 opacity-70">
|
372 |
+
{message.role === 'user' ? 'You' : 'Assistant'}
|
373 |
+
</div>
|
374 |
+
<div className="whitespace-pre-wrap">{message.content}</div>
|
375 |
+
</div>
|
376 |
+
</div>
|
377 |
+
))}
|
378 |
+
{isGenerating && (
|
379 |
+
<div className="flex justify-start">
|
380 |
+
<div className="bg-gray-100 text-gray-800 p-3 rounded-lg">
|
381 |
+
<div className="text-xs font-medium mb-1 opacity-70">Assistant</div>
|
382 |
+
<div className="flex items-center space-x-2">
|
383 |
+
<Loader2 className="w-4 h-4 animate-spin" />
|
384 |
+
<div>Loading...</div>
|
385 |
+
</div>
|
386 |
+
</div>
|
387 |
+
</div>
|
388 |
+
)}
|
389 |
+
</div>
|
390 |
+
<div ref={messagesEndRef} />
|
391 |
+
</div>
|
392 |
+
|
393 |
+
{/* Chat Input Area */}
|
394 |
+
<div className="flex gap-2">
|
395 |
+
<textarea
|
396 |
+
value={currentMessage}
|
397 |
+
onChange={(e) => setCurrentMessage(e.target.value)}
|
398 |
+
onKeyPress={handleKeyPress}
|
399 |
+
placeholder="Type your message... (Press Enter to send, Shift+Enter for new line)"
|
400 |
+
className="flex-1 p-3 border border-gray-300 rounded-lg resize-none focus:outline-none focus:ring-2 focus:ring-blue-500 focus:border-blue-500 disabled:bg-gray-100 disabled:cursor-not-allowed"
|
401 |
+
rows={2}
|
402 |
+
disabled={!hasBeenLoaded || isGenerating}
|
403 |
+
/>
|
404 |
+
<button
|
405 |
+
onClick={handleSendMessage}
|
406 |
+
disabled={!currentMessage.trim() || busy || !hasBeenLoaded}
|
407 |
+
className="px-4 py-2 bg-blue-500 hover:bg-blue-600 disabled:bg-gray-300 disabled:cursor-not-allowed text-white rounded-lg transition-colors flex items-center justify-center"
|
408 |
+
>
|
409 |
+
{isGenerating ? (
|
410 |
+
<Loader2 className="w-4 h-4 animate-spin" />
|
411 |
+
) : (
|
412 |
+
<Send className="w-4 h-4" />
|
413 |
+
)}
|
414 |
+
</button>
|
415 |
+
</div>
|
416 |
+
</>
|
417 |
+
) : (
|
418 |
+
// Simple Text Generation Layout
|
419 |
+
<>
|
420 |
+
{/* Prompt Input */}
|
421 |
+
<div className="mb-4">
|
422 |
+
<label className="block text-sm font-medium text-gray-700 mb-2">
|
423 |
+
Enter your prompt:
|
424 |
+
</label>
|
425 |
+
<textarea
|
426 |
+
value={prompt}
|
427 |
+
onChange={(e) => setPrompt(e.target.value)}
|
428 |
+
onKeyPress={handleKeyPress}
|
429 |
+
placeholder="Enter your text prompt here... (Press Enter to generate, Shift+Enter for new line)"
|
430 |
+
className="w-full p-3 border border-gray-300 rounded-lg resize-none focus:outline-none focus:ring-2 focus:ring-blue-500 focus:border-blue-500 disabled:bg-gray-100 disabled:cursor-not-allowed"
|
431 |
+
rows={4}
|
432 |
+
disabled={!hasBeenLoaded || isGenerating}
|
433 |
+
/>
|
434 |
+
</div>
|
435 |
+
|
436 |
+
{/* Generate Button */}
|
437 |
+
<div className="mb-4">
|
438 |
+
<button
|
439 |
+
onClick={handleGenerateText}
|
440 |
+
disabled={!prompt.trim() || busy || !hasBeenLoaded}
|
441 |
+
className="px-6 py-2 bg-green-500 hover:bg-green-600 disabled:bg-gray-300 disabled:cursor-not-allowed text-white rounded-lg transition-colors flex items-center gap-2"
|
442 |
+
>
|
443 |
+
{isGenerating ? (
|
444 |
+
<>
|
445 |
+
<Loader2 className="w-4 h-4 animate-spin" />
|
446 |
+
Generating...
|
447 |
+
</>
|
448 |
+
) : (
|
449 |
+
<>
|
450 |
+
<Send className="w-4 h-4" />
|
451 |
+
Generate Text
|
452 |
+
</>
|
453 |
+
)}
|
454 |
+
</button>
|
455 |
+
</div>
|
456 |
+
|
457 |
+
{/* Generated Text Output */}
|
458 |
+
<div className="flex-1 overflow-y-auto border border-gray-300 rounded-lg p-4 bg-white">
|
459 |
+
<div className="mb-2">
|
460 |
+
<label className="block text-sm font-medium text-gray-700">
|
461 |
+
Generated Text:
|
462 |
+
</label>
|
463 |
+
</div>
|
464 |
+
{generatedText ? (
|
465 |
+
<div className="whitespace-pre-wrap text-gray-800 bg-gray-50 p-3 rounded border">
|
466 |
+
{generatedText}
|
467 |
+
</div>
|
468 |
+
) : (
|
469 |
+
<div className="text-gray-500 italic flex items-center gap-2">
|
470 |
+
{isGenerating ? (
|
471 |
+
<>
|
472 |
+
<Loader2 className="w-4 h-4 animate-spin" />
|
473 |
+
Generating text...
|
474 |
+
</>
|
475 |
+
) : (
|
476 |
+
'Generated text will appear here'
|
477 |
+
)}
|
478 |
+
</div>
|
479 |
+
)}
|
480 |
+
<div ref={messagesEndRef} />
|
481 |
+
</div>
|
482 |
+
</>
|
483 |
+
)}
|
484 |
+
|
485 |
+
{!hasBeenLoaded && (
|
486 |
+
<div className="text-center text-gray-500 text-sm mt-2">
|
487 |
+
Please load a model first to start {hasChatTemplate ? 'chatting' : 'generating text'}
|
488 |
+
</div>
|
489 |
+
)}
|
490 |
+
</div>
|
491 |
+
)
|
492 |
+
}
|
493 |
+
|
494 |
+
export default TextGeneration
|
src/lib/workerManager.ts
CHANGED
@@ -1,33 +1,31 @@
|
|
1 |
-
const workers: Record<string, Worker | null> = {}
|
2 |
|
3 |
export const getWorker = (pipeline: string) => {
|
4 |
if (!workers[pipeline]) {
|
5 |
-
let workerUrl: string
|
6 |
|
7 |
-
// Construct the public URL for the worker script.
|
8 |
-
// process.env.PUBLIC_URL ensures this works correctly even if the
|
9 |
-
// app is hosted in a sub-directory.
|
10 |
switch (pipeline) {
|
11 |
case 'text-classification':
|
12 |
-
workerUrl = `${process.env.PUBLIC_URL}/workers/text-classification.js
|
13 |
-
break
|
14 |
case 'zero-shot-classification':
|
15 |
-
workerUrl = `${process.env.PUBLIC_URL}/workers/zero-shot-classification.js
|
16 |
-
break
|
17 |
-
|
|
|
|
|
18 |
default:
|
19 |
-
|
20 |
-
return null;
|
21 |
}
|
22 |
-
workers[pipeline] = new Worker(workerUrl, { type: 'module' })
|
23 |
}
|
24 |
-
return workers[pipeline]
|
25 |
-
}
|
26 |
|
27 |
export const terminateWorker = (pipeline: string) => {
|
28 |
-
const worker = workers[pipeline]
|
29 |
if (worker) {
|
30 |
-
worker.terminate()
|
31 |
-
delete workers[pipeline]
|
32 |
}
|
33 |
-
}
|
|
|
1 |
+
const workers: Record<string, Worker | null> = {}
|
2 |
|
3 |
export const getWorker = (pipeline: string) => {
|
4 |
if (!workers[pipeline]) {
|
5 |
+
let workerUrl: string
|
6 |
|
|
|
|
|
|
|
7 |
switch (pipeline) {
|
8 |
case 'text-classification':
|
9 |
+
workerUrl = `${process.env.PUBLIC_URL}/workers/text-classification.js`
|
10 |
+
break
|
11 |
case 'zero-shot-classification':
|
12 |
+
workerUrl = `${process.env.PUBLIC_URL}/workers/zero-shot-classification.js`
|
13 |
+
break
|
14 |
+
case 'text-generation':
|
15 |
+
workerUrl = `${process.env.PUBLIC_URL}/workers/text-generation.js`
|
16 |
+
break
|
17 |
default:
|
18 |
+
return null
|
|
|
19 |
}
|
20 |
+
workers[pipeline] = new Worker(workerUrl, { type: 'module' })
|
21 |
}
|
22 |
+
return workers[pipeline]
|
23 |
+
}
|
24 |
|
25 |
export const terminateWorker = (pipeline: string) => {
|
26 |
+
const worker = workers[pipeline]
|
27 |
if (worker) {
|
28 |
+
worker.terminate()
|
29 |
+
delete workers[pipeline]
|
30 |
}
|
31 |
+
}
|
src/types.ts
CHANGED
@@ -9,6 +9,16 @@ export interface ClassificationOutput {
|
|
9 |
scores: number[]
|
10 |
}
|
11 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
export type WorkerStatus =
|
13 |
| 'initiate'
|
14 |
| 'ready'
|
@@ -36,6 +46,18 @@ export interface TextClassificationWorkerInput {
|
|
36 |
model: string
|
37 |
}
|
38 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
|
40 |
type q8 = 'q8' | 'int8' | 'bnb8' | 'uint8'
|
41 |
type q4 = 'q4' | 'bnb4' | 'q4f16'
|
@@ -44,7 +66,6 @@ type fp32 = 'fp32'
|
|
44 |
|
45 |
export type QuantizationType = q8 | q4 | fp16 | fp32
|
46 |
|
47 |
-
|
48 |
export interface ModelInfo {
|
49 |
id: string
|
50 |
name: string
|
@@ -58,15 +79,18 @@ export interface ModelInfo {
|
|
58 |
supportedQuantizations: QuantizationType[]
|
59 |
baseId?: string
|
60 |
readme?: string
|
|
|
61 |
}
|
62 |
|
63 |
-
|
64 |
export interface ModelInfoResponse {
|
65 |
id: string
|
66 |
createdAt: string
|
67 |
config?: {
|
68 |
architectures: string[]
|
69 |
model_type: string
|
|
|
|
|
|
|
70 |
}
|
71 |
lastModified: string
|
72 |
pipeline_tag: string
|
|
|
9 |
scores: number[]
|
10 |
}
|
11 |
|
12 |
+
export interface ChatMessage {
|
13 |
+
role: 'system' | 'user' | 'assistant'
|
14 |
+
content: string
|
15 |
+
}
|
16 |
+
|
17 |
+
export interface GenerationOutput {
|
18 |
+
role: 'assistant'
|
19 |
+
content: string
|
20 |
+
}
|
21 |
+
|
22 |
export type WorkerStatus =
|
23 |
| 'initiate'
|
24 |
| 'ready'
|
|
|
46 |
model: string
|
47 |
}
|
48 |
|
49 |
+
export interface TextGenerationWorkerInput {
|
50 |
+
type: 'generate'
|
51 |
+
prompt?: string
|
52 |
+
messages?: ChatMessage[]
|
53 |
+
hasChatTemplate: boolean
|
54 |
+
model: string
|
55 |
+
temperature?: number
|
56 |
+
max_new_tokens?: number
|
57 |
+
top_p?: number
|
58 |
+
top_k?: number
|
59 |
+
do_sample?: boolean
|
60 |
+
}
|
61 |
|
62 |
type q8 = 'q8' | 'int8' | 'bnb8' | 'uint8'
|
63 |
type q4 = 'q4' | 'bnb4' | 'q4f16'
|
|
|
66 |
|
67 |
export type QuantizationType = q8 | q4 | fp16 | fp32
|
68 |
|
|
|
69 |
export interface ModelInfo {
|
70 |
id: string
|
71 |
name: string
|
|
|
79 |
supportedQuantizations: QuantizationType[]
|
80 |
baseId?: string
|
81 |
readme?: string
|
82 |
+
hasChatTemplate: boolean
|
83 |
}
|
84 |
|
|
|
85 |
export interface ModelInfoResponse {
|
86 |
id: string
|
87 |
createdAt: string
|
88 |
config?: {
|
89 |
architectures: string[]
|
90 |
model_type: string
|
91 |
+
tokenizer_config?: {
|
92 |
+
chat_template?: string
|
93 |
+
}
|
94 |
}
|
95 |
lastModified: string
|
96 |
pipeline_tag: string
|