Vokturz's picture
improve layout
91cc60b
import { useState, useRef, useEffect, useCallback } from 'react'
import { Send, Eraser, Loader2, X } from 'lucide-react'
import {
ChatMessage,
TextGenerationWorkerInput,
WorkerMessage
} from '../../types'
import { useModel } from '../../contexts/ModelContext'
import { useTextGeneration } from '../../contexts/TextGenerationContext'
function TextGeneration() {
const { config, messages, setMessages } = useTextGeneration()
const [currentMessage, setCurrentMessage] = useState<string>('')
const [prompt, setPrompt] = useState<string>('')
const [generatedText, setGeneratedText] = useState<string>('')
const [isGenerating, setIsGenerating] = useState<boolean>(false)
const {
activeWorker,
status,
modelInfo,
hasBeenLoaded,
selectedQuantization
} = useModel()
const messagesEndRef = useRef<HTMLDivElement>(null)
const scrollToBottom = () => {
messagesEndRef.current?.scrollIntoView({ behavior: 'smooth' })
}
useEffect(() => {
scrollToBottom()
}, [messages, generatedText])
const stopGeneration = useCallback(() => {
if (activeWorker && isGenerating) {
activeWorker.postMessage({ type: 'stop' })
setIsGenerating(false)
}
}, [activeWorker, isGenerating])
const handleSendMessage = useCallback(() => {
if (!currentMessage.trim() || !modelInfo || !activeWorker || isGenerating)
return
const userMessage: ChatMessage = {
role: 'user',
content: currentMessage.trim()
}
const updatedMessages = [...messages, userMessage]
setMessages(updatedMessages)
setCurrentMessage('')
setIsGenerating(true)
const message: TextGenerationWorkerInput = {
type: 'generate',
messages: updatedMessages,
hasChatTemplate: modelInfo.hasChatTemplate,
model: modelInfo.id,
dtype: selectedQuantization ?? 'fp32',
config
}
activeWorker.postMessage(message)
}, [
currentMessage,
messages,
setMessages,
modelInfo,
activeWorker,
config,
isGenerating,
selectedQuantization
])
const handleGenerateText = useCallback(() => {
if (!prompt.trim() || !modelInfo || !activeWorker || isGenerating) return
setIsGenerating(true)
const message: TextGenerationWorkerInput = {
type: 'generate',
prompt: prompt.trim(),
hasChatTemplate: modelInfo.hasChatTemplate,
model: modelInfo.id,
config,
dtype: selectedQuantization ?? 'fp32'
}
activeWorker.postMessage(message)
}, [
prompt,
modelInfo,
activeWorker,
config,
isGenerating,
selectedQuantization
])
useEffect(() => {
if (!activeWorker) return
const onMessageReceived = (e: MessageEvent<WorkerMessage>) => {
const { status, output } = e.data
if (status === 'output' && output) {
setIsGenerating(false)
if (modelInfo?.hasChatTemplate) {
const assistantMessage: ChatMessage = {
role: 'assistant',
content: output.content
}
setMessages((prev) => [...prev, assistantMessage])
} else {
setGeneratedText(output.content)
}
} else if (status === 'ready' || status === 'error') {
setIsGenerating(false)
}
}
activeWorker.addEventListener('message', onMessageReceived)
return () => activeWorker.removeEventListener('message', onMessageReceived)
}, [activeWorker, modelInfo?.hasChatTemplate, setMessages])
const handleKeyPress = (e: React.KeyboardEvent) => {
if (e.key === 'Enter' && !e.shiftKey) {
e.preventDefault()
if (modelInfo?.hasChatTemplate) {
handleSendMessage()
} else {
handleGenerateText()
}
}
}
const clearChat = () => {
if (modelInfo?.hasChatTemplate) {
setMessages((prev) => prev.filter((msg) => msg.role === 'system'))
} else {
setPrompt('')
setGeneratedText('')
}
}
const busy = status !== 'ready' || isGenerating
const hasChatTemplate = modelInfo?.hasChatTemplate
return (
<div className="flex flex-col min-h-[30dvh] max-h-[calc(100dvh-128px)] w-full p-4">
<div className="flex items-center justify-between mb-4">
<h1 className="text-2xl font-bold">
Text Generation {hasChatTemplate ? '(Chat)' : ''}
</h1>
<div className="flex gap-2">
<button
onClick={clearChat}
className="p-2 bg-red-100 hover:bg-red-200 rounded-lg transition-colors"
title={hasChatTemplate ? 'Clear Chat' : 'Clear Text'}
>
<Eraser className="w-4 h-4" />
</button>
{isGenerating && (
<button
onClick={stopGeneration}
className="p-2 bg-orange-100 hover:bg-orange-200 rounded-lg transition-colors"
title="Stop Generation"
>
<X className="w-4 h-4" />
</button>
)}
</div>
</div>
{hasChatTemplate ? (
<>
<div className="flex-1 overflow-y-auto border border-gray-300 rounded-lg p-4 mb-4 bg-white">
<div className="space-y-4">
{messages
.filter((msg) => msg.role !== 'system')
.map((message, index) => (
<div
key={index}
className={`flex ${message.role === 'user' ? 'justify-end' : 'justify-start'}`}
>
<div
className={`max-w-[80%] p-3 rounded-lg ${
message.role === 'user'
? 'bg-blue-500 text-white'
: 'bg-gray-100 text-gray-800'
}`}
>
<div className="text-xs font-medium mb-1 opacity-70">
{message.role === 'user' ? 'You' : 'Assistant'}
</div>
<div className="whitespace-pre-wrap">
{message.content}
</div>
</div>
</div>
))}
{isGenerating && (
<div className="flex justify-start">
<div className="bg-gray-100 text-gray-800 p-3 rounded-lg">
<div className="text-xs font-medium mb-1 opacity-70">
Assistant
</div>
<div className="flex items-center space-x-2">
<Loader2 className="w-4 h-4 animate-spin" />
<div>Loading...</div>
</div>
</div>
</div>
)}
</div>
<div ref={messagesEndRef} />
</div>
<div className="flex gap-2">
<textarea
value={currentMessage}
onChange={(e) => setCurrentMessage(e.target.value)}
onKeyPress={handleKeyPress}
placeholder="Type your message... (Press Enter to send, Shift+Enter for new line)"
className="flex-1 p-3 border border-gray-300 rounded-lg resize-none focus:outline-hidden focus:ring-2 focus:ring-blue-500 focus:border-blue-500 disabled:bg-gray-100 disabled:cursor-not-allowed"
rows={2}
disabled={!hasBeenLoaded || isGenerating}
/>
<button
onClick={handleSendMessage}
disabled={!currentMessage.trim() || busy || !hasBeenLoaded}
className="px-4 py-2 bg-blue-500 hover:bg-blue-600 disabled:bg-gray-300 disabled:cursor-not-allowed text-white rounded-lg transition-colors flex items-center justify-center"
>
{isGenerating ? (
<Loader2 className="w-4 h-4 animate-spin" />
) : (
<Send className="w-4 h-4" />
)}
</button>
</div>
</>
) : (
<>
<div className="mb-4">
<label className="block text-sm font-medium text-gray-700 mb-2">
Enter your prompt:
</label>
<textarea
value={prompt}
onChange={(e) => setPrompt(e.target.value)}
onKeyPress={handleKeyPress}
placeholder="Enter your text prompt here... (Press Enter to generate, Shift+Enter for new line)"
className="w-full p-3 border border-gray-300 rounded-lg resize-none focus:outline-hidden focus:ring-2 focus:ring-blue-500 focus:border-blue-500 disabled:bg-gray-100 disabled:cursor-not-allowed"
rows={4}
disabled={!hasBeenLoaded || isGenerating}
/>
</div>
<div className="mb-4">
<button
onClick={handleGenerateText}
disabled={!prompt.trim() || busy || !hasBeenLoaded}
className="px-6 py-2 bg-green-500 hover:bg-green-600 disabled:bg-gray-300 disabled:cursor-not-allowed text-white rounded-lg transition-colors flex items-center gap-2"
>
{isGenerating ? (
<>
<Loader2 className="w-4 h-4 animate-spin" />
Generating...
</>
) : (
<>
<Send className="w-4 h-4" />
Generate Text
</>
)}
</button>
</div>
<div className="flex-1 overflow-y-auto border border-gray-300 rounded-lg p-4 bg-white">
<div className="mb-2">
<label className="block text-sm font-medium text-gray-700">
Generated Text:
</label>
</div>
{generatedText ? (
<div className="whitespace-pre-wrap text-gray-800 bg-gray-50 p-3 rounded-sm border">
{generatedText}
</div>
) : (
<div className="text-gray-500 italic flex items-center gap-2">
{isGenerating ? (
<>
<Loader2 className="w-4 h-4 animate-spin" />
Generating text...
</>
) : (
'Generated text will appear here'
)}
</div>
)}
<div ref={messagesEndRef} />
</div>
</>
)}
{!hasBeenLoaded && (
<div className="text-center text-gray-500 text-sm mt-2">
Please load a model first to start{' '}
{hasChatTemplate ? 'chatting' : 'generating text'}
</div>
)}
</div>
)
}
export default TextGeneration