ntam0001 commited on
Commit
ed024cc
ยท
verified ยท
1 Parent(s): b0e10f1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +239 -123
app.py CHANGED
@@ -11,7 +11,6 @@ from plotly.subplots import make_subplots
11
  import os
12
 
13
  # Load model artifacts
14
- @st.cache_resource
15
  def load_model_artifacts():
16
  try:
17
  # Load the trained model
@@ -30,14 +29,24 @@ def load_model_artifacts():
30
  raise Exception(f"Error loading model artifacts: {str(e)}")
31
 
32
  # Initialize model components
33
- model, scaler, metadata = load_model_artifacts()
34
- feature_names = metadata['feature_names']
 
 
 
 
 
 
 
35
 
36
  def predict_student_eligibility(*args):
37
  """
38
  Predict student eligibility based on input features
39
  """
40
  try:
 
 
 
41
  # Create input dictionary from gradio inputs
42
  input_data = {feature_names[i]: args[i] for i in range(len(feature_names))}
43
 
@@ -51,7 +60,7 @@ def predict_student_eligibility(*args):
51
  input_reshaped = input_scaled.reshape(input_scaled.shape[0], input_scaled.shape[1], 1)
52
 
53
  # Make prediction
54
- probability = model.predict(input_reshaped)[0][0]
55
  prediction = "Eligible" if probability > 0.5 else "Not Eligible"
56
  confidence = abs(probability - 0.5) * 2 # Convert to confidence score
57
 
@@ -61,105 +70,139 @@ def predict_student_eligibility(*args):
61
  return prediction, f"{probability:.4f}", f"{confidence:.4f}", fig
62
 
63
  except Exception as e:
64
- return f"Error: {str(e)}", "N/A", "N/A", None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
  def create_prediction_viz(probability, prediction, input_data):
67
  """
68
  Create visualization for prediction results
69
  """
70
- # Create subplots
71
- fig = make_subplots(
72
- rows=2, cols=2,
73
- subplot_titles=('Prediction Probability', 'Confidence Meter', 'Input Features', 'Feature Distribution'),
74
- specs=[[{"type": "indicator"}, {"type": "indicator"}],
75
- [{"type": "bar"}, {"type": "histogram"}]]
76
- )
77
-
78
- # Prediction probability gauge
79
- fig.add_trace(
80
- go.Indicator(
81
- mode="gauge+number+delta",
82
- value=probability,
83
- domain={'x': [0, 1], 'y': [0, 1]},
84
- title={'text': "Eligibility Probability"},
85
- gauge={
86
- 'axis': {'range': [None, 1]},
87
- 'bar': {'color': "darkblue"},
88
- 'steps': [
89
- {'range': [0, 0.5], 'color': "lightgray"},
90
- {'range': [0.5, 1], 'color': "lightgreen"}
91
- ],
92
- 'threshold': {
93
- 'line': {'color': "red", 'width': 4},
94
- 'thickness': 0.75,
95
- 'value': 0.5
 
 
96
  }
97
- }
98
- ),
99
- row=1, col=1
100
- )
101
-
102
- # Confidence meter
103
- confidence = abs(probability - 0.5) * 2
104
- fig.add_trace(
105
- go.Indicator(
106
- mode="gauge+number",
107
- value=confidence,
108
- domain={'x': [0, 1], 'y': [0, 1]},
109
- title={'text': "Prediction Confidence"},
110
- gauge={
111
- 'axis': {'range': [None, 1]},
112
- 'bar': {'color': "orange"},
113
- 'steps': [
114
- {'range': [0, 0.3], 'color': "lightcoral"},
115
- {'range': [0.3, 0.7], 'color': "lightyellow"},
116
- {'range': [0.7, 1], 'color': "lightgreen"}
117
- ]
118
- }
119
- ),
120
- row=1, col=2
121
- )
122
-
123
- # Input features bar chart
124
- features = list(input_data.keys())
125
- values = list(input_data.values())
126
-
127
- fig.add_trace(
128
- go.Bar(x=features, y=values, name="Input Values", marker_color="skyblue"),
129
- row=2, col=1
130
- )
131
-
132
- # Feature distribution (example data)
133
- fig.add_trace(
134
- go.Histogram(x=values, nbinsx=10, name="Distribution", marker_color="lightcoral"),
135
- row=2, col=2
136
- )
137
-
138
- fig.update_layout(
139
- height=800,
140
- showlegend=False,
141
- title_text="Student Eligibility Prediction Dashboard",
142
- title_x=0.5
143
- )
144
-
145
- return fig
 
 
 
 
 
 
 
 
146
 
147
  def create_model_info():
148
  """
149
  Create model information display
150
  """
151
- info_html = f"""
152
- <div style="padding: 20px; background-color: #f0f2f6; border-radius: 10px; margin: 10px 0;">
153
- <h3>๐Ÿค– Model Information</h3>
154
- <ul>
155
- <li><strong>Model Type:</strong> {metadata.get('model_type', 'CNN')}</li>
156
- <li><strong>Test Accuracy:</strong> {metadata['performance_metrics']['test_accuracy']:.4f}</li>
157
- <li><strong>AUC Score:</strong> {metadata['performance_metrics']['auc_score']:.4f}</li>
158
- <li><strong>Creation Date:</strong> {metadata.get('creation_date', 'N/A')}</li>
159
- <li><strong>Features:</strong> {len(feature_names)} input features</li>
160
- </ul>
161
- </div>
162
- """
 
 
 
 
 
 
 
 
163
  return info_html
164
 
165
  def batch_predict(file):
@@ -167,6 +210,12 @@ def batch_predict(file):
167
  Batch prediction from uploaded CSV file
168
  """
169
  try:
 
 
 
 
 
 
170
  # Read the uploaded file
171
  df = pd.read_csv(file.name)
172
 
@@ -199,13 +248,19 @@ def batch_predict(file):
199
  results_df.to_csv(output_file, index=False)
200
 
201
  # Create summary statistics
202
- summary = f"""
203
- Batch Prediction Summary:
204
- - Total predictions: {len(results_df)}
205
- - Eligible: {sum(1 for p in predictions if p == 'Eligible')}
206
- - Not Eligible: {sum(1 for p in predictions if p == 'Not Eligible')}
207
- - Average Probability: {np.mean(probabilities):.4f}
208
- - Average Confidence: {np.mean(np.abs(probabilities - 0.5) * 2):.4f}
 
 
 
 
 
 
209
  """
210
 
211
  return summary, output_file
@@ -229,6 +284,9 @@ with gr.Blocks(
229
  border-radius: 10px;
230
  margin-bottom: 20px;
231
  }
 
 
 
232
  """
233
  ) as demo:
234
 
@@ -242,66 +300,120 @@ with gr.Blocks(
242
 
243
  with gr.Tabs():
244
  # Single Prediction Tab
245
- with gr.TabItem("Single Prediction"):
246
  gr.Markdown("### Enter student information to predict eligibility")
247
 
248
  with gr.Row():
249
  with gr.Column(scale=1):
 
250
  # Create input components dynamically based on features
251
  inputs = []
252
- for feature in feature_names:
253
  inputs.append(
254
  gr.Number(
255
- label=f"{feature}",
256
- value=85, # Default value
257
  minimum=0,
258
  maximum=100,
259
- step=1
 
260
  )
261
  )
262
 
263
- predict_btn = gr.Button("๐Ÿ”ฎ Predict Eligibility", variant="primary", size="lg")
 
 
 
 
 
264
 
265
  with gr.Column(scale=2):
 
266
  with gr.Row():
267
- prediction_output = gr.Textbox(label="Prediction", scale=1)
268
- probability_output = gr.Textbox(label="Probability", scale=1)
269
- confidence_output = gr.Textbox(label="Confidence", scale=1)
270
 
271
- prediction_plot = gr.Plot(label="Prediction Visualization")
272
 
273
  # Model information
274
  gr.HTML(create_model_info())
275
 
276
  # Batch Prediction Tab
277
- with gr.TabItem("Batch Prediction"):
278
  gr.Markdown("### Upload a CSV file for batch predictions")
279
- gr.Markdown(f"**Required columns:** {', '.join(feature_names)}")
 
 
 
 
 
 
 
 
 
 
280
 
281
  with gr.Row():
282
  with gr.Column():
283
  file_input = gr.File(
284
- label="Upload CSV File",
285
  file_types=[".csv"],
286
  type="file"
287
  )
288
- batch_predict_btn = gr.Button("๐Ÿ“Š Process Batch", variant="primary")
 
 
 
 
289
 
290
  with gr.Column():
291
- batch_output = gr.Textbox(label="Batch Results Summary", lines=10)
292
- download_file = gr.File(label="Download Results")
 
 
 
 
293
 
294
  # Model Analytics Tab
295
- with gr.TabItem("Model Analytics"):
296
  gr.Markdown("### Model Performance Metrics")
297
 
298
- # Performance metrics
299
- metrics_df = pd.DataFrame([metadata['performance_metrics']])
300
- gr.Dataframe(metrics_df, label="Performance Metrics")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
301
 
302
- # Feature importance (placeholder - you'd need to calculate this)
303
- gr.Markdown("### Feature Names")
304
- gr.Textbox(value=", ".join(feature_names), label="Model Features", lines=3)
 
 
 
 
 
 
 
305
 
306
  # Event handlers
307
  predict_btn.click(
@@ -318,4 +430,8 @@ with gr.Blocks(
318
 
319
  # Launch the app
320
  if __name__ == "__main__":
321
- demo.launch(share=True)
 
 
 
 
 
11
  import os
12
 
13
  # Load model artifacts
 
14
  def load_model_artifacts():
15
  try:
16
  # Load the trained model
 
29
  raise Exception(f"Error loading model artifacts: {str(e)}")
30
 
31
  # Initialize model components
32
+ try:
33
+ model, scaler, metadata = load_model_artifacts()
34
+ feature_names = metadata['feature_names']
35
+ print(f"โœ… Model loaded successfully with features: {feature_names}")
36
+ except Exception as e:
37
+ print(f"โŒ Error loading model: {e}")
38
+ # Fallback values for testing
39
+ model, scaler, metadata = None, None, {}
40
+ feature_names = ['Feature_1', 'Feature_2', 'Feature_3', 'Feature_4']
41
 
42
  def predict_student_eligibility(*args):
43
  """
44
  Predict student eligibility based on input features
45
  """
46
  try:
47
+ if model is None or scaler is None:
48
+ return "Model not loaded", "N/A", "N/A", create_error_plot()
49
+
50
  # Create input dictionary from gradio inputs
51
  input_data = {feature_names[i]: args[i] for i in range(len(feature_names))}
52
 
 
60
  input_reshaped = input_scaled.reshape(input_scaled.shape[0], input_scaled.shape[1], 1)
61
 
62
  # Make prediction
63
+ probability = float(model.predict(input_reshaped)[0][0])
64
  prediction = "Eligible" if probability > 0.5 else "Not Eligible"
65
  confidence = abs(probability - 0.5) * 2 # Convert to confidence score
66
 
 
70
  return prediction, f"{probability:.4f}", f"{confidence:.4f}", fig
71
 
72
  except Exception as e:
73
+ return f"Error: {str(e)}", "N/A", "N/A", create_error_plot()
74
+
75
+ def create_error_plot():
76
+ """Create a simple error plot"""
77
+ fig = go.Figure()
78
+ fig.add_annotation(
79
+ text="Model not available or error occurred",
80
+ xref="paper", yref="paper",
81
+ x=0.5, y=0.5, xanchor='center', yanchor='middle',
82
+ showarrow=False, font=dict(size=20)
83
+ )
84
+ fig.update_layout(
85
+ xaxis={'visible': False},
86
+ yaxis={'visible': False},
87
+ height=400
88
+ )
89
+ return fig
90
 
91
  def create_prediction_viz(probability, prediction, input_data):
92
  """
93
  Create visualization for prediction results
94
  """
95
+ try:
96
+ # Create subplots
97
+ fig = make_subplots(
98
+ rows=2, cols=2,
99
+ subplot_titles=('Prediction Probability', 'Confidence Meter', 'Input Features', 'Probability Distribution'),
100
+ specs=[[{"type": "indicator"}, {"type": "indicator"}],
101
+ [{"type": "bar"}, {"type": "scatter"}]]
102
+ )
103
+
104
+ # Prediction probability gauge
105
+ fig.add_trace(
106
+ go.Indicator(
107
+ mode="gauge+number",
108
+ value=probability,
109
+ domain={'x': [0, 1], 'y': [0, 1]},
110
+ title={'text': "Eligibility Probability"},
111
+ gauge={
112
+ 'axis': {'range': [None, 1]},
113
+ 'bar': {'color': "darkblue"},
114
+ 'steps': [
115
+ {'range': [0, 0.5], 'color': "lightcoral"},
116
+ {'range': [0.5, 1], 'color': "lightgreen"}
117
+ ],
118
+ 'threshold': {
119
+ 'line': {'color': "red", 'width': 4},
120
+ 'thickness': 0.75,
121
+ 'value': 0.5
122
+ }
123
  }
124
+ ),
125
+ row=1, col=1
126
+ )
127
+
128
+ # Confidence meter
129
+ confidence = abs(probability - 0.5) * 2
130
+ fig.add_trace(
131
+ go.Indicator(
132
+ mode="gauge+number",
133
+ value=confidence,
134
+ domain={'x': [0, 1], 'y': [0, 1]},
135
+ title={'text': "Prediction Confidence"},
136
+ gauge={
137
+ 'axis': {'range': [None, 1]},
138
+ 'bar': {'color': "orange"},
139
+ 'steps': [
140
+ {'range': [0, 0.3], 'color': "lightcoral"},
141
+ {'range': [0.3, 0.7], 'color': "lightyellow"},
142
+ {'range': [0.7, 1], 'color': "lightgreen"}
143
+ ]
144
+ }
145
+ ),
146
+ row=1, col=2
147
+ )
148
+
149
+ # Input features bar chart
150
+ features = list(input_data.keys())
151
+ values = list(input_data.values())
152
+
153
+ fig.add_trace(
154
+ go.Bar(x=features, y=values, name="Input Values", marker_color="skyblue"),
155
+ row=2, col=1
156
+ )
157
+
158
+ # Simple probability visualization
159
+ fig.add_trace(
160
+ go.Scatter(
161
+ x=[0, 1],
162
+ y=[probability, probability],
163
+ mode='lines+markers',
164
+ name="Probability",
165
+ line=dict(color="red", width=3),
166
+ marker=dict(size=10)
167
+ ),
168
+ row=2, col=2
169
+ )
170
+
171
+ fig.update_layout(
172
+ height=800,
173
+ showlegend=False,
174
+ title_text="Student Eligibility Prediction Dashboard",
175
+ title_x=0.5
176
+ )
177
+
178
+ return fig
179
+ except Exception as e:
180
+ return create_error_plot()
181
 
182
  def create_model_info():
183
  """
184
  Create model information display
185
  """
186
+ if metadata:
187
+ info_html = f"""
188
+ <div style="padding: 20px; background-color: #f0f2f6; border-radius: 10px; margin: 10px 0;">
189
+ <h3>๐Ÿค– Model Information</h3>
190
+ <ul>
191
+ <li><strong>Model Type:</strong> {metadata.get('model_type', 'CNN')}</li>
192
+ <li><strong>Test Accuracy:</strong> {metadata.get('performance_metrics', {}).get('test_accuracy', 'N/A')}</li>
193
+ <li><strong>AUC Score:</strong> {metadata.get('performance_metrics', {}).get('auc_score', 'N/A')}</li>
194
+ <li><strong>Creation Date:</strong> {metadata.get('creation_date', 'N/A')}</li>
195
+ <li><strong>Features:</strong> {len(feature_names)} input features</li>
196
+ </ul>
197
+ </div>
198
+ """
199
+ else:
200
+ info_html = """
201
+ <div style="padding: 20px; background-color: #ffebee; border-radius: 10px; margin: 10px 0;">
202
+ <h3>โš ๏ธ Model Information</h3>
203
+ <p>Model artifacts not loaded. Please ensure all required files are uploaded.</p>
204
+ </div>
205
+ """
206
  return info_html
207
 
208
  def batch_predict(file):
 
210
  Batch prediction from uploaded CSV file
211
  """
212
  try:
213
+ if model is None or scaler is None:
214
+ return "Model not loaded. Please check if all model files are uploaded.", None
215
+
216
+ if file is None:
217
+ return "Please upload a CSV file.", None
218
+
219
  # Read the uploaded file
220
  df = pd.read_csv(file.name)
221
 
 
248
  results_df.to_csv(output_file, index=False)
249
 
250
  # Create summary statistics
251
+ eligible_count = sum(1 for p in predictions if p == 'Eligible')
252
+ not_eligible_count = len(predictions) - eligible_count
253
+
254
+ summary = f"""Batch Prediction Summary:
255
+ โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”
256
+ ๐Ÿ“Š Total predictions: {len(results_df)}
257
+ โœ… Eligible: {eligible_count} ({eligible_count/len(predictions)*100:.1f}%)
258
+ โŒ Not Eligible: {not_eligible_count} ({not_eligible_count/len(predictions)*100:.1f}%)
259
+ ๐Ÿ“ˆ Average Probability: {np.mean(probabilities):.4f}
260
+ ๐ŸŽฏ Average Confidence: {np.mean(np.abs(probabilities - 0.5) * 2):.4f}
261
+ โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”
262
+
263
+ Results saved to: {output_file}
264
  """
265
 
266
  return summary, output_file
 
284
  border-radius: 10px;
285
  margin-bottom: 20px;
286
  }
287
+ .feature-input {
288
+ margin: 5px 0;
289
+ }
290
  """
291
  ) as demo:
292
 
 
300
 
301
  with gr.Tabs():
302
  # Single Prediction Tab
303
+ with gr.TabItem("๐Ÿ”ฎ Single Prediction"):
304
  gr.Markdown("### Enter student information to predict eligibility")
305
 
306
  with gr.Row():
307
  with gr.Column(scale=1):
308
+ gr.Markdown("#### Input Features")
309
  # Create input components dynamically based on features
310
  inputs = []
311
+ for i, feature in enumerate(feature_names):
312
  inputs.append(
313
  gr.Number(
314
+ label=f"๐Ÿ“Š {feature}",
315
+ value=75 + i*5, # Different default values
316
  minimum=0,
317
  maximum=100,
318
+ step=0.1,
319
+ elem_classes=["feature-input"]
320
  )
321
  )
322
 
323
+ predict_btn = gr.Button(
324
+ "๐Ÿ”ฎ Predict Eligibility",
325
+ variant="primary",
326
+ size="lg",
327
+ elem_id="predict-btn"
328
+ )
329
 
330
  with gr.Column(scale=2):
331
+ gr.Markdown("#### Prediction Results")
332
  with gr.Row():
333
+ prediction_output = gr.Textbox(label="๐ŸŽฏ Prediction", scale=1)
334
+ probability_output = gr.Textbox(label="๐Ÿ“Š Probability", scale=1)
335
+ confidence_output = gr.Textbox(label="๐ŸŽฏ Confidence", scale=1)
336
 
337
+ prediction_plot = gr.Plot(label="๐Ÿ“ˆ Prediction Visualization")
338
 
339
  # Model information
340
  gr.HTML(create_model_info())
341
 
342
  # Batch Prediction Tab
343
+ with gr.TabItem("๐Ÿ“Š Batch Prediction"):
344
  gr.Markdown("### Upload a CSV file for batch predictions")
345
+ gr.Markdown(f"**Required columns:** `{', '.join(feature_names)}`")
346
+
347
+ # Sample CSV format
348
+ gr.Markdown("""
349
+ **Example CSV format:**
350
+ ```csv
351
+ Feature_1,Feature_2,Feature_3,Feature_4
352
+ 85,90,75,88
353
+ 92,78,85,91
354
+ ```
355
+ """)
356
 
357
  with gr.Row():
358
  with gr.Column():
359
  file_input = gr.File(
360
+ label="๐Ÿ“ Upload CSV File",
361
  file_types=[".csv"],
362
  type="file"
363
  )
364
+ batch_predict_btn = gr.Button(
365
+ "๐Ÿ“Š Process Batch",
366
+ variant="primary",
367
+ size="lg"
368
+ )
369
 
370
  with gr.Column():
371
+ batch_output = gr.Textbox(
372
+ label="๐Ÿ“‹ Batch Results Summary",
373
+ lines=15,
374
+ max_lines=20
375
+ )
376
+ download_file = gr.File(label="โฌ‡๏ธ Download Results")
377
 
378
  # Model Analytics Tab
379
+ with gr.TabItem("๐Ÿ“ˆ Model Analytics"):
380
  gr.Markdown("### Model Performance Metrics")
381
 
382
+ if metadata and 'performance_metrics' in metadata:
383
+ # Performance metrics
384
+ metrics_data = metadata['performance_metrics']
385
+ metrics_df = pd.DataFrame([{
386
+ 'Metric': k.replace('_', ' ').title(),
387
+ 'Value': f"{v:.4f}" if isinstance(v, float) else str(v)
388
+ } for k, v in metrics_data.items()])
389
+
390
+ gr.Dataframe(
391
+ metrics_df,
392
+ label="๐ŸŽฏ Performance Metrics",
393
+ headers=['Metric', 'Value']
394
+ )
395
+ else:
396
+ gr.Markdown("โš ๏ธ **Performance metrics not available**")
397
+
398
+ # Feature information
399
+ gr.Markdown("### ๐Ÿ“Š Model Features")
400
+ feature_info = pd.DataFrame({
401
+ 'Feature Name': feature_names,
402
+ 'Index': range(len(feature_names)),
403
+ 'Type': ['Numerical'] * len(feature_names)
404
+ })
405
+ gr.Dataframe(feature_info, label="Feature Information")
406
 
407
+ # Model architecture info
408
+ if metadata:
409
+ gr.Markdown("### ๐Ÿ—๏ธ Model Architecture")
410
+ arch_info = f"""
411
+ - **Model Type**: {metadata.get('model_type', 'CNN')}
412
+ - **Input Shape**: {metadata.get('input_shape', 'N/A')}
413
+ - **Total Features**: {len(feature_names)}
414
+ - **Output Classes**: {len(metadata.get('target_classes', {}))}
415
+ """
416
+ gr.Markdown(arch_info)
417
 
418
  # Event handlers
419
  predict_btn.click(
 
430
 
431
  # Launch the app
432
  if __name__ == "__main__":
433
+ demo.launch(
434
+ share=False,
435
+ server_name="0.0.0.0",
436
+ server_port=7860
437
+ )