Add ModelContext for managing model loading state and progress
Browse files- src/App.tsx +48 -0
- src/components/TextClassification.tsx +5 -5
- src/components/ZeroShotClassification.tsx +4 -5
- src/contexts/ModelContext.tsx +39 -0
- src/index.tsx +4 -0
src/App.tsx
CHANGED
@@ -4,9 +4,11 @@ import ZeroShotClassification from './components/ZeroShotClassification';
|
|
4 |
import TextClassification from './components/TextClassification';
|
5 |
import Header from './Header';
|
6 |
import Footer from './Footer';
|
|
|
7 |
|
8 |
function App() {
|
9 |
const [pipeline, setPipeline] = useState('zero-shot-classification');
|
|
|
10 |
|
11 |
return (
|
12 |
<div className="min-h-screen bg-gradient-to-br from-blue-50 to-indigo-100">
|
@@ -21,6 +23,47 @@ function App() {
|
|
21 |
</h2>
|
22 |
<PipelineSelector pipeline={pipeline} setPipeline={setPipeline} />
|
23 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
{/* Pipeline Description */}
|
25 |
<div className="mt-4 p-4 bg-gray-50 rounded-lg">
|
26 |
<div className="flex items-start space-x-3">
|
@@ -42,6 +85,11 @@ function App() {
|
|
42 |
{pipeline === 'zero-shot-classification'
|
43 |
? 'Zero-Shot Classification'
|
44 |
: 'Text-Classification'}
|
|
|
|
|
|
|
|
|
|
|
45 |
</h3>
|
46 |
<p className="text-sm text-gray-600 mt-1">
|
47 |
{pipeline === 'zero-shot-classification'
|
|
|
4 |
import TextClassification from './components/TextClassification';
|
5 |
import Header from './Header';
|
6 |
import Footer from './Footer';
|
7 |
+
import { useModel } from './contexts/ModelContext';
|
8 |
|
9 |
function App() {
|
10 |
const [pipeline, setPipeline] = useState('zero-shot-classification');
|
11 |
+
const { progress, status, model } = useModel();
|
12 |
|
13 |
return (
|
14 |
<div className="min-h-screen bg-gradient-to-br from-blue-50 to-indigo-100">
|
|
|
23 |
</h2>
|
24 |
<PipelineSelector pipeline={pipeline} setPipeline={setPipeline} />
|
25 |
|
26 |
+
{/* Model Loading Progress */}
|
27 |
+
{status === 'progress' && (
|
28 |
+
<div className="mt-4 p-4 bg-blue-50 rounded-lg">
|
29 |
+
<div className="flex items-center space-x-3">
|
30 |
+
<div className="flex-shrink-0">
|
31 |
+
<svg
|
32 |
+
className="animate-spin h-5 w-5 text-blue-500"
|
33 |
+
fill="none"
|
34 |
+
viewBox="0 0 24 24"
|
35 |
+
>
|
36 |
+
<circle
|
37 |
+
className="opacity-25"
|
38 |
+
cx="12"
|
39 |
+
cy="12"
|
40 |
+
r="10"
|
41 |
+
stroke="currentColor"
|
42 |
+
strokeWidth="4"
|
43 |
+
></circle>
|
44 |
+
<path
|
45 |
+
className="opacity-75"
|
46 |
+
fill="currentColor"
|
47 |
+
d="M4 12a8 8 0 018-8V0C5.373 0 0 5.373 0 12h4zm2 5.291A7.962 7.962 0 014 12H0c0 3.042 1.135 5.824 3 7.938l3-2.647z"
|
48 |
+
></path>
|
49 |
+
</svg>
|
50 |
+
</div>
|
51 |
+
<div className="flex-1">
|
52 |
+
<p className="text-sm font-medium text-blue-900">
|
53 |
+
Loading Model...
|
54 |
+
</p>
|
55 |
+
<div className="mt-2 bg-blue-200 rounded-full h-2">
|
56 |
+
<div
|
57 |
+
className="bg-blue-500 h-2 rounded-full transition-all duration-300"
|
58 |
+
style={{ width: `${progress.toFixed(2)}%` }}
|
59 |
+
></div>
|
60 |
+
</div>
|
61 |
+
<p className="text-xs text-blue-700 mt-1">{progress.toFixed(2)}%</p>
|
62 |
+
</div>
|
63 |
+
</div>
|
64 |
+
</div>
|
65 |
+
)}
|
66 |
+
|
67 |
{/* Pipeline Description */}
|
68 |
<div className="mt-4 p-4 bg-gray-50 rounded-lg">
|
69 |
<div className="flex items-start space-x-3">
|
|
|
85 |
{pipeline === 'zero-shot-classification'
|
86 |
? 'Zero-Shot Classification'
|
87 |
: 'Text-Classification'}
|
88 |
+
{model && (
|
89 |
+
<span className="ml-2 text-xs text-gray-500 font-normal">
|
90 |
+
({model})
|
91 |
+
</span>
|
92 |
+
)}
|
93 |
</h3>
|
94 |
<p className="text-sm text-gray-600 mt-1">
|
95 |
{pipeline === 'zero-shot-classification'
|
src/components/TextClassification.tsx
CHANGED
@@ -4,6 +4,8 @@ import {
|
|
4 |
TextClassificationWorkerInput,
|
5 |
WorkerMessage
|
6 |
} from '../types';
|
|
|
|
|
7 |
|
8 |
const PLACEHOLDER_TEXTS: string[] = [
|
9 |
'I absolutely love this product! It exceeded all my expectations.',
|
@@ -21,8 +23,8 @@ const PLACEHOLDER_TEXTS: string[] = [
|
|
21 |
function TextClassification() {
|
22 |
const [text, setText] = useState<string>(PLACEHOLDER_TEXTS.join('\n'));
|
23 |
const [results, setResults] = useState<ClassificationOutput[]>([]);
|
24 |
-
const
|
25 |
-
|
26 |
|
27 |
// Create a reference to the worker object.
|
28 |
const worker = useRef<Worker | null>(null);
|
@@ -54,6 +56,7 @@ function TextClassification() {
|
|
54 |
)
|
55 |
setProgress(e.data.output.progress);
|
56 |
} else if (status === 'output') {
|
|
|
57 |
const result = e.data.output!;
|
58 |
setResults((prevResults) => [...prevResults, result]);
|
59 |
console.log(result);
|
@@ -111,9 +114,6 @@ function TextClassification() {
|
|
111 |
? 'Model loading...'
|
112 |
: 'Processing...'}
|
113 |
</button>
|
114 |
-
{status === 'progress' && (
|
115 |
-
<div className="text-sm font-medium">{progress}%</div>
|
116 |
-
)}
|
117 |
<button
|
118 |
className="py-2 px-4 bg-gray-500 hover:bg-gray-600 rounded text-white font-medium transition-colors"
|
119 |
onClick={handleClear}
|
|
|
4 |
TextClassificationWorkerInput,
|
5 |
WorkerMessage
|
6 |
} from '../types';
|
7 |
+
import { useModel } from '../contexts/ModelContext';
|
8 |
+
|
9 |
|
10 |
const PLACEHOLDER_TEXTS: string[] = [
|
11 |
'I absolutely love this product! It exceeded all my expectations.',
|
|
|
23 |
function TextClassification() {
|
24 |
const [text, setText] = useState<string>(PLACEHOLDER_TEXTS.join('\n'));
|
25 |
const [results, setResults] = useState<ClassificationOutput[]>([]);
|
26 |
+
const { setProgress, status, setStatus, setModel } = useModel();
|
27 |
+
setModel('Xenova/bert-base-multilingual-uncased-sentiment')
|
28 |
|
29 |
// Create a reference to the worker object.
|
30 |
const worker = useRef<Worker | null>(null);
|
|
|
56 |
)
|
57 |
setProgress(e.data.output.progress);
|
58 |
} else if (status === 'output') {
|
59 |
+
setStatus('output');
|
60 |
const result = e.data.output!;
|
61 |
setResults((prevResults) => [...prevResults, result]);
|
62 |
console.log(result);
|
|
|
114 |
? 'Model loading...'
|
115 |
: 'Processing...'}
|
116 |
</button>
|
|
|
|
|
|
|
117 |
<button
|
118 |
className="py-2 px-4 bg-gray-500 hover:bg-gray-600 rounded text-white font-medium transition-colors"
|
119 |
onClick={handleClear}
|
src/components/ZeroShotClassification.tsx
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
// src/App.tsx
|
2 |
import { useState, useRef, useEffect, useCallback } from 'react';
|
3 |
import { Section, WorkerMessage, ZeroShotWorkerInput } from '../types';
|
|
|
4 |
|
5 |
const PLACEHOLDER_REVIEWS: string[] = [
|
6 |
// battery/charging problems
|
@@ -44,8 +45,8 @@ function ZeroShotClassification() {
|
|
44 |
PLACEHOLDER_SECTIONS.map((title) => ({ title, items: [] }))
|
45 |
);
|
46 |
|
47 |
-
const
|
48 |
-
|
49 |
|
50 |
// Create a reference to the worker object.
|
51 |
const worker = useRef<Worker | null>(null);
|
@@ -77,6 +78,7 @@ function ZeroShotClassification() {
|
|
77 |
)
|
78 |
setProgress(e.data.output.progress);
|
79 |
} else if (status === 'output') {
|
|
|
80 |
const { sequence, labels, scores } = e.data.output!;
|
81 |
|
82 |
// Threshold for classification
|
@@ -175,9 +177,6 @@ function ZeroShotClassification() {
|
|
175 |
? 'Model loading...'
|
176 |
: 'Processing'}
|
177 |
</button>
|
178 |
-
{status === 'progress' && (
|
179 |
-
<div className="text-sm font-medium">{progress}%</div>
|
180 |
-
)}
|
181 |
<div className="flex gap-1">
|
182 |
<button
|
183 |
className="border py-1 px-2 bg-green-400 rounded text-white text-sm font-medium cursor-pointer"
|
|
|
1 |
// src/App.tsx
|
2 |
import { useState, useRef, useEffect, useCallback } from 'react';
|
3 |
import { Section, WorkerMessage, ZeroShotWorkerInput } from '../types';
|
4 |
+
import { useModel } from '../contexts/ModelContext';
|
5 |
|
6 |
const PLACEHOLDER_REVIEWS: string[] = [
|
7 |
// battery/charging problems
|
|
|
45 |
PLACEHOLDER_SECTIONS.map((title) => ({ title, items: [] }))
|
46 |
);
|
47 |
|
48 |
+
const { setProgress, status, setStatus, setModel } = useModel();
|
49 |
+
setModel('MoritzLaurer/deberta-v3-xsmall-zeroshot-v1.1-all-33')
|
50 |
|
51 |
// Create a reference to the worker object.
|
52 |
const worker = useRef<Worker | null>(null);
|
|
|
78 |
)
|
79 |
setProgress(e.data.output.progress);
|
80 |
} else if (status === 'output') {
|
81 |
+
setStatus('output');
|
82 |
const { sequence, labels, scores } = e.data.output!;
|
83 |
|
84 |
// Threshold for classification
|
|
|
177 |
? 'Model loading...'
|
178 |
: 'Processing'}
|
179 |
</button>
|
|
|
|
|
|
|
180 |
<div className="flex gap-1">
|
181 |
<button
|
182 |
className="border py-1 px-2 bg-green-400 rounded text-white text-sm font-medium cursor-pointer"
|
src/contexts/ModelContext.tsx
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import React, { createContext, useContext, useEffect, useState } from 'react';
|
2 |
+
|
3 |
+
interface ModelContextType {
|
4 |
+
progress: number;
|
5 |
+
status: string;
|
6 |
+
setProgress: (progress: number) => void;
|
7 |
+
setStatus: (status: string) => void;
|
8 |
+
model: string;
|
9 |
+
setModel: (model: string) => void;
|
10 |
+
}
|
11 |
+
|
12 |
+
const ModelContext = createContext<ModelContextType | undefined>(undefined);
|
13 |
+
|
14 |
+
export function ModelProvider({ children }: { children: React.ReactNode }) {
|
15 |
+
const [progress, setProgress] = useState<number>(0);
|
16 |
+
const [status, setStatus] = useState<string>('idle');
|
17 |
+
const [model, setModel] = useState<string>('');
|
18 |
+
|
19 |
+
// set progress to 0 when model is changed
|
20 |
+
useEffect(() => {
|
21 |
+
setProgress(0);
|
22 |
+
}, [model]);
|
23 |
+
|
24 |
+
return (
|
25 |
+
<ModelContext.Provider
|
26 |
+
value={{ progress, setProgress, status, setStatus, model, setModel }}
|
27 |
+
>
|
28 |
+
{children}
|
29 |
+
</ModelContext.Provider>
|
30 |
+
);
|
31 |
+
}
|
32 |
+
|
33 |
+
export function useModel() {
|
34 |
+
const context = useContext(ModelContext);
|
35 |
+
if (context === undefined) {
|
36 |
+
throw new Error('useModel must be used within a ModelProvider');
|
37 |
+
}
|
38 |
+
return context;
|
39 |
+
}
|
src/index.tsx
CHANGED
@@ -3,14 +3,18 @@ import ReactDOM from 'react-dom/client';
|
|
3 |
import './index.css';
|
4 |
import App from './App';
|
5 |
import reportWebVitals from './reportWebVitals';
|
|
|
|
|
6 |
|
7 |
const root = ReactDOM.createRoot(
|
8 |
document.getElementById('root') as HTMLElement
|
9 |
);
|
10 |
root.render(
|
|
|
11 |
<React.StrictMode>
|
12 |
<App />
|
13 |
</React.StrictMode>
|
|
|
14 |
);
|
15 |
|
16 |
// If you want to start measuring performance in your app, pass a function
|
|
|
3 |
import './index.css';
|
4 |
import App from './App';
|
5 |
import reportWebVitals from './reportWebVitals';
|
6 |
+
import { ModelProvider } from './contexts/ModelContext';
|
7 |
+
|
8 |
|
9 |
const root = ReactDOM.createRoot(
|
10 |
document.getElementById('root') as HTMLElement
|
11 |
);
|
12 |
root.render(
|
13 |
+
<ModelProvider>
|
14 |
<React.StrictMode>
|
15 |
<App />
|
16 |
</React.StrictMode>
|
17 |
+
</ModelProvider>
|
18 |
);
|
19 |
|
20 |
// If you want to start measuring performance in your app, pass a function
|