Arvador237 commited on
Commit
359e113
·
verified ·
1 Parent(s): 61eb2f8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +81 -74
app.py CHANGED
@@ -6,7 +6,7 @@ import seaborn as sns
6
  from datetime import datetime
7
  import plotly.express as px
8
  import plotly.graph_objects as go
9
- from data_processing import load_and_preprocess_data
10
  from models import train_arima, train_mlp, train_lstm, train_bayesian_network
11
  from chatbot import InflationChatbot
12
  import os
@@ -62,12 +62,19 @@ st.markdown("""
62
  @st.cache_data
63
  def load_data():
64
  try:
65
- # Essayez de charger depuis le chemin local
66
- data = pd.read_csv('inflation_beac.csv')
67
  except:
68
- # Si échec, utilisez l'URL de base
69
  base_url = "https://raw.githubusercontent.com/username/repo/main/"
70
- data = pd.read_csv(base_url + 'inflation_beac.csv')
 
 
 
 
 
 
 
 
 
71
  return data
72
 
73
  data = load_and_preprocess_data()
@@ -78,24 +85,24 @@ with st.sidebar:
78
  st.markdown("## Paramètres de l'analyse")
79
 
80
  # Sélection du pays
81
- pays_options = data['Pays'].unique()
82
  selected_country = st.selectbox("Sélectionnez un pays", pays_options)
83
 
84
  # Filtrage des dates disponibles pour le pays sélectionné
85
- country_data = data[data['Pays'] == selected_country]
86
- min_date = country_data['Date'].min()
87
- max_date = country_data['Date'].max()
88
 
89
  # Sélection de la plage temporelle
90
  start_date = st.date_input("Date de début",
91
- value=datetime.strptime(min_date, '%Y-%m-%d'),
92
- min_value=datetime.strptime(min_date, '%Y-%m-%d'),
93
- max_value=datetime.strptime(max_date, '%Y-%m-%d'))
94
 
95
  end_date = st.date_input("Date de fin",
96
- value=datetime.strptime(max_date, '%Y-%m-%d'),
97
- min_value=datetime.strptime(min_date, '%Y-%m-%d'),
98
- max_value=datetime.strptime(max_date, '%Y-%m-%d'))
99
 
100
  # Sélection du modèle
101
  model_options = {
@@ -139,53 +146,53 @@ with tab1:
139
 
140
  # Filtrage des données selon les sélections
141
  filtered_data = data[
142
- (data['Pays'] == selected_country) &
143
- (data['Date'] >= str(start_date)) &
144
- (data['Date'] <= str(end_date))
145
  ]
146
 
147
  # Graphique de l'inflation
148
- fig1 = px.line(filtered_data, x='Date', y='Taux d\'inflation (%)',
149
- title=f'Évolution de l\'inflation en {selected_country}',
150
- labels={'Taux d\'inflation (%)': 'Taux d\'inflation (%)'})
151
  fig1.update_layout(height=400)
152
  st.plotly_chart(fig1, use_container_width=True)
153
 
154
  # Graphiques des autres indicateurs
155
  col1, col2 = st.columns(2)
156
  with col1:
157
- fig2 = px.line(filtered_data, x='Date', y='Masse monétaire M2',
158
- title=f'Masse monétaire M2 en {selected_country}')
159
  st.plotly_chart(fig2, use_container_width=True)
160
 
161
- fig3 = px.line(filtered_data, x='Date', y='Taux de croissance du PIB',
162
- title=f'Croissance du PIB en {selected_country}')
163
  st.plotly_chart(fig3, use_container_width=True)
164
 
165
  with col2:
166
- fig4 = px.line(filtered_data, x='Date', y='Taux directeur',
167
- title=f'Taux directeur en {selected_country}')
168
  st.plotly_chart(fig4, use_container_width=True)
169
 
170
- fig5 = px.line(filtered_data, x='Date', y='Balance commerciale',
171
- title=f'Balance commerciale en {selected_country}')
172
  st.plotly_chart(fig5, use_container_width=True)
173
 
174
  # Matrice de corrélation
175
  st.subheader("Matrice de corrélation")
176
- corr_data = filtered_data.drop(columns=['Date', 'Pays'])
177
  corr_matrix = corr_data.corr()
178
 
179
  fig6 = go.Figure(data=go.Heatmap(
180
  z=corr_matrix.values,
181
  x=corr_matrix.columns,
182
  y=corr_matrix.columns,
183
- colorscale='RdBu',
184
  zmin=-1,
185
  zmax=1,
186
  colorbar=dict(title="Coefficient de corrélation")
187
  ))
188
- fig6.update_layout(title='Corrélations entre les indicateurs macroéconomiques',
189
  height=500)
190
  st.plotly_chart(fig6, use_container_width=True)
191
 
@@ -234,24 +241,24 @@ with tab2:
234
 
235
  # Ajout des données réelles
236
  fig_pred.add_trace(go.Scatter(
237
- x=predictions['Date'],
238
- y=predictions['Inflation réelle'],
239
- name='Inflation réelle',
240
- line=dict(color='blue')
241
  ))
242
 
243
  # Ajout des prédictions
244
  fig_pred.add_trace(go.Scatter(
245
- x=predictions['Date'],
246
- y=predictions['Inflation prédite'],
247
- name='Inflation prédite',
248
- line=dict(color='red', dash='dash')
249
  ))
250
 
251
  fig_pred.update_layout(
252
- title=f'Comparaison inflation réelle vs prédite - {selected_model}',
253
- xaxis_title='Date',
254
- yaxis_title='Taux d\'inflation (%)',
255
  height=500
256
  )
257
  st.plotly_chart(fig_pred, use_container_width=True)
@@ -260,7 +267,7 @@ with tab2:
260
  st.subheader("Export des résultats")
261
 
262
  # Format CSV
263
- csv = predictions.to_csv(index=False).encode('utf-8')
264
  st.download_button(
265
  label="Télécharger les prédictions (CSV)",
266
  data=csv,
@@ -270,10 +277,10 @@ with tab2:
270
 
271
  # Format Excel
272
  output = BytesIO()
273
- with pd.ExcelWriter(output, engine='xlsxwriter') as writer:
274
- predictions.to_excel(writer, sheet_name='Prédictions', index=False)
275
- metrics_df = pd.DataFrame.from_dict(metrics, orient='index', columns=['Valeur'])
276
- metrics_df.to_excel(writer, sheet_name='Métriques')
277
  excel_data = output.getvalue()
278
 
279
  st.download_button(
@@ -289,21 +296,21 @@ with tab3:
289
  # Comparaison entre pays
290
  st.subheader("Comparaison entre pays")
291
  selected_countries = st.multiselect("Sélectionnez les pays à comparer",
292
- data['Pays'].unique(),
293
  default=[selected_country])
294
 
295
  if selected_countries:
296
- compare_data = data[data['Pays'].isin(selected_countries)]
297
 
298
  # Graphique comparatif
299
- fig_compare = px.line(compare_data, x='Date', y='Taux d\'inflation (%)',
300
- color='Pays', title='Comparaison des taux d\'inflation')
301
  st.plotly_chart(fig_compare, use_container_width=True)
302
 
303
- # Dernières valeurs
304
  st.subheader("Dernières valeurs disponibles")
305
- latest_data = compare_data.sort_values('Date').groupby('Pays').last().reset_index()
306
- st.dataframe(latest_data.set_index('Pays').drop(columns=['Date']).style.background_gradient(cmap='Blues'))
307
 
308
  # Comparaison entre modèles
309
  st.subheader("Comparaison entre modèles")
@@ -335,38 +342,38 @@ with tab3:
335
  )
336
 
337
  comparison_results.append({
338
- 'Modèle': model_name,
339
- 'MAE': metrics['mae'],
340
- 'RMSE': metrics['rmse'],
341
- '': metrics['r2'],
342
- 'Temps d\'entraînement (s)': metrics['training_time']
343
  })
344
 
345
  comparison_df = pd.DataFrame(comparison_results)
346
 
347
  # Affichage des résultats
348
- st.dataframe(comparison_df.set_index('Modèle').style.background_gradient(cmap='Blues'))
349
 
350
  # Graphique de comparaison
351
  fig_models = go.Figure()
352
 
353
  fig_models.add_trace(go.Bar(
354
- x=comparison_df['Modèle'],
355
- y=comparison_df['MAE'],
356
- name='MAE',
357
- marker_color='indianred'
358
  ))
359
 
360
  fig_models.add_trace(go.Bar(
361
- x=comparison_df['Modèle'],
362
- y=comparison_df['RMSE'],
363
- name='RMSE',
364
- marker_color='lightsalmon'
365
  ))
366
 
367
  fig_models.update_layout(
368
- title='Comparaison des performances des modèles',
369
- barmode='group',
370
  height=500
371
  )
372
 
@@ -376,9 +383,9 @@ with tab4:
376
  st.header("Assistant intelligent")
377
 
378
  # Initialisation du chatbot
379
- if 'chatbot' not in st.session_state:
380
  st.session_state.chatbot = InflationChatbot()
381
- if 'chat_history' not in st.session_state:
382
  st.session_state.chat_history = []
383
 
384
  # Affichage de l'historique de chat
@@ -400,7 +407,7 @@ with tab4:
400
  prompt,
401
  country=selected_country,
402
  model=selected_model,
403
- metrics=metrics if 'metrics' in locals() else None
404
  )
405
  st.markdown(response)
406
 
 
6
  from datetime import datetime
7
  import plotly.express as px
8
  import plotly.graph_objects as go
9
+ from data_processing import load_and_preprocess_data, prepare_timeseries_data, apply_scenarios
10
  from models import train_arima, train_mlp, train_lstm, train_bayesian_network
11
  from chatbot import InflationChatbot
12
  import os
 
62
  @st.cache_data
63
  def load_data():
64
  try:
65
+ data = pd.read_csv("inflation_beac_complet_2010_2025.csv")
 
66
  except:
 
67
  base_url = "https://raw.githubusercontent.com/username/repo/main/"
68
+ data = pd.read_csv(base_url + "inflation_beac_complet_2010_2025.csv")
69
+
70
+ # Renommer les colonnes
71
+ data = data.rename(columns={
72
+ "Année": "Date",
73
+ "Taux d'inflation (%)": "Taux d'inflation",
74
+ "Masse monétaire (M2)": "Masse monétaire M2",
75
+ "Croissance PIB (%)": "Taux de croissance du PIB",
76
+ "Taux de change FCFA/USD": "Taux de change"
77
+ })
78
  return data
79
 
80
  data = load_and_preprocess_data()
 
85
  st.markdown("## Paramètres de l'analyse")
86
 
87
  # Sélection du pays
88
+ pays_options = data["Pays"].unique()
89
  selected_country = st.selectbox("Sélectionnez un pays", pays_options)
90
 
91
  # Filtrage des dates disponibles pour le pays sélectionné
92
+ country_data = data[data["Pays"] == selected_country]
93
+ min_date = country_data["Date"].min()
94
+ max_date = country_data["Date"].max()
95
 
96
  # Sélection de la plage temporelle
97
  start_date = st.date_input("Date de début",
98
+ value=datetime.strptime(min_date, "%Y-%m-%d"),
99
+ min_value=datetime.strptime(min_date, "%Y-%m-%d"),
100
+ max_value=datetime.strptime(max_date, "%Y-%m-%d"))
101
 
102
  end_date = st.date_input("Date de fin",
103
+ value=datetime.strptime(max_date, "%Y-%m-%d"),
104
+ min_value=datetime.strptime(min_date, "%Y-%m-%d"),
105
+ max_value=datetime.strptime(max_date, "%Y-%m-%d"))
106
 
107
  # Sélection du modèle
108
  model_options = {
 
146
 
147
  # Filtrage des données selon les sélections
148
  filtered_data = data[
149
+ (data["Pays"] == selected_country) &
150
+ (data["Date"] >= str(start_date)) &
151
+ (data["Date"] <= str(end_date))
152
  ]
153
 
154
  # Graphique de l'inflation
155
+ fig1 = px.line(filtered_data, x="Date", y="Taux d'inflation",
156
+ title=f"Évolution de l'inflation en {selected_country}",
157
+ labels={"Taux d'inflation": "Taux d'inflation (%)"})
158
  fig1.update_layout(height=400)
159
  st.plotly_chart(fig1, use_container_width=True)
160
 
161
  # Graphiques des autres indicateurs
162
  col1, col2 = st.columns(2)
163
  with col1:
164
+ fig2 = px.line(filtered_data, x="Date", y="Masse monétaire M2",
165
+ title=f"Masse monétaire M2 en {selected_country}")
166
  st.plotly_chart(fig2, use_container_width=True)
167
 
168
+ fig3 = px.line(filtered_data, x="Date", y="Taux de croissance du PIB",
169
+ title=f"Croissance du PIB en {selected_country}")
170
  st.plotly_chart(fig3, use_container_width=True)
171
 
172
  with col2:
173
+ fig4 = px.line(filtered_data, x="Date", y="Taux directeur",
174
+ title=f"Taux directeur en {selected_country}")
175
  st.plotly_chart(fig4, use_container_width=True)
176
 
177
+ fig5 = px.line(filtered_data, x="Date", y="Balance commerciale",
178
+ title=f"Balance commerciale en {selected_country}")
179
  st.plotly_chart(fig5, use_container_width=True)
180
 
181
  # Matrice de corrélation
182
  st.subheader("Matrice de corrélation")
183
+ corr_data = filtered_data.drop(columns=["Date", "Pays"])
184
  corr_matrix = corr_data.corr()
185
 
186
  fig6 = go.Figure(data=go.Heatmap(
187
  z=corr_matrix.values,
188
  x=corr_matrix.columns,
189
  y=corr_matrix.columns,
190
+ colorscale="RdBu",
191
  zmin=-1,
192
  zmax=1,
193
  colorbar=dict(title="Coefficient de corrélation")
194
  ))
195
+ fig6.update_layout(title="Corrélations entre les indicateurs macroéconomiques",
196
  height=500)
197
  st.plotly_chart(fig6, use_container_width=True)
198
 
 
241
 
242
  # Ajout des données réelles
243
  fig_pred.add_trace(go.Scatter(
244
+ x=predictions["Date"],
245
+ y=predictions["Inflation réelle"],
246
+ name="Inflation réelle",
247
+ line=dict(color="blue")
248
  ))
249
 
250
  # Ajout des prédictions
251
  fig_pred.add_trace(go.Scatter(
252
+ x=predictions["Date"],
253
+ y=predictions["Inflation prédite"],
254
+ name="Inflation prédite",
255
+ line=dict(color="red", dash="dash")
256
  ))
257
 
258
  fig_pred.update_layout(
259
+ title=f"Comparaison inflation réelle vs prédite - {selected_model}",
260
+ xaxis_title="Date",
261
+ yaxis_title="Taux d'inflation (%)",
262
  height=500
263
  )
264
  st.plotly_chart(fig_pred, use_container_width=True)
 
267
  st.subheader("Export des résultats")
268
 
269
  # Format CSV
270
+ csv = predictions.to_csv(index=False).encode("utf-8")
271
  st.download_button(
272
  label="Télécharger les prédictions (CSV)",
273
  data=csv,
 
277
 
278
  # Format Excel
279
  output = BytesIO()
280
+ with pd.ExcelWriter(output, engine="xlsxwriter") as writer:
281
+ predictions.to_excel(writer, sheet_name="Prédictions", index=False)
282
+ metrics_df = pd.DataFrame.from_dict(metrics, orient="index", columns=["Valeur"])
283
+ metrics_df.to_excel(writer, sheet_name="Métriques")
284
  excel_data = output.getvalue()
285
 
286
  st.download_button(
 
296
  # Comparaison entre pays
297
  st.subheader("Comparaison entre pays")
298
  selected_countries = st.multiselect("Sélectionnez les pays à comparer",
299
+ data["Pays"].unique(),
300
  default=[selected_country])
301
 
302
  if selected_countries:
303
+ compare_data = data[data["Pays"].isin(selected_countries)]
304
 
305
  # Graphique comparatif
306
+ fig_compare = px.line(compare_data, x="Date", y="Taux d'inflation",
307
+ color="Pays", title="Comparaison des taux d'inflation")
308
  st.plotly_chart(fig_compare, use_container_width=True)
309
 
310
+ # Dernières valeurs disponibles
311
  st.subheader("Dernières valeurs disponibles")
312
+ latest_data = compare_data.sort_values("Date").groupby("Pays").last().reset_index()
313
+ st.dataframe(latest_data.set_index("Pays").drop(columns=["Date"]).style.background_gradient(cmap="Blues"))
314
 
315
  # Comparaison entre modèles
316
  st.subheader("Comparaison entre modèles")
 
342
  )
343
 
344
  comparison_results.append({
345
+ "Modèle": model_name,
346
+ "MAE": metrics["mae"],
347
+ "RMSE": metrics["rmse"],
348
+ "": metrics["r2"],
349
+ "Temps d'entraînement (s)": metrics["training_time"]
350
  })
351
 
352
  comparison_df = pd.DataFrame(comparison_results)
353
 
354
  # Affichage des résultats
355
+ st.dataframe(comparison_df.set_index("Modèle").style.background_gradient(cmap="Blues"))
356
 
357
  # Graphique de comparaison
358
  fig_models = go.Figure()
359
 
360
  fig_models.add_trace(go.Bar(
361
+ x=comparison_df["Modèle"],
362
+ y=comparison_df["MAE"],
363
+ name="MAE",
364
+ marker_color="indianred"
365
  ))
366
 
367
  fig_models.add_trace(go.Bar(
368
+ x=comparison_df["Modèle"],
369
+ y=comparison_df["RMSE"],
370
+ name="RMSE",
371
+ marker_color="lightsalmon"
372
  ))
373
 
374
  fig_models.update_layout(
375
+ title="Comparaison des performances des modèles",
376
+ barmode="group",
377
  height=500
378
  )
379
 
 
383
  st.header("Assistant intelligent")
384
 
385
  # Initialisation du chatbot
386
+ if "chatbot" not in st.session_state:
387
  st.session_state.chatbot = InflationChatbot()
388
+ if "chat_history" not in st.session_state:
389
  st.session_state.chat_history = []
390
 
391
  # Affichage de l'historique de chat
 
407
  prompt,
408
  country=selected_country,
409
  model=selected_model,
410
+ metrics=metrics if "metrics" in locals() else None
411
  )
412
  st.markdown(response)
413