import { useEffect, useCallback, useState } from 'react' import { ChevronDown, Loader2, X } from 'lucide-react' import { QuantizationType, WorkerMessage } from '../types' import { useModel } from '../contexts/ModelContext' import { getWorker, terminateWorker } from '../lib/workerManager' import { Alert, AlertDescription } from './ui/alert' const ModelLoader = () => { const [showAlert, setShowAlert] = useState(false) const [alertMessage, setAlertMessage] = useState('') const { modelInfo, selectedQuantization, setSelectedQuantization, status, progress, setStatus, setProgress, activeWorker, setActiveWorker, pipeline, hasBeenLoaded, setHasBeenLoaded, setErrorText } = useModel() useEffect(() => { setHasBeenLoaded(false) }, [selectedQuantization, setHasBeenLoaded]) useEffect(() => { if (!modelInfo) return if (modelInfo.isCompatible) { const quantizations = modelInfo.supportedQuantizations let defaultQuant: QuantizationType = 'fp32' if (quantizations.includes('int8')) { defaultQuant = 'int8' } else if (quantizations.includes('q8')) { defaultQuant = 'q8' } else if (quantizations.includes('q4')) { defaultQuant = 'q4' } setSelectedQuantization(defaultQuant) } setHasBeenLoaded(false) }, [modelInfo, setSelectedQuantization, setHasBeenLoaded]) useEffect(() => { if (!modelInfo) return const newWorker = getWorker(pipeline) if (!newWorker) { return } if (!hasBeenLoaded) { setErrorText('') setStatus('initiate') setActiveWorker(newWorker) setProgress(0) } const onMessageReceived = (e: MessageEvent) => { const { status, output } = e.data if (status === 'ready') { setStatus('ready') if (e.data.output) console.log(e.data.output) setHasBeenLoaded(true) } else if (status === 'loading' && output && !hasBeenLoaded) { setStatus('loading') if ( output.progress && typeof output.file === 'string' && output.file.startsWith('onnx') ) { setProgress(output.progress) } } else if (status === 'error') { setStatus('error') const error = e.data.output console.error(error) const errText = error.split(' WASM error: ')[1] setErrorText(errText) setShowAlert(true) let time = 3000 if (!hasBeenLoaded) setAlertMessage(error.split('.')[0] + '. See console for details.') else { setAlertMessage(`${errText}. Refresh the page and try again.`) time = 5000 } setTimeout(() => { setShowAlert(false) setAlertMessage('') }, time) } } newWorker.addEventListener('message', onMessageReceived) return () => { newWorker.removeEventListener('message', onMessageReceived) // terminateWorker(pipeline) } }, [ pipeline, modelInfo, selectedQuantization, setActiveWorker, setStatus, setProgress, hasBeenLoaded, setHasBeenLoaded, setErrorText ]) useEffect(() => { if (progress === 100) { setTimeout(() => { setShowAlert(false) setAlertMessage('') }, 2000) } }, [progress]) const loadModel = useCallback(() => { if (!modelInfo || !selectedQuantization) return const message = { type: 'load', model: modelInfo.name, dtype: selectedQuantization ?? 'fp32', isStyleTTS2: modelInfo.isStyleTTS2 || modelInfo.name.includes('kitten-tts') || false // text-to-speech only } activeWorker?.postMessage(message) }, [modelInfo, selectedQuantization, activeWorker]) if (!modelInfo?.isCompatible) { return null } return (

{modelInfo.supportedQuantizations.length >= 1 ? ( <> Quant:
) : ( No quantization available. Using fp32 )}
{selectedQuantization && (
)}
{showAlert && (
{alertMessage}
)}
) } export default ModelLoader