File size: 6,720 Bytes
0250375 31283f8 ad5cef3 63dbafb 0250375 ad5cef3 31283f8 ad5cef3 4d810fa 63dbafb ad5cef3 bd915ca 2f35054 4d810fa ad5cef3 673d22a 4d810fa ad5cef3 4d810fa ad5cef3 673d22a ad5cef3 4d810fa 63dbafb 4d810fa 63dbafb 4d810fa ad5cef3 bd915ca 4d810fa ad5cef3 4d810fa 0250375 63dbafb 31283f8 63dbafb 0250375 31283f8 63dbafb ad5cef3 25647ae ad5cef3 4d810fa 63dbafb 4d810fa ad5cef3 31283f8 ad5cef3 673d22a ad5cef3 79eafc9 ad5cef3 4d810fa ad5cef3 4d810fa ad5cef3 79eafc9 4d810fa f3b30b4 4d810fa 5541427 4d810fa ad5cef3 63dbafb ad5cef3 4d810fa ad5cef3 31283f8 ad5cef3 63dbafb 4d810fa 63dbafb ad5cef3 31283f8 0250375 31283f8 0250375 ad5cef3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 |
import { useEffect, useCallback, useState } from 'react'
import { ChevronDown, Loader2, X } from 'lucide-react'
import { QuantizationType, WorkerMessage } from '../types'
import { useModel } from '../contexts/ModelContext'
import { getWorker, terminateWorker } from '../lib/workerManager'
import { Alert, AlertDescription } from './ui/alert'
const ModelLoader = () => {
const [showAlert, setShowAlert] = useState(false)
const [alertMessage, setAlertMessage] = useState<React.ReactNode>('')
const {
modelInfo,
selectedQuantization,
setSelectedQuantization,
status,
progress,
setStatus,
setProgress,
activeWorker,
setActiveWorker,
pipeline,
hasBeenLoaded,
setHasBeenLoaded,
setErrorText
} = useModel()
useEffect(() => {
setHasBeenLoaded(false)
}, [selectedQuantization, setHasBeenLoaded])
useEffect(() => {
if (!modelInfo) return
if (modelInfo.isCompatible) {
const quantizations = modelInfo.supportedQuantizations
let defaultQuant: QuantizationType = 'fp32'
if (quantizations.includes('int8')) {
defaultQuant = 'int8'
} else if (quantizations.includes('q8')) {
defaultQuant = 'q8'
} else if (quantizations.includes('q4')) {
defaultQuant = 'q4'
}
setSelectedQuantization(defaultQuant)
}
setHasBeenLoaded(false)
}, [modelInfo, setSelectedQuantization, setHasBeenLoaded])
useEffect(() => {
if (!modelInfo) return
const newWorker = getWorker(pipeline)
if (!newWorker) {
return
}
if (!hasBeenLoaded) {
setErrorText('')
setStatus('initiate')
setActiveWorker(newWorker)
setProgress(0)
}
const onMessageReceived = (e: MessageEvent<WorkerMessage>) => {
const { status, output } = e.data
if (status === 'ready') {
setStatus('ready')
if (e.data.output) console.log(e.data.output)
setHasBeenLoaded(true)
} else if (status === 'loading' && output && !hasBeenLoaded) {
setStatus('loading')
if (
output.progress &&
typeof output.file === 'string' &&
output.file.startsWith('onnx')
) {
setProgress(output.progress)
}
} else if (status === 'error') {
setStatus('error')
const error = e.data.output
console.error(error)
const errText = error.split(' WASM error: ')[1]
setErrorText(errText)
setShowAlert(true)
let time = 3000
if (!hasBeenLoaded)
setAlertMessage(error.split('.')[0] + '. See console for details.')
else {
setAlertMessage(`${errText}. Refresh the page and try again.`)
time = 5000
}
setTimeout(() => {
setShowAlert(false)
setAlertMessage('')
}, time)
}
}
newWorker.addEventListener('message', onMessageReceived)
return () => {
newWorker.removeEventListener('message', onMessageReceived)
// terminateWorker(pipeline)
}
}, [
pipeline,
modelInfo,
selectedQuantization,
setActiveWorker,
setStatus,
setProgress,
hasBeenLoaded,
setHasBeenLoaded,
setErrorText
])
useEffect(() => {
if (progress === 100) {
setTimeout(() => {
setShowAlert(false)
setAlertMessage('')
}, 2000)
}
}, [progress])
const loadModel = useCallback(() => {
if (!modelInfo || !selectedQuantization) return
const message = {
type: 'load',
model: modelInfo.name,
dtype: selectedQuantization ?? 'fp32',
isStyleTTS2:
modelInfo.isStyleTTS2 || modelInfo.name.includes('kitten-tts') || false // text-to-speech only
}
activeWorker?.postMessage(message)
}, [modelInfo, selectedQuantization, activeWorker])
if (!modelInfo?.isCompatible) {
return null
}
return (
<div className="space-y-3">
<hr className="border-gray-200" />
<div className="flex items-center justify-between space-x-4">
<div className="flex items-center space-x-2">
{modelInfo.supportedQuantizations.length >= 1 ? (
<>
<span className="text-xs text-gray-600 font-medium">Quant:</span>
<div className="relative">
<select
value={selectedQuantization || ''}
onChange={(e) =>
setSelectedQuantization(e.target.value as QuantizationType)
}
className="appearance-none bg-white border border-gray-300 rounded-md px-3 py-1 pr-8 text-xs text-gray-700 focus:outline-hidden focus:ring-2 focus:ring-blue-500 focus:border-blue-500"
>
{modelInfo.supportedQuantizations.map((quant) => (
<option key={quant} value={quant}>
{quant}
</option>
))}
</select>
<ChevronDown className="absolute right-2 top-1/2 transform -translate-y-1/2 w-3 h-3 text-gray-400 pointer-events-none" />
</div>
</>
) : (
<span className="text-xs text-gray-600 font-medium white-space-break-spaces">
No quantization available. Using fp32
</span>
)}
</div>
{selectedQuantization && (
<div className="flex justify-center">
<button
className={`w-32 py-2 px-4 ${status !== 'error' ? 'bg-green-500 hover:bg-green-600 cursor-pointer' : 'bg-red-500 hover:bg-red-600'} rounded-sm text-white font-medium disabled:opacity-50 disabled:cursor-not-allowed transition-colors text-sm inline-flex items-center text-center justify-center space-x-2`}
disabled={
hasBeenLoaded || status === 'loading' || status === 'error'
}
onClick={loadModel}
>
{status === 'loading' && !hasBeenLoaded ? (
<>
<Loader2 className="animate-spin h-4 w-4" />
<span>{progress.toFixed(0)}%</span>
</>
) : status !== 'error' ? (
<span>{!hasBeenLoaded ? 'Load Model' : 'Model Ready'}</span>
) : (
<span>Error</span>
)}
</button>
</div>
)}
</div>
{showAlert && (
<div className="fixed bottom-0 right-0 m-2">
<Alert
variant={`${typeof alertMessage === 'string' ? 'destructive' : 'default'}`}
>
<AlertDescription>{alertMessage}</AlertDescription>
</Alert>
</div>
)}
</div>
)
}
export default ModelLoader
|