Refactor modelInfo handling across components for improved null safety
Browse files
src/App.tsx
CHANGED
@@ -3,7 +3,6 @@ import PipelineSelector from './components/PipelineSelector'
|
|
3 |
import ZeroShotClassification from './components/ZeroShotClassification'
|
4 |
import TextClassification from './components/TextClassification'
|
5 |
import Header from './Header'
|
6 |
-
import Footer from './Footer'
|
7 |
import { useModel } from './contexts/ModelContext'
|
8 |
import { getModelsByPipeline } from './lib/huggingface'
|
9 |
import ModelSelector from './components/ModelSelector'
|
@@ -11,9 +10,10 @@ import ModelInfo from './components/ModelInfo'
|
|
11 |
import ModelReadme from './components/ModelReadme'
|
12 |
|
13 |
function App() {
|
14 |
-
const { pipeline, setPipeline, setModels, modelInfo } = useModel()
|
15 |
|
16 |
useEffect(() => {
|
|
|
17 |
const fetchModels = async () => {
|
18 |
const fetchedModels = await getModelsByPipeline(pipeline)
|
19 |
setModels(fetchedModels)
|
|
|
3 |
import ZeroShotClassification from './components/ZeroShotClassification'
|
4 |
import TextClassification from './components/TextClassification'
|
5 |
import Header from './Header'
|
|
|
6 |
import { useModel } from './contexts/ModelContext'
|
7 |
import { getModelsByPipeline } from './lib/huggingface'
|
8 |
import ModelSelector from './components/ModelSelector'
|
|
|
10 |
import ModelReadme from './components/ModelReadme'
|
11 |
|
12 |
function App() {
|
13 |
+
const { pipeline, setPipeline, setModels, setModelInfo, modelInfo } = useModel()
|
14 |
|
15 |
useEffect(() => {
|
16 |
+
setModelInfo(null)
|
17 |
const fetchModels = async () => {
|
18 |
const fetchedModels = await getModelsByPipeline(pipeline)
|
19 |
setModels(fetchedModels)
|
src/components/ModelInfo.tsx
CHANGED
@@ -64,7 +64,7 @@ const ModelInfo = () => {
|
|
64 |
</div>
|
65 |
)
|
66 |
|
67 |
-
if (!modelInfo
|
68 |
return <ModelInfoSkeleton />
|
69 |
}
|
70 |
|
|
|
64 |
</div>
|
65 |
)
|
66 |
|
67 |
+
if (!modelInfo) {
|
68 |
return <ModelInfoSkeleton />
|
69 |
}
|
70 |
|
src/components/ModelLoader.tsx
CHANGED
@@ -19,6 +19,8 @@ const ModelLoader = () => {
|
|
19 |
} = useModel()
|
20 |
|
21 |
useEffect(() => {
|
|
|
|
|
22 |
if (modelInfo.isCompatible && modelInfo.supportedQuantizations.length > 0) {
|
23 |
const quantizations = modelInfo.supportedQuantizations
|
24 |
let defaultQuant: QuantizationType = 'fp32'
|
@@ -34,12 +36,13 @@ const ModelLoader = () => {
|
|
34 |
setSelectedQuantization(defaultQuant)
|
35 |
}
|
36 |
}, [
|
37 |
-
modelInfo
|
38 |
-
modelInfo.isCompatible,
|
39 |
setSelectedQuantization
|
40 |
])
|
41 |
|
42 |
useEffect(() => {
|
|
|
|
|
43 |
const newWorker = getWorker(pipeline)
|
44 |
if (!newWorker) {
|
45 |
return
|
@@ -70,10 +73,10 @@ const ModelLoader = () => {
|
|
70 |
newWorker.removeEventListener('message', onMessageReceived)
|
71 |
// terminateWorker(pipeline);
|
72 |
}
|
73 |
-
}, [pipeline, modelInfo
|
74 |
|
75 |
const loadModel = useCallback(() => {
|
76 |
-
if (!modelInfo
|
77 |
|
78 |
setStatus('loading')
|
79 |
const message = {
|
@@ -82,12 +85,12 @@ const ModelLoader = () => {
|
|
82 |
quantization: selectedQuantization
|
83 |
}
|
84 |
activeWorker?.postMessage(message)
|
85 |
-
}, [modelInfo
|
86 |
|
87 |
const ready: boolean = status === 'ready'
|
88 |
const busy: boolean = status === 'loading'
|
89 |
|
90 |
-
if (!modelInfo
|
91 |
return null
|
92 |
}
|
93 |
|
|
|
19 |
} = useModel()
|
20 |
|
21 |
useEffect(() => {
|
22 |
+
if (!modelInfo) return
|
23 |
+
|
24 |
if (modelInfo.isCompatible && modelInfo.supportedQuantizations.length > 0) {
|
25 |
const quantizations = modelInfo.supportedQuantizations
|
26 |
let defaultQuant: QuantizationType = 'fp32'
|
|
|
36 |
setSelectedQuantization(defaultQuant)
|
37 |
}
|
38 |
}, [
|
39 |
+
modelInfo,
|
|
|
40 |
setSelectedQuantization
|
41 |
])
|
42 |
|
43 |
useEffect(() => {
|
44 |
+
if (!modelInfo) return
|
45 |
+
|
46 |
const newWorker = getWorker(pipeline)
|
47 |
if (!newWorker) {
|
48 |
return
|
|
|
73 |
newWorker.removeEventListener('message', onMessageReceived)
|
74 |
// terminateWorker(pipeline);
|
75 |
}
|
76 |
+
}, [pipeline, modelInfo, selectedQuantization, setActiveWorker, setStatus, setProgress])
|
77 |
|
78 |
const loadModel = useCallback(() => {
|
79 |
+
if (!modelInfo || !selectedQuantization) return
|
80 |
|
81 |
setStatus('loading')
|
82 |
const message = {
|
|
|
85 |
quantization: selectedQuantization
|
86 |
}
|
87 |
activeWorker?.postMessage(message)
|
88 |
+
}, [modelInfo, selectedQuantization, setStatus, activeWorker])
|
89 |
|
90 |
const ready: boolean = status === 'ready'
|
91 |
const busy: boolean = status === 'loading'
|
92 |
|
93 |
+
if (!modelInfo?.isCompatible || modelInfo.supportedQuantizations.length === 0) {
|
94 |
return null
|
95 |
}
|
96 |
|
src/components/ModelSelector.tsx
CHANGED
@@ -119,7 +119,7 @@ const ModelSelector: React.FC = () => {
|
|
119 |
}
|
120 |
|
121 |
const selectedModel =
|
122 |
-
models.find((model) => model.id === modelInfo
|
123 |
|
124 |
return (
|
125 |
<div className="relative">
|
@@ -132,7 +132,7 @@ const ModelSelector: React.FC = () => {
|
|
132 |
<div className="flex items-center justify-between w-full">
|
133 |
<div className="flex flex-col flex-1 min-w-0">
|
134 |
<span className="truncate font-medium">
|
135 |
-
{modelInfo
|
136 |
</span>
|
137 |
</div>
|
138 |
|
|
|
119 |
}
|
120 |
|
121 |
const selectedModel =
|
122 |
+
models.find((model) => model.id === modelInfo?.id) || models[0]
|
123 |
|
124 |
return (
|
125 |
<div className="relative">
|
|
|
132 |
<div className="flex items-center justify-between w-full">
|
133 |
<div className="flex flex-col flex-1 min-w-0">
|
134 |
<span className="truncate font-medium">
|
135 |
+
{modelInfo?.id || 'Select a model'}
|
136 |
</span>
|
137 |
</div>
|
138 |
|
src/components/TextClassification.tsx
CHANGED
@@ -58,6 +58,7 @@ function TextClassification() {
|
|
58 |
}, [setStatus])
|
59 |
|
60 |
const classify = useCallback(() => {
|
|
|
61 |
setStatus('loading')
|
62 |
setResults([]) // Clear previous results
|
63 |
const message: TextClassificationWorkerInput = {
|
@@ -66,7 +67,7 @@ function TextClassification() {
|
|
66 |
model: modelInfo.id
|
67 |
}
|
68 |
workerRef.current?.postMessage(message)
|
69 |
-
}, [text, modelInfo
|
70 |
|
71 |
const busy: boolean = status !== 'ready'
|
72 |
|
|
|
58 |
}, [setStatus])
|
59 |
|
60 |
const classify = useCallback(() => {
|
61 |
+
if (!modelInfo) return
|
62 |
setStatus('loading')
|
63 |
setResults([]) // Clear previous results
|
64 |
const message: TextClassificationWorkerInput = {
|
|
|
67 |
model: modelInfo.id
|
68 |
}
|
69 |
workerRef.current?.postMessage(message)
|
70 |
+
}, [text, modelInfo, setStatus])
|
71 |
|
72 |
const busy: boolean = status !== 'ready'
|
73 |
|
src/components/ZeroShotClassification.tsx
CHANGED
@@ -104,6 +104,8 @@ function ZeroShotClassification() {
|
|
104 |
}, [sections])
|
105 |
|
106 |
const classify = useCallback(() => {
|
|
|
|
|
107 |
setStatus('loading')
|
108 |
const message: ZeroShotWorkerInput = {
|
109 |
text,
|
@@ -113,7 +115,7 @@ function ZeroShotClassification() {
|
|
113 |
model: modelInfo.name
|
114 |
}
|
115 |
worker.current?.postMessage(message)
|
116 |
-
}, [text, sections, modelInfo
|
117 |
|
118 |
const busy: boolean = status !== 'ready'
|
119 |
|
|
|
104 |
}, [sections])
|
105 |
|
106 |
const classify = useCallback(() => {
|
107 |
+
if (!modelInfo) return
|
108 |
+
|
109 |
setStatus('loading')
|
110 |
const message: ZeroShotWorkerInput = {
|
111 |
text,
|
|
|
115 |
model: modelInfo.name
|
116 |
}
|
117 |
worker.current?.postMessage(message)
|
118 |
+
}, [text, sections, modelInfo])
|
119 |
|
120 |
const busy: boolean = status !== 'ready'
|
121 |
|
src/contexts/ModelContext.tsx
CHANGED
@@ -16,8 +16,8 @@ interface ModelContextType {
|
|
16 |
setStatus: (status: WorkerStatus) => void
|
17 |
progress: number
|
18 |
setProgress: (progress: number) => void
|
19 |
-
modelInfo: ModelInfo
|
20 |
-
setModelInfo: (model: ModelInfo) => void
|
21 |
pipeline: string
|
22 |
setPipeline: (pipeline: string) => void
|
23 |
models: ModelInfoResponse[]
|
@@ -33,7 +33,7 @@ const ModelContext = createContext<ModelContextType | undefined>(undefined)
|
|
33 |
export function ModelProvider({ children }: { children: React.ReactNode }) {
|
34 |
const [progress, setProgress] = useState<number>(0)
|
35 |
const [status, setStatus] = useState<WorkerStatus>('initiate')
|
36 |
-
const [modelInfo, setModelInfo] = useState<ModelInfo>(
|
37 |
const [models, setModels] = useState<ModelInfoResponse[]>(
|
38 |
[] as ModelInfoResponse[]
|
39 |
)
|
@@ -45,7 +45,7 @@ export function ModelProvider({ children }: { children: React.ReactNode }) {
|
|
45 |
// set progress to 0 when model is changed
|
46 |
useEffect(() => {
|
47 |
setProgress(0)
|
48 |
-
}, [modelInfo
|
49 |
|
50 |
return (
|
51 |
<ModelContext.Provider
|
|
|
16 |
setStatus: (status: WorkerStatus) => void
|
17 |
progress: number
|
18 |
setProgress: (progress: number) => void
|
19 |
+
modelInfo: ModelInfo | null
|
20 |
+
setModelInfo: (model: ModelInfo | null) => void
|
21 |
pipeline: string
|
22 |
setPipeline: (pipeline: string) => void
|
23 |
models: ModelInfoResponse[]
|
|
|
33 |
export function ModelProvider({ children }: { children: React.ReactNode }) {
|
34 |
const [progress, setProgress] = useState<number>(0)
|
35 |
const [status, setStatus] = useState<WorkerStatus>('initiate')
|
36 |
+
const [modelInfo, setModelInfo] = useState<ModelInfo | null>(null)
|
37 |
const [models, setModels] = useState<ModelInfoResponse[]>(
|
38 |
[] as ModelInfoResponse[]
|
39 |
)
|
|
|
45 |
// set progress to 0 when model is changed
|
46 |
useEffect(() => {
|
47 |
setProgress(0)
|
48 |
+
}, [modelInfo?.name])
|
49 |
|
50 |
return (
|
51 |
<ModelContext.Provider
|