Vokturz commited on
Commit
fb852fe
·
1 Parent(s): 9283c8b

Add ModelContext for managing model loading state and progress

Browse files
src/App.tsx CHANGED
@@ -4,9 +4,11 @@ import ZeroShotClassification from './components/ZeroShotClassification';
4
  import TextClassification from './components/TextClassification';
5
  import Header from './Header';
6
  import Footer from './Footer';
 
7
 
8
  function App() {
9
  const [pipeline, setPipeline] = useState('zero-shot-classification');
 
10
 
11
  return (
12
  <div className="min-h-screen bg-gradient-to-br from-blue-50 to-indigo-100">
@@ -21,6 +23,47 @@ function App() {
21
  </h2>
22
  <PipelineSelector pipeline={pipeline} setPipeline={setPipeline} />
23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  {/* Pipeline Description */}
25
  <div className="mt-4 p-4 bg-gray-50 rounded-lg">
26
  <div className="flex items-start space-x-3">
@@ -42,6 +85,11 @@ function App() {
42
  {pipeline === 'zero-shot-classification'
43
  ? 'Zero-Shot Classification'
44
  : 'Text-Classification'}
 
 
 
 
 
45
  </h3>
46
  <p className="text-sm text-gray-600 mt-1">
47
  {pipeline === 'zero-shot-classification'
 
4
  import TextClassification from './components/TextClassification';
5
  import Header from './Header';
6
  import Footer from './Footer';
7
+ import { useModel } from './contexts/ModelContext';
8
 
9
  function App() {
10
  const [pipeline, setPipeline] = useState('zero-shot-classification');
11
+ const { progress, status, model } = useModel();
12
 
13
  return (
14
  <div className="min-h-screen bg-gradient-to-br from-blue-50 to-indigo-100">
 
23
  </h2>
24
  <PipelineSelector pipeline={pipeline} setPipeline={setPipeline} />
25
 
26
+ {/* Model Loading Progress */}
27
+ {status === 'progress' && (
28
+ <div className="mt-4 p-4 bg-blue-50 rounded-lg">
29
+ <div className="flex items-center space-x-3">
30
+ <div className="flex-shrink-0">
31
+ <svg
32
+ className="animate-spin h-5 w-5 text-blue-500"
33
+ fill="none"
34
+ viewBox="0 0 24 24"
35
+ >
36
+ <circle
37
+ className="opacity-25"
38
+ cx="12"
39
+ cy="12"
40
+ r="10"
41
+ stroke="currentColor"
42
+ strokeWidth="4"
43
+ ></circle>
44
+ <path
45
+ className="opacity-75"
46
+ fill="currentColor"
47
+ d="M4 12a8 8 0 018-8V0C5.373 0 0 5.373 0 12h4zm2 5.291A7.962 7.962 0 014 12H0c0 3.042 1.135 5.824 3 7.938l3-2.647z"
48
+ ></path>
49
+ </svg>
50
+ </div>
51
+ <div className="flex-1">
52
+ <p className="text-sm font-medium text-blue-900">
53
+ Loading Model...
54
+ </p>
55
+ <div className="mt-2 bg-blue-200 rounded-full h-2">
56
+ <div
57
+ className="bg-blue-500 h-2 rounded-full transition-all duration-300"
58
+ style={{ width: `${progress.toFixed(2)}%` }}
59
+ ></div>
60
+ </div>
61
+ <p className="text-xs text-blue-700 mt-1">{progress.toFixed(2)}%</p>
62
+ </div>
63
+ </div>
64
+ </div>
65
+ )}
66
+
67
  {/* Pipeline Description */}
68
  <div className="mt-4 p-4 bg-gray-50 rounded-lg">
69
  <div className="flex items-start space-x-3">
 
85
  {pipeline === 'zero-shot-classification'
86
  ? 'Zero-Shot Classification'
87
  : 'Text-Classification'}
88
+ {model && (
89
+ <span className="ml-2 text-xs text-gray-500 font-normal">
90
+ ({model})
91
+ </span>
92
+ )}
93
  </h3>
94
  <p className="text-sm text-gray-600 mt-1">
95
  {pipeline === 'zero-shot-classification'
src/components/TextClassification.tsx CHANGED
@@ -4,6 +4,8 @@ import {
4
  TextClassificationWorkerInput,
5
  WorkerMessage
6
  } from '../types';
 
 
7
 
8
  const PLACEHOLDER_TEXTS: string[] = [
9
  'I absolutely love this product! It exceeded all my expectations.',
@@ -21,8 +23,8 @@ const PLACEHOLDER_TEXTS: string[] = [
21
  function TextClassification() {
22
  const [text, setText] = useState<string>(PLACEHOLDER_TEXTS.join('\n'));
23
  const [results, setResults] = useState<ClassificationOutput[]>([]);
24
- const [status, setStatus] = useState<string>('idle');
25
- const [progress, setProgress] = useState<number>(0);
26
 
27
  // Create a reference to the worker object.
28
  const worker = useRef<Worker | null>(null);
@@ -54,6 +56,7 @@ function TextClassification() {
54
  )
55
  setProgress(e.data.output.progress);
56
  } else if (status === 'output') {
 
57
  const result = e.data.output!;
58
  setResults((prevResults) => [...prevResults, result]);
59
  console.log(result);
@@ -111,9 +114,6 @@ function TextClassification() {
111
  ? 'Model loading...'
112
  : 'Processing...'}
113
  </button>
114
- {status === 'progress' && (
115
- <div className="text-sm font-medium">{progress}%</div>
116
- )}
117
  <button
118
  className="py-2 px-4 bg-gray-500 hover:bg-gray-600 rounded text-white font-medium transition-colors"
119
  onClick={handleClear}
 
4
  TextClassificationWorkerInput,
5
  WorkerMessage
6
  } from '../types';
7
+ import { useModel } from '../contexts/ModelContext';
8
+
9
 
10
  const PLACEHOLDER_TEXTS: string[] = [
11
  'I absolutely love this product! It exceeded all my expectations.',
 
23
  function TextClassification() {
24
  const [text, setText] = useState<string>(PLACEHOLDER_TEXTS.join('\n'));
25
  const [results, setResults] = useState<ClassificationOutput[]>([]);
26
+ const { setProgress, status, setStatus, setModel } = useModel();
27
+ setModel('Xenova/bert-base-multilingual-uncased-sentiment')
28
 
29
  // Create a reference to the worker object.
30
  const worker = useRef<Worker | null>(null);
 
56
  )
57
  setProgress(e.data.output.progress);
58
  } else if (status === 'output') {
59
+ setStatus('output');
60
  const result = e.data.output!;
61
  setResults((prevResults) => [...prevResults, result]);
62
  console.log(result);
 
114
  ? 'Model loading...'
115
  : 'Processing...'}
116
  </button>
 
 
 
117
  <button
118
  className="py-2 px-4 bg-gray-500 hover:bg-gray-600 rounded text-white font-medium transition-colors"
119
  onClick={handleClear}
src/components/ZeroShotClassification.tsx CHANGED
@@ -1,6 +1,7 @@
1
  // src/App.tsx
2
  import { useState, useRef, useEffect, useCallback } from 'react';
3
  import { Section, WorkerMessage, ZeroShotWorkerInput } from '../types';
 
4
 
5
  const PLACEHOLDER_REVIEWS: string[] = [
6
  // battery/charging problems
@@ -44,8 +45,8 @@ function ZeroShotClassification() {
44
  PLACEHOLDER_SECTIONS.map((title) => ({ title, items: [] }))
45
  );
46
 
47
- const [status, setStatus] = useState<string>('idle');
48
- const [progress, setProgress] = useState<number>(0);
49
 
50
  // Create a reference to the worker object.
51
  const worker = useRef<Worker | null>(null);
@@ -77,6 +78,7 @@ function ZeroShotClassification() {
77
  )
78
  setProgress(e.data.output.progress);
79
  } else if (status === 'output') {
 
80
  const { sequence, labels, scores } = e.data.output!;
81
 
82
  // Threshold for classification
@@ -175,9 +177,6 @@ function ZeroShotClassification() {
175
  ? 'Model loading...'
176
  : 'Processing'}
177
  </button>
178
- {status === 'progress' && (
179
- <div className="text-sm font-medium">{progress}%</div>
180
- )}
181
  <div className="flex gap-1">
182
  <button
183
  className="border py-1 px-2 bg-green-400 rounded text-white text-sm font-medium cursor-pointer"
 
1
  // src/App.tsx
2
  import { useState, useRef, useEffect, useCallback } from 'react';
3
  import { Section, WorkerMessage, ZeroShotWorkerInput } from '../types';
4
+ import { useModel } from '../contexts/ModelContext';
5
 
6
  const PLACEHOLDER_REVIEWS: string[] = [
7
  // battery/charging problems
 
45
  PLACEHOLDER_SECTIONS.map((title) => ({ title, items: [] }))
46
  );
47
 
48
+ const { setProgress, status, setStatus, setModel } = useModel();
49
+ setModel('MoritzLaurer/deberta-v3-xsmall-zeroshot-v1.1-all-33')
50
 
51
  // Create a reference to the worker object.
52
  const worker = useRef<Worker | null>(null);
 
78
  )
79
  setProgress(e.data.output.progress);
80
  } else if (status === 'output') {
81
+ setStatus('output');
82
  const { sequence, labels, scores } = e.data.output!;
83
 
84
  // Threshold for classification
 
177
  ? 'Model loading...'
178
  : 'Processing'}
179
  </button>
 
 
 
180
  <div className="flex gap-1">
181
  <button
182
  className="border py-1 px-2 bg-green-400 rounded text-white text-sm font-medium cursor-pointer"
src/contexts/ModelContext.tsx ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import React, { createContext, useContext, useEffect, useState } from 'react';
2
+
3
+ interface ModelContextType {
4
+ progress: number;
5
+ status: string;
6
+ setProgress: (progress: number) => void;
7
+ setStatus: (status: string) => void;
8
+ model: string;
9
+ setModel: (model: string) => void;
10
+ }
11
+
12
+ const ModelContext = createContext<ModelContextType | undefined>(undefined);
13
+
14
+ export function ModelProvider({ children }: { children: React.ReactNode }) {
15
+ const [progress, setProgress] = useState<number>(0);
16
+ const [status, setStatus] = useState<string>('idle');
17
+ const [model, setModel] = useState<string>('');
18
+
19
+ // set progress to 0 when model is changed
20
+ useEffect(() => {
21
+ setProgress(0);
22
+ }, [model]);
23
+
24
+ return (
25
+ <ModelContext.Provider
26
+ value={{ progress, setProgress, status, setStatus, model, setModel }}
27
+ >
28
+ {children}
29
+ </ModelContext.Provider>
30
+ );
31
+ }
32
+
33
+ export function useModel() {
34
+ const context = useContext(ModelContext);
35
+ if (context === undefined) {
36
+ throw new Error('useModel must be used within a ModelProvider');
37
+ }
38
+ return context;
39
+ }
src/index.tsx CHANGED
@@ -3,14 +3,18 @@ import ReactDOM from 'react-dom/client';
3
  import './index.css';
4
  import App from './App';
5
  import reportWebVitals from './reportWebVitals';
 
 
6
 
7
  const root = ReactDOM.createRoot(
8
  document.getElementById('root') as HTMLElement
9
  );
10
  root.render(
 
11
  <React.StrictMode>
12
  <App />
13
  </React.StrictMode>
 
14
  );
15
 
16
  // If you want to start measuring performance in your app, pass a function
 
3
  import './index.css';
4
  import App from './App';
5
  import reportWebVitals from './reportWebVitals';
6
+ import { ModelProvider } from './contexts/ModelContext';
7
+
8
 
9
  const root = ReactDOM.createRoot(
10
  document.getElementById('root') as HTMLElement
11
  );
12
  root.render(
13
+ <ModelProvider>
14
  <React.StrictMode>
15
  <App />
16
  </React.StrictMode>
17
+ </ModelProvider>
18
  );
19
 
20
  // If you want to start measuring performance in your app, pass a function