File size: 5,466 Bytes
ad5cef3 4d810fa ad5cef3 bd915ca 2f35054 4d810fa ad5cef3 673d22a 4d810fa ad5cef3 4d810fa ad5cef3 bd915ca ad5cef3 673d22a ad5cef3 4d810fa ad5cef3 bd915ca 4d810fa ad5cef3 4d810fa ad5cef3 4d810fa ad5cef3 673d22a ad5cef3 4d810fa ad5cef3 4d810fa ad5cef3 4d810fa ad5cef3 4d810fa ad5cef3 bd915ca ad5cef3 4d810fa ad5cef3 4d810fa ad5cef3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 |
import { useEffect, useCallback } from 'react'
import { ChevronDown, Loader } from 'lucide-react'
import { QuantizationType, WorkerMessage } from '../types'
import { useModel } from '../contexts/ModelContext'
import { getWorker } from '../lib/workerManager'
const ModelLoader = () => {
const {
modelInfo,
selectedQuantization,
setSelectedQuantization,
status,
progress,
setStatus,
setProgress,
activeWorker,
setActiveWorker,
pipeline,
setResults,
hasBeenLoaded,
setHasBeenLoaded
} = useModel()
useEffect(() => {
setHasBeenLoaded(false)
}, [selectedQuantization, setHasBeenLoaded])
useEffect(() => {
if (!modelInfo) return
if (modelInfo.isCompatible) {
const quantizations = modelInfo.supportedQuantizations
let defaultQuant: QuantizationType = 'fp32'
if (quantizations.includes('int8')) {
defaultQuant = 'int8'
} else if (quantizations.includes('q8')) {
defaultQuant = 'q8'
} else if (quantizations.includes('q4')) {
defaultQuant = 'q4'
}
setSelectedQuantization(defaultQuant)
}
setHasBeenLoaded(false)
}, [modelInfo, setSelectedQuantization, setHasBeenLoaded])
useEffect(() => {
if (!modelInfo) return
const newWorker = getWorker(pipeline)
if (!newWorker) {
return
}
if (!hasBeenLoaded) {
setStatus('initiate')
setActiveWorker(newWorker)
}
const onMessageReceived = (e: MessageEvent<WorkerMessage>) => {
const { status, output } = e.data
if (status === 'ready') {
setStatus('ready')
if (e.data.output) console.log(e.data.output)
setHasBeenLoaded(true)
} else if (status === 'loading' && output && !hasBeenLoaded) {
setStatus('loading')
if (
output.progress &&
typeof output.file === 'string' &&
output.file.startsWith('onnx')
) {
setProgress(output.progress)
}
} else if (status === 'output') {
setStatus('output')
const result = e.data.output!
setResults((prev: any[]) => [...prev, result])
// console.log(result)
} else if (status === 'error') {
setStatus('error')
console.error(e.data.output)
}
}
newWorker.addEventListener('message', onMessageReceived)
return () => {
newWorker.removeEventListener('message', onMessageReceived)
// terminateWorker(pipeline);
}
}, [
pipeline,
modelInfo,
selectedQuantization,
setActiveWorker,
setStatus,
setProgress,
setResults,
hasBeenLoaded,
setHasBeenLoaded
])
const loadModel = useCallback(() => {
if (!modelInfo || !selectedQuantization) return
const message = {
type: 'load',
model: modelInfo.name,
dtype: selectedQuantization ?? 'fp32'
}
activeWorker?.postMessage(message)
}, [modelInfo, selectedQuantization, activeWorker])
if (!modelInfo?.isCompatible) {
return null
}
return (
<div className="space-y-3">
<hr className="border-gray-200" />
<div className="flex items-center justify-between space-x-4">
<div className="flex items-center space-x-2">
{modelInfo.supportedQuantizations.length > 1 ? (
<>
<span className="text-xs text-gray-600 font-medium">
Quantization:
</span>
<div className="relative">
<select
value={selectedQuantization || ''}
onChange={(e) =>
setSelectedQuantization(e.target.value as QuantizationType)
}
className="appearance-none bg-white border border-gray-300 rounded-md px-3 py-1 pr-8 text-xs text-gray-700 focus:outline-none focus:ring-2 focus:ring-blue-500 focus:border-blue-500"
>
<option value="">Select quantization</option>
{modelInfo.supportedQuantizations.map((quant) => (
<option key={quant} value={quant}>
{quant}
</option>
))}
</select>
<ChevronDown className="absolute right-2 top-1/2 transform -translate-y-1/2 w-3 h-3 text-gray-400 pointer-events-none" />
</div>
</>
) : (
<span className="text-xs text-gray-600 font-medium white-space-break-spaces">
No quantization available. Using fp32
</span>
)}
</div>
{selectedQuantization && (
<div className="flex justify-center">
<button
className="w-32 py-2 px-4 bg-green-500 hover:bg-green-600 rounded text-white font-medium disabled:opacity-50 disabled:cursor-not-allowed transition-colors text-sm inline-flex items-center text-center justify-center space-x-2"
disabled={hasBeenLoaded || status === 'loading'}
onClick={loadModel}
>
{status === 'loading' && !hasBeenLoaded ? (
<>
<Loader className="animate-spin h-4 w-4" />
<span>{progress.toFixed(0)}%</span>
</>
) : (
<span>{!hasBeenLoaded ? 'Load Model' : 'Model Ready'}</span>
)}
</button>
</div>
)}
</div>
</div>
)
}
export default ModelLoader
|