Vokturz commited on
Commit
2f35054
·
1 Parent(s): d0679b9

feat: Add text generation functionality

Browse files
public/workers/text-generation.js ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /* eslint-disable no-restricted-globals */
2
+ import { pipeline } from 'https://cdn.jsdelivr.net/npm/@huggingface/transformers@3.6.3'
3
+
4
+ class MyTextGenerationPipeline {
5
+ static task = 'text-generation'
6
+ static instance = null
7
+ static currentGeneration = null
8
+
9
+ static async getInstance(model, dtype = 'fp32', progress_callback = null) {
10
+ this.instance = pipeline(this.task, model, {
11
+ dtype,
12
+ device: 'webgpu',
13
+ progress_callback
14
+ })
15
+ return this.instance
16
+ }
17
+
18
+ static stopGeneration() {
19
+ if (this.currentGeneration) {
20
+ this.currentGeneration.abort()
21
+ this.currentGeneration = null
22
+ }
23
+ }
24
+ }
25
+
26
+ // Listen for messages from the main thread
27
+ self.addEventListener('message', async (event) => {
28
+ try {
29
+ const {
30
+ type,
31
+ model,
32
+ dtype,
33
+ messages,
34
+ prompt,
35
+ hasChatTemplate,
36
+ temperature,
37
+ max_new_tokens,
38
+ top_p,
39
+ top_k,
40
+ do_sample,
41
+ stop_words
42
+ } = event.data
43
+
44
+ if (type === 'stop') {
45
+ MyTextGenerationPipeline.stopGeneration()
46
+ self.postMessage({ status: 'ready' })
47
+ return
48
+ }
49
+
50
+ if (!model) {
51
+ self.postMessage({
52
+ status: 'error',
53
+ output: 'No model provided'
54
+ })
55
+ return
56
+ }
57
+
58
+ // Retrieve the pipeline. This will download the model if not already cached.
59
+ const generator = await MyTextGenerationPipeline.getInstance(
60
+ model,
61
+ dtype,
62
+ (x) => {
63
+ self.postMessage({ status: 'loading', output: x })
64
+ }
65
+ )
66
+
67
+ if (type === 'load') {
68
+ self.postMessage({
69
+ status: 'ready',
70
+ output: `Model ${model}, dtype ${dtype} loaded`
71
+ })
72
+ return
73
+ }
74
+
75
+ if (type === 'generate') {
76
+ let inputText = ''
77
+
78
+ if (hasChatTemplate && messages && messages.length > 0) {
79
+ inputText = messages
80
+ } else if (!hasChatTemplate && prompt) {
81
+ inputText = prompt
82
+ } else {
83
+ self.postMessage({ status: 'ready' })
84
+ return
85
+ }
86
+
87
+ const options = {
88
+ max_new_tokens: max_new_tokens || 100,
89
+ temperature: temperature || 0.7,
90
+ do_sample: do_sample !== false,
91
+ ...(top_p && { top_p }),
92
+ ...(top_k && { top_k }),
93
+ ...(stop_words && stop_words.length > 0 && { stop_words })
94
+ }
95
+
96
+ // Create an AbortController for this generation
97
+ const abortController = new AbortController()
98
+ MyTextGenerationPipeline.currentGeneration = abortController
99
+
100
+ try {
101
+ const output = await generator(inputText, {
102
+ ...options,
103
+ signal: abortController.signal
104
+ })
105
+
106
+ if (hasChatTemplate) {
107
+ // For chat mode, extract only the assistant's response
108
+ self.postMessage({
109
+ status: 'output',
110
+ output: output[0].generated_text.slice(-1)[0]
111
+ })
112
+ } else {
113
+ self.postMessage({
114
+ status: 'output',
115
+ output: {
116
+ role: 'assistant',
117
+ content: output[0].generated_text
118
+ }
119
+ })
120
+ }
121
+
122
+ self.postMessage({ status: 'ready' })
123
+ } catch (error) {
124
+ if (error.name === 'AbortError') {
125
+ self.postMessage({ status: 'ready' })
126
+ } else {
127
+ throw error
128
+ }
129
+ } finally {
130
+ MyTextGenerationPipeline.currentGeneration = null
131
+ }
132
+ }
133
+ } catch (error) {
134
+ self.postMessage({
135
+ status: 'error',
136
+ output: error.message || 'An error occurred during text generation'
137
+ })
138
+ }
139
+ })
src/App.tsx CHANGED
@@ -8,6 +8,7 @@ import { getModelsByPipeline } from './lib/huggingface'
8
  import ModelSelector from './components/ModelSelector'
9
  import ModelInfo from './components/ModelInfo'
10
  import ModelReadme from './components/ModelReadme'
 
11
 
12
  function App() {
13
  const { pipeline, setPipeline, setModels, setModelInfo, modelInfo, setIsFetching} = useModel()
@@ -76,6 +77,7 @@ function App() {
76
  <ZeroShotClassification />
77
  )}
78
  {pipeline === 'text-classification' && <TextClassification />}
 
79
  </div>
80
  </main>
81
  </div>
 
8
  import ModelSelector from './components/ModelSelector'
9
  import ModelInfo from './components/ModelInfo'
10
  import ModelReadme from './components/ModelReadme'
11
+ import TextGeneration from './components/TextGeneration'
12
 
13
  function App() {
14
  const { pipeline, setPipeline, setModels, setModelInfo, modelInfo, setIsFetching} = useModel()
 
77
  <ZeroShotClassification />
78
  )}
79
  {pipeline === 'text-classification' && <TextClassification />}
80
+ {pipeline === 'text-generation' && <TextGeneration />}
81
  </div>
82
  </main>
83
  </div>
src/components/ModelLoader.tsx CHANGED
@@ -23,7 +23,7 @@ const ModelLoader = () => {
23
 
24
  useEffect(() => {
25
  setHasBeenLoaded(false)
26
- }, [selectedQuantization])
27
 
28
  useEffect(() => {
29
  if (!modelInfo) return
 
23
 
24
  useEffect(() => {
25
  setHasBeenLoaded(false)
26
+ }, [selectedQuantization, setHasBeenLoaded])
27
 
28
  useEffect(() => {
29
  if (!modelInfo) return
src/components/ModelSelector.tsx CHANGED
@@ -107,11 +107,9 @@ function ModelSelector() {
107
  incompatibilityReason: modelInfoResponse.incompatibilityReason,
108
  supportedQuantizations: modelInfoResponse.supportedQuantizations,
109
  baseId: modelInfoResponse.baseId,
110
- readme: modelInfoResponse.readme
 
111
  }
112
-
113
- console.log('Fetched model info:', modelInfoResponse)
114
-
115
  setModelInfo(modelInfo)
116
  setIsCustomModel(isCustom)
117
  setIsFetching(false)
 
107
  incompatibilityReason: modelInfoResponse.incompatibilityReason,
108
  supportedQuantizations: modelInfoResponse.supportedQuantizations,
109
  baseId: modelInfoResponse.baseId,
110
+ readme: modelInfoResponse.readme,
111
+ hasChatTemplate: Boolean(modelInfoResponse.config?.tokenizer_config?.chat_template)
112
  }
 
 
 
113
  setModelInfo(modelInfo)
114
  setIsCustomModel(isCustom)
115
  setIsFetching(false)
src/components/TextGeneration.tsx ADDED
@@ -0,0 +1,494 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { useState, useRef, useEffect, useCallback } from 'react'
2
+ import { Send, Settings, Trash2, Loader2, X } from 'lucide-react'
3
+ import { Dialog, Transition, Switch } from '@headlessui/react'
4
+ import { Fragment } from 'react'
5
+ import {
6
+ ChatMessage,
7
+ TextGenerationWorkerInput,
8
+ WorkerMessage
9
+ } from '../types'
10
+ import { useModel } from '../contexts/ModelContext'
11
+
12
+ function TextGeneration() {
13
+ const [messages, setMessages] = useState<ChatMessage[]>([
14
+ { role: 'system', content: 'You are a helpful assistant.' }
15
+ ])
16
+ const [currentMessage, setCurrentMessage] = useState<string>('')
17
+ const [showSettings, setShowSettings] = useState<boolean>(false)
18
+
19
+ // For simple text generation
20
+ const [prompt, setPrompt] = useState<string>('')
21
+ const [generatedText, setGeneratedText] = useState<string>('')
22
+
23
+ // Generation parameters
24
+ const [temperature, setTemperature] = useState<number>(0.7)
25
+ const [maxTokens, setMaxTokens] = useState<number>(100)
26
+ const [topP, setTopP] = useState<number>(0.9)
27
+ const [topK, setTopK] = useState<number>(50)
28
+ const [doSample, setDoSample] = useState<boolean>(true)
29
+
30
+ // Generation state
31
+ const [isGenerating, setIsGenerating] = useState<boolean>(false)
32
+
33
+ const { activeWorker, status, modelInfo, hasBeenLoaded } = useModel()
34
+ const messagesEndRef = useRef<HTMLDivElement>(null)
35
+
36
+ const scrollToBottom = () => {
37
+ messagesEndRef.current?.scrollIntoView({ behavior: 'smooth' })
38
+ }
39
+
40
+ useEffect(() => {
41
+ scrollToBottom()
42
+ }, [messages, generatedText])
43
+
44
+ const stopGeneration = useCallback(() => {
45
+ if (activeWorker && isGenerating) {
46
+ activeWorker.postMessage({ type: 'stop' })
47
+ setIsGenerating(false)
48
+ }
49
+ }, [activeWorker, isGenerating])
50
+
51
+ const handleSendMessage = useCallback(() => {
52
+ if (!currentMessage.trim() || !modelInfo || !activeWorker || isGenerating) {
53
+ return
54
+ }
55
+
56
+ const userMessage: ChatMessage = {
57
+ role: 'user',
58
+ content: currentMessage.trim()
59
+ }
60
+
61
+ const updatedMessages = [...messages, userMessage]
62
+ setMessages(updatedMessages)
63
+ setCurrentMessage('')
64
+ setIsGenerating(true)
65
+
66
+ const message: TextGenerationWorkerInput = {
67
+ type: 'generate',
68
+ messages: updatedMessages,
69
+ hasChatTemplate: modelInfo.hasChatTemplate,
70
+ model: modelInfo.id,
71
+ temperature,
72
+ max_new_tokens: maxTokens,
73
+ top_p: topP,
74
+ top_k: topK,
75
+ do_sample: doSample,
76
+ }
77
+
78
+ activeWorker.postMessage(message)
79
+ }, [currentMessage, messages, modelInfo, activeWorker, temperature, maxTokens, topP, topK, doSample, isGenerating])
80
+
81
+ const handleGenerateText = useCallback(() => {
82
+ if (!prompt.trim() || !modelInfo || !activeWorker || isGenerating) {
83
+ return
84
+ }
85
+
86
+ setIsGenerating(true)
87
+
88
+ const message: TextGenerationWorkerInput = {
89
+ type: 'generate',
90
+ prompt: prompt.trim(),
91
+ hasChatTemplate: modelInfo.hasChatTemplate,
92
+ model: modelInfo.id,
93
+ temperature,
94
+ max_new_tokens: maxTokens,
95
+ top_p: topP,
96
+ top_k: topK,
97
+ do_sample: doSample
98
+ }
99
+
100
+ activeWorker.postMessage(message)
101
+ }, [prompt, modelInfo, activeWorker, temperature, maxTokens, topP, topK, doSample, isGenerating])
102
+
103
+ useEffect(() => {
104
+ if (!activeWorker) return
105
+
106
+ const onMessageReceived = (e: MessageEvent<WorkerMessage>) => {
107
+ const { status, output } = e.data
108
+
109
+ if (status === 'output' && output) {
110
+ setIsGenerating(false)
111
+ if (modelInfo?.hasChatTemplate) {
112
+ // Chat mode
113
+ const assistantMessage: ChatMessage = {
114
+ role: 'assistant',
115
+ content: output.content
116
+ }
117
+ setMessages(prev => [...prev, assistantMessage])
118
+ } else {
119
+ // Simple text generation mode
120
+ setGeneratedText(output.content)
121
+ }
122
+ } else if (status === 'ready') {
123
+ setIsGenerating(false)
124
+ } else if (status === 'error') {
125
+ setIsGenerating(false)
126
+ }
127
+ }
128
+
129
+ activeWorker.addEventListener('message', onMessageReceived)
130
+
131
+ return () => {
132
+ activeWorker.removeEventListener('message', onMessageReceived)
133
+ }
134
+ }, [activeWorker, modelInfo?.hasChatTemplate])
135
+
136
+ const handleKeyPress = (e: React.KeyboardEvent) => {
137
+ if (e.key === 'Enter' && !e.shiftKey) {
138
+ e.preventDefault()
139
+ if (modelInfo?.hasChatTemplate) {
140
+ handleSendMessage()
141
+ } else {
142
+ handleGenerateText()
143
+ }
144
+ }
145
+ }
146
+
147
+ const clearChat = () => {
148
+ if (modelInfo?.hasChatTemplate) {
149
+ setMessages([{ role: 'system', content: 'You are a helpful assistant.' }])
150
+ } else {
151
+ setPrompt('')
152
+ setGeneratedText('')
153
+ }
154
+ }
155
+
156
+ const updateSystemMessage = (content: string) => {
157
+ setMessages(prev => [
158
+ { role: 'system', content },
159
+ ...prev.filter(msg => msg.role !== 'system')
160
+ ])
161
+ }
162
+
163
+
164
+ const busy = status !== 'ready' || isGenerating
165
+ const hasChatTemplate = modelInfo?.hasChatTemplate
166
+
167
+ return (
168
+ <div className="flex flex-col h-[70vh] max-h-[100vh] w-full p-4">
169
+ <div className="flex items-center justify-between mb-4">
170
+ <h1 className="text-2xl font-bold">
171
+ {hasChatTemplate ? 'Chat with AI' : 'Text Generation'}
172
+ </h1>
173
+ <div className="flex gap-2">
174
+ <button
175
+ onClick={() => setShowSettings(true)}
176
+ className="p-2 bg-gray-100 hover:bg-gray-200 rounded-lg transition-colors"
177
+ title="Settings"
178
+ >
179
+ <Settings className="w-4 h-4" />
180
+ </button>
181
+ <button
182
+ onClick={clearChat}
183
+ className="p-2 bg-red-100 hover:bg-red-200 rounded-lg transition-colors"
184
+ title={hasChatTemplate ? "Clear Chat" : "Clear Text"}
185
+ >
186
+ <Trash2 className="w-4 h-4" />
187
+ </button>
188
+ {isGenerating && (
189
+ <button
190
+ onClick={stopGeneration}
191
+ className="p-2 bg-orange-100 hover:bg-orange-200 rounded-lg transition-colors"
192
+ title="Stop Generation"
193
+ >
194
+ <X className="w-4 h-4" />
195
+ </button>
196
+ )}
197
+ </div>
198
+ </div>
199
+
200
+ {/* Settings Dialog using Headless UI */}
201
+ <Transition appear show={showSettings} as={Fragment}>
202
+ <Dialog as="div" className="relative z-10" onClose={() => setShowSettings(false)}>
203
+ <Transition.Child
204
+ as={Fragment}
205
+ enter="ease-out duration-300"
206
+ enterFrom="opacity-0"
207
+ enterTo="opacity-100"
208
+ leave="ease-in duration-200"
209
+ leaveFrom="opacity-100"
210
+ leaveTo="opacity-0"
211
+ >
212
+ <div className="fixed inset-0 bg-black bg-opacity-25" />
213
+ </Transition.Child>
214
+
215
+ <div className="fixed inset-0 overflow-y-auto">
216
+ <div className="flex min-h-full items-center justify-center p-4 text-center">
217
+ <Transition.Child
218
+ as={Fragment}
219
+ enter="ease-out duration-300"
220
+ enterFrom="opacity-0 scale-95"
221
+ enterTo="opacity-100 scale-100"
222
+ leave="ease-in duration-200"
223
+ leaveFrom="opacity-100 scale-100"
224
+ leaveTo="opacity-0 scale-95"
225
+ >
226
+ <Dialog.Panel className="w-full max-w-2xl transform overflow-hidden rounded-2xl bg-white p-6 text-left align-middle shadow-xl transition-all">
227
+ <Dialog.Title
228
+ as="h3"
229
+ className="text-lg font-medium leading-6 text-gray-900 mb-4"
230
+ >
231
+ Generation Settings
232
+ </Dialog.Title>
233
+
234
+ <div className="space-y-6">
235
+ {/* Generation Parameters */}
236
+ <div>
237
+ <h4 className="font-semibold text-gray-800 mb-3">Parameters</h4>
238
+ <div className="grid grid-cols-2 md:grid-cols-3 gap-4">
239
+ <div>
240
+ <label className="block text-sm font-medium text-gray-700 mb-1">
241
+ Temperature: {temperature}
242
+ </label>
243
+ <input
244
+ type="range"
245
+ min="0.1"
246
+ max="2.0"
247
+ step="0.1"
248
+ value={temperature}
249
+ onChange={(e) => setTemperature(parseFloat(e.target.value))}
250
+ className="w-full"
251
+ />
252
+ </div>
253
+
254
+ <div>
255
+ <label className="block text-sm font-medium text-gray-700 mb-1">
256
+ Max Tokens: {maxTokens}
257
+ </label>
258
+ <input
259
+ type="range"
260
+ min="10"
261
+ max="500"
262
+ step="10"
263
+ value={maxTokens}
264
+ onChange={(e) => setMaxTokens(parseInt(e.target.value))}
265
+ className="w-full"
266
+ />
267
+ </div>
268
+
269
+ <div>
270
+ <label className="block text-sm font-medium text-gray-700 mb-1">
271
+ Top P: {topP}
272
+ </label>
273
+ <input
274
+ type="range"
275
+ min="0.1"
276
+ max="1.0"
277
+ step="0.1"
278
+ value={topP}
279
+ onChange={(e) => setTopP(parseFloat(e.target.value))}
280
+ className="w-full"
281
+ />
282
+ </div>
283
+
284
+ <div>
285
+ <label className="block text-sm font-medium text-gray-700 mb-1">
286
+ Top K: {topK}
287
+ </label>
288
+ <input
289
+ type="range"
290
+ min="1"
291
+ max="100"
292
+ step="1"
293
+ value={topK}
294
+ onChange={(e) => setTopK(parseInt(e.target.value))}
295
+ className="w-full"
296
+ />
297
+ </div>
298
+
299
+ <div className="flex items-center">
300
+ <Switch
301
+ checked={doSample}
302
+ onChange={setDoSample}
303
+ className={`${
304
+ doSample ? 'bg-blue-600' : 'bg-gray-200'
305
+ } relative inline-flex h-6 w-11 items-center rounded-full`}
306
+ >
307
+ <span className="sr-only">Enable sampling</span>
308
+ <span
309
+ className={`${
310
+ doSample ? 'translate-x-6' : 'translate-x-1'
311
+ } inline-block h-4 w-4 transform rounded-full bg-white transition`}
312
+ />
313
+ </Switch>
314
+ <label className="ml-2 text-sm font-medium text-gray-700">
315
+ Do Sample
316
+ </label>
317
+ </div>
318
+ </div>
319
+ </div>
320
+
321
+
322
+ {/* System Message for Chat */}
323
+ {hasChatTemplate && (
324
+ <div>
325
+ <h4 className="font-semibold text-gray-800 mb-3">System Message</h4>
326
+ <textarea
327
+ value={messages.find(m => m.role === 'system')?.content || ''}
328
+ onChange={(e) => updateSystemMessage(e.target.value)}
329
+ className="w-full p-2 border border-gray-300 rounded-md text-sm"
330
+ rows={3}
331
+ placeholder="Enter system message..."
332
+ />
333
+ </div>
334
+ )}
335
+ </div>
336
+
337
+ <div className="mt-6 flex justify-end">
338
+ <button
339
+ type="button"
340
+ className="inline-flex justify-center rounded-md border border-transparent bg-blue-100 px-4 py-2 text-sm font-medium text-blue-900 hover:bg-blue-200 focus:outline-none focus-visible:ring-2 focus-visible:ring-blue-500 focus-visible:ring-offset-2"
341
+ onClick={() => setShowSettings(false)}
342
+ >
343
+ Close
344
+ </button>
345
+ </div>
346
+ </Dialog.Panel>
347
+ </Transition.Child>
348
+ </div>
349
+ </div>
350
+ </Dialog>
351
+ </Transition>
352
+
353
+ {hasChatTemplate ? (
354
+ // Chat Layout
355
+ <>
356
+ {/* Chat Messages */}
357
+ <div className="flex-1 overflow-y-auto border border-gray-300 rounded-lg p-4 mb-4 bg-white">
358
+ <div className="space-y-4">
359
+ {messages.filter(msg => msg.role !== 'system').map((message, index) => (
360
+ <div
361
+ key={index}
362
+ className={`flex ${message.role === 'user' ? 'justify-end' : 'justify-start'}`}
363
+ >
364
+ <div
365
+ className={`max-w-[80%] p-3 rounded-lg ${
366
+ message.role === 'user'
367
+ ? 'bg-blue-500 text-white'
368
+ : 'bg-gray-100 text-gray-800'
369
+ }`}
370
+ >
371
+ <div className="text-xs font-medium mb-1 opacity-70">
372
+ {message.role === 'user' ? 'You' : 'Assistant'}
373
+ </div>
374
+ <div className="whitespace-pre-wrap">{message.content}</div>
375
+ </div>
376
+ </div>
377
+ ))}
378
+ {isGenerating && (
379
+ <div className="flex justify-start">
380
+ <div className="bg-gray-100 text-gray-800 p-3 rounded-lg">
381
+ <div className="text-xs font-medium mb-1 opacity-70">Assistant</div>
382
+ <div className="flex items-center space-x-2">
383
+ <Loader2 className="w-4 h-4 animate-spin" />
384
+ <div>Loading...</div>
385
+ </div>
386
+ </div>
387
+ </div>
388
+ )}
389
+ </div>
390
+ <div ref={messagesEndRef} />
391
+ </div>
392
+
393
+ {/* Chat Input Area */}
394
+ <div className="flex gap-2">
395
+ <textarea
396
+ value={currentMessage}
397
+ onChange={(e) => setCurrentMessage(e.target.value)}
398
+ onKeyPress={handleKeyPress}
399
+ placeholder="Type your message... (Press Enter to send, Shift+Enter for new line)"
400
+ className="flex-1 p-3 border border-gray-300 rounded-lg resize-none focus:outline-none focus:ring-2 focus:ring-blue-500 focus:border-blue-500 disabled:bg-gray-100 disabled:cursor-not-allowed"
401
+ rows={2}
402
+ disabled={!hasBeenLoaded || isGenerating}
403
+ />
404
+ <button
405
+ onClick={handleSendMessage}
406
+ disabled={!currentMessage.trim() || busy || !hasBeenLoaded}
407
+ className="px-4 py-2 bg-blue-500 hover:bg-blue-600 disabled:bg-gray-300 disabled:cursor-not-allowed text-white rounded-lg transition-colors flex items-center justify-center"
408
+ >
409
+ {isGenerating ? (
410
+ <Loader2 className="w-4 h-4 animate-spin" />
411
+ ) : (
412
+ <Send className="w-4 h-4" />
413
+ )}
414
+ </button>
415
+ </div>
416
+ </>
417
+ ) : (
418
+ // Simple Text Generation Layout
419
+ <>
420
+ {/* Prompt Input */}
421
+ <div className="mb-4">
422
+ <label className="block text-sm font-medium text-gray-700 mb-2">
423
+ Enter your prompt:
424
+ </label>
425
+ <textarea
426
+ value={prompt}
427
+ onChange={(e) => setPrompt(e.target.value)}
428
+ onKeyPress={handleKeyPress}
429
+ placeholder="Enter your text prompt here... (Press Enter to generate, Shift+Enter for new line)"
430
+ className="w-full p-3 border border-gray-300 rounded-lg resize-none focus:outline-none focus:ring-2 focus:ring-blue-500 focus:border-blue-500 disabled:bg-gray-100 disabled:cursor-not-allowed"
431
+ rows={4}
432
+ disabled={!hasBeenLoaded || isGenerating}
433
+ />
434
+ </div>
435
+
436
+ {/* Generate Button */}
437
+ <div className="mb-4">
438
+ <button
439
+ onClick={handleGenerateText}
440
+ disabled={!prompt.trim() || busy || !hasBeenLoaded}
441
+ className="px-6 py-2 bg-green-500 hover:bg-green-600 disabled:bg-gray-300 disabled:cursor-not-allowed text-white rounded-lg transition-colors flex items-center gap-2"
442
+ >
443
+ {isGenerating ? (
444
+ <>
445
+ <Loader2 className="w-4 h-4 animate-spin" />
446
+ Generating...
447
+ </>
448
+ ) : (
449
+ <>
450
+ <Send className="w-4 h-4" />
451
+ Generate Text
452
+ </>
453
+ )}
454
+ </button>
455
+ </div>
456
+
457
+ {/* Generated Text Output */}
458
+ <div className="flex-1 overflow-y-auto border border-gray-300 rounded-lg p-4 bg-white">
459
+ <div className="mb-2">
460
+ <label className="block text-sm font-medium text-gray-700">
461
+ Generated Text:
462
+ </label>
463
+ </div>
464
+ {generatedText ? (
465
+ <div className="whitespace-pre-wrap text-gray-800 bg-gray-50 p-3 rounded border">
466
+ {generatedText}
467
+ </div>
468
+ ) : (
469
+ <div className="text-gray-500 italic flex items-center gap-2">
470
+ {isGenerating ? (
471
+ <>
472
+ <Loader2 className="w-4 h-4 animate-spin" />
473
+ Generating text...
474
+ </>
475
+ ) : (
476
+ 'Generated text will appear here'
477
+ )}
478
+ </div>
479
+ )}
480
+ <div ref={messagesEndRef} />
481
+ </div>
482
+ </>
483
+ )}
484
+
485
+ {!hasBeenLoaded && (
486
+ <div className="text-center text-gray-500 text-sm mt-2">
487
+ Please load a model first to start {hasChatTemplate ? 'chatting' : 'generating text'}
488
+ </div>
489
+ )}
490
+ </div>
491
+ )
492
+ }
493
+
494
+ export default TextGeneration
src/lib/workerManager.ts CHANGED
@@ -1,33 +1,31 @@
1
- const workers: Record<string, Worker | null> = {};
2
 
3
  export const getWorker = (pipeline: string) => {
4
  if (!workers[pipeline]) {
5
- let workerUrl: string;
6
 
7
- // Construct the public URL for the worker script.
8
- // process.env.PUBLIC_URL ensures this works correctly even if the
9
- // app is hosted in a sub-directory.
10
  switch (pipeline) {
11
  case 'text-classification':
12
- workerUrl = `${process.env.PUBLIC_URL}/workers/text-classification.js`;
13
- break;
14
  case 'zero-shot-classification':
15
- workerUrl = `${process.env.PUBLIC_URL}/workers/zero-shot-classification.js`;
16
- break;
17
- // Add other pipeline types here
 
 
18
  default:
19
- // Return null or throw an error if the pipeline is unknown
20
- return null;
21
  }
22
- workers[pipeline] = new Worker(workerUrl, { type: 'module' });
23
  }
24
- return workers[pipeline];
25
- };
26
 
27
  export const terminateWorker = (pipeline: string) => {
28
- const worker = workers[pipeline];
29
  if (worker) {
30
- worker.terminate();
31
- delete workers[pipeline];
32
  }
33
- };
 
1
+ const workers: Record<string, Worker | null> = {}
2
 
3
  export const getWorker = (pipeline: string) => {
4
  if (!workers[pipeline]) {
5
+ let workerUrl: string
6
 
 
 
 
7
  switch (pipeline) {
8
  case 'text-classification':
9
+ workerUrl = `${process.env.PUBLIC_URL}/workers/text-classification.js`
10
+ break
11
  case 'zero-shot-classification':
12
+ workerUrl = `${process.env.PUBLIC_URL}/workers/zero-shot-classification.js`
13
+ break
14
+ case 'text-generation':
15
+ workerUrl = `${process.env.PUBLIC_URL}/workers/text-generation.js`
16
+ break
17
  default:
18
+ return null
 
19
  }
20
+ workers[pipeline] = new Worker(workerUrl, { type: 'module' })
21
  }
22
+ return workers[pipeline]
23
+ }
24
 
25
  export const terminateWorker = (pipeline: string) => {
26
+ const worker = workers[pipeline]
27
  if (worker) {
28
+ worker.terminate()
29
+ delete workers[pipeline]
30
  }
31
+ }
src/types.ts CHANGED
@@ -9,6 +9,16 @@ export interface ClassificationOutput {
9
  scores: number[]
10
  }
11
 
 
 
 
 
 
 
 
 
 
 
12
  export type WorkerStatus =
13
  | 'initiate'
14
  | 'ready'
@@ -36,6 +46,18 @@ export interface TextClassificationWorkerInput {
36
  model: string
37
  }
38
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
  type q8 = 'q8' | 'int8' | 'bnb8' | 'uint8'
41
  type q4 = 'q4' | 'bnb4' | 'q4f16'
@@ -44,7 +66,6 @@ type fp32 = 'fp32'
44
 
45
  export type QuantizationType = q8 | q4 | fp16 | fp32
46
 
47
-
48
  export interface ModelInfo {
49
  id: string
50
  name: string
@@ -58,15 +79,18 @@ export interface ModelInfo {
58
  supportedQuantizations: QuantizationType[]
59
  baseId?: string
60
  readme?: string
 
61
  }
62
 
63
-
64
  export interface ModelInfoResponse {
65
  id: string
66
  createdAt: string
67
  config?: {
68
  architectures: string[]
69
  model_type: string
 
 
 
70
  }
71
  lastModified: string
72
  pipeline_tag: string
 
9
  scores: number[]
10
  }
11
 
12
+ export interface ChatMessage {
13
+ role: 'system' | 'user' | 'assistant'
14
+ content: string
15
+ }
16
+
17
+ export interface GenerationOutput {
18
+ role: 'assistant'
19
+ content: string
20
+ }
21
+
22
  export type WorkerStatus =
23
  | 'initiate'
24
  | 'ready'
 
46
  model: string
47
  }
48
 
49
+ export interface TextGenerationWorkerInput {
50
+ type: 'generate'
51
+ prompt?: string
52
+ messages?: ChatMessage[]
53
+ hasChatTemplate: boolean
54
+ model: string
55
+ temperature?: number
56
+ max_new_tokens?: number
57
+ top_p?: number
58
+ top_k?: number
59
+ do_sample?: boolean
60
+ }
61
 
62
  type q8 = 'q8' | 'int8' | 'bnb8' | 'uint8'
63
  type q4 = 'q4' | 'bnb4' | 'q4f16'
 
66
 
67
  export type QuantizationType = q8 | q4 | fp16 | fp32
68
 
 
69
  export interface ModelInfo {
70
  id: string
71
  name: string
 
79
  supportedQuantizations: QuantizationType[]
80
  baseId?: string
81
  readme?: string
82
+ hasChatTemplate: boolean
83
  }
84
 
 
85
  export interface ModelInfoResponse {
86
  id: string
87
  createdAt: string
88
  config?: {
89
  architectures: string[]
90
  model_type: string
91
+ tokenizer_config?: {
92
+ chat_template?: string
93
+ }
94
  }
95
  lastModified: string
96
  pipeline_tag: string