Vokturz commited on
Commit
673d22a
·
1 Parent(s): 97cab0c

Refactor modelInfo handling across components for improved null safety

Browse files
src/App.tsx CHANGED
@@ -3,7 +3,6 @@ import PipelineSelector from './components/PipelineSelector'
3
  import ZeroShotClassification from './components/ZeroShotClassification'
4
  import TextClassification from './components/TextClassification'
5
  import Header from './Header'
6
- import Footer from './Footer'
7
  import { useModel } from './contexts/ModelContext'
8
  import { getModelsByPipeline } from './lib/huggingface'
9
  import ModelSelector from './components/ModelSelector'
@@ -11,9 +10,10 @@ import ModelInfo from './components/ModelInfo'
11
  import ModelReadme from './components/ModelReadme'
12
 
13
  function App() {
14
- const { pipeline, setPipeline, setModels, modelInfo } = useModel()
15
 
16
  useEffect(() => {
 
17
  const fetchModels = async () => {
18
  const fetchedModels = await getModelsByPipeline(pipeline)
19
  setModels(fetchedModels)
 
3
  import ZeroShotClassification from './components/ZeroShotClassification'
4
  import TextClassification from './components/TextClassification'
5
  import Header from './Header'
 
6
  import { useModel } from './contexts/ModelContext'
7
  import { getModelsByPipeline } from './lib/huggingface'
8
  import ModelSelector from './components/ModelSelector'
 
10
  import ModelReadme from './components/ModelReadme'
11
 
12
  function App() {
13
+ const { pipeline, setPipeline, setModels, setModelInfo, modelInfo } = useModel()
14
 
15
  useEffect(() => {
16
+ setModelInfo(null)
17
  const fetchModels = async () => {
18
  const fetchedModels = await getModelsByPipeline(pipeline)
19
  setModels(fetchedModels)
src/components/ModelInfo.tsx CHANGED
@@ -64,7 +64,7 @@ const ModelInfo = () => {
64
  </div>
65
  )
66
 
67
- if (!modelInfo.name) {
68
  return <ModelInfoSkeleton />
69
  }
70
 
 
64
  </div>
65
  )
66
 
67
+ if (!modelInfo) {
68
  return <ModelInfoSkeleton />
69
  }
70
 
src/components/ModelLoader.tsx CHANGED
@@ -19,6 +19,8 @@ const ModelLoader = () => {
19
  } = useModel()
20
 
21
  useEffect(() => {
 
 
22
  if (modelInfo.isCompatible && modelInfo.supportedQuantizations.length > 0) {
23
  const quantizations = modelInfo.supportedQuantizations
24
  let defaultQuant: QuantizationType = 'fp32'
@@ -34,12 +36,13 @@ const ModelLoader = () => {
34
  setSelectedQuantization(defaultQuant)
35
  }
36
  }, [
37
- modelInfo.supportedQuantizations,
38
- modelInfo.isCompatible,
39
  setSelectedQuantization
40
  ])
41
 
42
  useEffect(() => {
 
 
43
  const newWorker = getWorker(pipeline)
44
  if (!newWorker) {
45
  return
@@ -70,10 +73,10 @@ const ModelLoader = () => {
70
  newWorker.removeEventListener('message', onMessageReceived)
71
  // terminateWorker(pipeline);
72
  }
73
- }, [pipeline, modelInfo.name, selectedQuantization, setActiveWorker, setStatus, setProgress])
74
 
75
  const loadModel = useCallback(() => {
76
- if (!modelInfo.name || !selectedQuantization) return
77
 
78
  setStatus('loading')
79
  const message = {
@@ -82,12 +85,12 @@ const ModelLoader = () => {
82
  quantization: selectedQuantization
83
  }
84
  activeWorker?.postMessage(message)
85
- }, [modelInfo.name, selectedQuantization, setStatus, activeWorker])
86
 
87
  const ready: boolean = status === 'ready'
88
  const busy: boolean = status === 'loading'
89
 
90
- if (!modelInfo.isCompatible || modelInfo.supportedQuantizations.length === 0) {
91
  return null
92
  }
93
 
 
19
  } = useModel()
20
 
21
  useEffect(() => {
22
+ if (!modelInfo) return
23
+
24
  if (modelInfo.isCompatible && modelInfo.supportedQuantizations.length > 0) {
25
  const quantizations = modelInfo.supportedQuantizations
26
  let defaultQuant: QuantizationType = 'fp32'
 
36
  setSelectedQuantization(defaultQuant)
37
  }
38
  }, [
39
+ modelInfo,
 
40
  setSelectedQuantization
41
  ])
42
 
43
  useEffect(() => {
44
+ if (!modelInfo) return
45
+
46
  const newWorker = getWorker(pipeline)
47
  if (!newWorker) {
48
  return
 
73
  newWorker.removeEventListener('message', onMessageReceived)
74
  // terminateWorker(pipeline);
75
  }
76
+ }, [pipeline, modelInfo, selectedQuantization, setActiveWorker, setStatus, setProgress])
77
 
78
  const loadModel = useCallback(() => {
79
+ if (!modelInfo || !selectedQuantization) return
80
 
81
  setStatus('loading')
82
  const message = {
 
85
  quantization: selectedQuantization
86
  }
87
  activeWorker?.postMessage(message)
88
+ }, [modelInfo, selectedQuantization, setStatus, activeWorker])
89
 
90
  const ready: boolean = status === 'ready'
91
  const busy: boolean = status === 'loading'
92
 
93
+ if (!modelInfo?.isCompatible || modelInfo.supportedQuantizations.length === 0) {
94
  return null
95
  }
96
 
src/components/ModelSelector.tsx CHANGED
@@ -119,7 +119,7 @@ const ModelSelector: React.FC = () => {
119
  }
120
 
121
  const selectedModel =
122
- models.find((model) => model.id === modelInfo.id) || models[0]
123
 
124
  return (
125
  <div className="relative">
@@ -132,7 +132,7 @@ const ModelSelector: React.FC = () => {
132
  <div className="flex items-center justify-between w-full">
133
  <div className="flex flex-col flex-1 min-w-0">
134
  <span className="truncate font-medium">
135
- {modelInfo.id || 'Select a model'}
136
  </span>
137
  </div>
138
 
 
119
  }
120
 
121
  const selectedModel =
122
+ models.find((model) => model.id === modelInfo?.id) || models[0]
123
 
124
  return (
125
  <div className="relative">
 
132
  <div className="flex items-center justify-between w-full">
133
  <div className="flex flex-col flex-1 min-w-0">
134
  <span className="truncate font-medium">
135
+ {modelInfo?.id || 'Select a model'}
136
  </span>
137
  </div>
138
 
src/components/TextClassification.tsx CHANGED
@@ -58,6 +58,7 @@ function TextClassification() {
58
  }, [setStatus])
59
 
60
  const classify = useCallback(() => {
 
61
  setStatus('loading')
62
  setResults([]) // Clear previous results
63
  const message: TextClassificationWorkerInput = {
@@ -66,7 +67,7 @@ function TextClassification() {
66
  model: modelInfo.id
67
  }
68
  workerRef.current?.postMessage(message)
69
- }, [text, modelInfo.id, setStatus])
70
 
71
  const busy: boolean = status !== 'ready'
72
 
 
58
  }, [setStatus])
59
 
60
  const classify = useCallback(() => {
61
+ if (!modelInfo) return
62
  setStatus('loading')
63
  setResults([]) // Clear previous results
64
  const message: TextClassificationWorkerInput = {
 
67
  model: modelInfo.id
68
  }
69
  workerRef.current?.postMessage(message)
70
+ }, [text, modelInfo, setStatus])
71
 
72
  const busy: boolean = status !== 'ready'
73
 
src/components/ZeroShotClassification.tsx CHANGED
@@ -104,6 +104,8 @@ function ZeroShotClassification() {
104
  }, [sections])
105
 
106
  const classify = useCallback(() => {
 
 
107
  setStatus('loading')
108
  const message: ZeroShotWorkerInput = {
109
  text,
@@ -113,7 +115,7 @@ function ZeroShotClassification() {
113
  model: modelInfo.name
114
  }
115
  worker.current?.postMessage(message)
116
- }, [text, sections, modelInfo.name])
117
 
118
  const busy: boolean = status !== 'ready'
119
 
 
104
  }, [sections])
105
 
106
  const classify = useCallback(() => {
107
+ if (!modelInfo) return
108
+
109
  setStatus('loading')
110
  const message: ZeroShotWorkerInput = {
111
  text,
 
115
  model: modelInfo.name
116
  }
117
  worker.current?.postMessage(message)
118
+ }, [text, sections, modelInfo])
119
 
120
  const busy: boolean = status !== 'ready'
121
 
src/contexts/ModelContext.tsx CHANGED
@@ -16,8 +16,8 @@ interface ModelContextType {
16
  setStatus: (status: WorkerStatus) => void
17
  progress: number
18
  setProgress: (progress: number) => void
19
- modelInfo: ModelInfo
20
- setModelInfo: (model: ModelInfo) => void
21
  pipeline: string
22
  setPipeline: (pipeline: string) => void
23
  models: ModelInfoResponse[]
@@ -33,7 +33,7 @@ const ModelContext = createContext<ModelContextType | undefined>(undefined)
33
  export function ModelProvider({ children }: { children: React.ReactNode }) {
34
  const [progress, setProgress] = useState<number>(0)
35
  const [status, setStatus] = useState<WorkerStatus>('initiate')
36
- const [modelInfo, setModelInfo] = useState<ModelInfo>({} as ModelInfo)
37
  const [models, setModels] = useState<ModelInfoResponse[]>(
38
  [] as ModelInfoResponse[]
39
  )
@@ -45,7 +45,7 @@ export function ModelProvider({ children }: { children: React.ReactNode }) {
45
  // set progress to 0 when model is changed
46
  useEffect(() => {
47
  setProgress(0)
48
- }, [modelInfo.name])
49
 
50
  return (
51
  <ModelContext.Provider
 
16
  setStatus: (status: WorkerStatus) => void
17
  progress: number
18
  setProgress: (progress: number) => void
19
+ modelInfo: ModelInfo | null
20
+ setModelInfo: (model: ModelInfo | null) => void
21
  pipeline: string
22
  setPipeline: (pipeline: string) => void
23
  models: ModelInfoResponse[]
 
33
  export function ModelProvider({ children }: { children: React.ReactNode }) {
34
  const [progress, setProgress] = useState<number>(0)
35
  const [status, setStatus] = useState<WorkerStatus>('initiate')
36
+ const [modelInfo, setModelInfo] = useState<ModelInfo | null>(null)
37
  const [models, setModels] = useState<ModelInfoResponse[]>(
38
  [] as ModelInfoResponse[]
39
  )
 
45
  // set progress to 0 when model is changed
46
  useEffect(() => {
47
  setProgress(0)
48
+ }, [modelInfo?.name])
49
 
50
  return (
51
  <ModelContext.Provider