Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -6,7 +6,7 @@ import seaborn as sns
|
|
6 |
from datetime import datetime
|
7 |
import plotly.express as px
|
8 |
import plotly.graph_objects as go
|
9 |
-
from data_processing import load_and_preprocess_data
|
10 |
from models import train_arima, train_mlp, train_lstm, train_bayesian_network
|
11 |
from chatbot import InflationChatbot
|
12 |
import os
|
@@ -62,12 +62,19 @@ st.markdown("""
|
|
62 |
@st.cache_data
|
63 |
def load_data():
|
64 |
try:
|
65 |
-
|
66 |
-
data = pd.read_csv('inflation_beac.csv')
|
67 |
except:
|
68 |
-
# Si échec, utilisez l'URL de base
|
69 |
base_url = "https://raw.githubusercontent.com/username/repo/main/"
|
70 |
-
data = pd.read_csv(base_url +
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
71 |
return data
|
72 |
|
73 |
data = load_and_preprocess_data()
|
@@ -78,24 +85,24 @@ with st.sidebar:
|
|
78 |
st.markdown("## Paramètres de l'analyse")
|
79 |
|
80 |
# Sélection du pays
|
81 |
-
pays_options = data[
|
82 |
selected_country = st.selectbox("Sélectionnez un pays", pays_options)
|
83 |
|
84 |
# Filtrage des dates disponibles pour le pays sélectionné
|
85 |
-
country_data = data[data[
|
86 |
-
min_date = country_data[
|
87 |
-
max_date = country_data[
|
88 |
|
89 |
# Sélection de la plage temporelle
|
90 |
start_date = st.date_input("Date de début",
|
91 |
-
value=datetime.strptime(min_date,
|
92 |
-
min_value=datetime.strptime(min_date,
|
93 |
-
max_value=datetime.strptime(max_date,
|
94 |
|
95 |
end_date = st.date_input("Date de fin",
|
96 |
-
value=datetime.strptime(max_date,
|
97 |
-
min_value=datetime.strptime(min_date,
|
98 |
-
max_value=datetime.strptime(max_date,
|
99 |
|
100 |
# Sélection du modèle
|
101 |
model_options = {
|
@@ -139,53 +146,53 @@ with tab1:
|
|
139 |
|
140 |
# Filtrage des données selon les sélections
|
141 |
filtered_data = data[
|
142 |
-
(data[
|
143 |
-
(data[
|
144 |
-
(data[
|
145 |
]
|
146 |
|
147 |
# Graphique de l'inflation
|
148 |
-
fig1 = px.line(filtered_data, x=
|
149 |
-
title=f
|
150 |
-
labels={
|
151 |
fig1.update_layout(height=400)
|
152 |
st.plotly_chart(fig1, use_container_width=True)
|
153 |
|
154 |
# Graphiques des autres indicateurs
|
155 |
col1, col2 = st.columns(2)
|
156 |
with col1:
|
157 |
-
fig2 = px.line(filtered_data, x=
|
158 |
-
title=f
|
159 |
st.plotly_chart(fig2, use_container_width=True)
|
160 |
|
161 |
-
fig3 = px.line(filtered_data, x=
|
162 |
-
title=f
|
163 |
st.plotly_chart(fig3, use_container_width=True)
|
164 |
|
165 |
with col2:
|
166 |
-
fig4 = px.line(filtered_data, x=
|
167 |
-
title=f
|
168 |
st.plotly_chart(fig4, use_container_width=True)
|
169 |
|
170 |
-
fig5 = px.line(filtered_data, x=
|
171 |
-
title=f
|
172 |
st.plotly_chart(fig5, use_container_width=True)
|
173 |
|
174 |
# Matrice de corrélation
|
175 |
st.subheader("Matrice de corrélation")
|
176 |
-
corr_data = filtered_data.drop(columns=[
|
177 |
corr_matrix = corr_data.corr()
|
178 |
|
179 |
fig6 = go.Figure(data=go.Heatmap(
|
180 |
z=corr_matrix.values,
|
181 |
x=corr_matrix.columns,
|
182 |
y=corr_matrix.columns,
|
183 |
-
colorscale=
|
184 |
zmin=-1,
|
185 |
zmax=1,
|
186 |
colorbar=dict(title="Coefficient de corrélation")
|
187 |
))
|
188 |
-
fig6.update_layout(title=
|
189 |
height=500)
|
190 |
st.plotly_chart(fig6, use_container_width=True)
|
191 |
|
@@ -234,24 +241,24 @@ with tab2:
|
|
234 |
|
235 |
# Ajout des données réelles
|
236 |
fig_pred.add_trace(go.Scatter(
|
237 |
-
x=predictions[
|
238 |
-
y=predictions[
|
239 |
-
name=
|
240 |
-
line=dict(color=
|
241 |
))
|
242 |
|
243 |
# Ajout des prédictions
|
244 |
fig_pred.add_trace(go.Scatter(
|
245 |
-
x=predictions[
|
246 |
-
y=predictions[
|
247 |
-
name=
|
248 |
-
line=dict(color=
|
249 |
))
|
250 |
|
251 |
fig_pred.update_layout(
|
252 |
-
title=f
|
253 |
-
xaxis_title=
|
254 |
-
yaxis_title=
|
255 |
height=500
|
256 |
)
|
257 |
st.plotly_chart(fig_pred, use_container_width=True)
|
@@ -260,7 +267,7 @@ with tab2:
|
|
260 |
st.subheader("Export des résultats")
|
261 |
|
262 |
# Format CSV
|
263 |
-
csv = predictions.to_csv(index=False).encode(
|
264 |
st.download_button(
|
265 |
label="Télécharger les prédictions (CSV)",
|
266 |
data=csv,
|
@@ -270,10 +277,10 @@ with tab2:
|
|
270 |
|
271 |
# Format Excel
|
272 |
output = BytesIO()
|
273 |
-
with pd.ExcelWriter(output, engine=
|
274 |
-
predictions.to_excel(writer, sheet_name=
|
275 |
-
metrics_df = pd.DataFrame.from_dict(metrics, orient=
|
276 |
-
metrics_df.to_excel(writer, sheet_name=
|
277 |
excel_data = output.getvalue()
|
278 |
|
279 |
st.download_button(
|
@@ -289,21 +296,21 @@ with tab3:
|
|
289 |
# Comparaison entre pays
|
290 |
st.subheader("Comparaison entre pays")
|
291 |
selected_countries = st.multiselect("Sélectionnez les pays à comparer",
|
292 |
-
data[
|
293 |
default=[selected_country])
|
294 |
|
295 |
if selected_countries:
|
296 |
-
compare_data = data[data[
|
297 |
|
298 |
# Graphique comparatif
|
299 |
-
fig_compare = px.line(compare_data, x=
|
300 |
-
color=
|
301 |
st.plotly_chart(fig_compare, use_container_width=True)
|
302 |
|
303 |
-
# Dernières valeurs
|
304 |
st.subheader("Dernières valeurs disponibles")
|
305 |
-
latest_data = compare_data.sort_values(
|
306 |
-
st.dataframe(latest_data.set_index(
|
307 |
|
308 |
# Comparaison entre modèles
|
309 |
st.subheader("Comparaison entre modèles")
|
@@ -335,38 +342,38 @@ with tab3:
|
|
335 |
)
|
336 |
|
337 |
comparison_results.append({
|
338 |
-
|
339 |
-
|
340 |
-
|
341 |
-
|
342 |
-
|
343 |
})
|
344 |
|
345 |
comparison_df = pd.DataFrame(comparison_results)
|
346 |
|
347 |
# Affichage des résultats
|
348 |
-
st.dataframe(comparison_df.set_index(
|
349 |
|
350 |
# Graphique de comparaison
|
351 |
fig_models = go.Figure()
|
352 |
|
353 |
fig_models.add_trace(go.Bar(
|
354 |
-
x=comparison_df[
|
355 |
-
y=comparison_df[
|
356 |
-
name=
|
357 |
-
marker_color=
|
358 |
))
|
359 |
|
360 |
fig_models.add_trace(go.Bar(
|
361 |
-
x=comparison_df[
|
362 |
-
y=comparison_df[
|
363 |
-
name=
|
364 |
-
marker_color=
|
365 |
))
|
366 |
|
367 |
fig_models.update_layout(
|
368 |
-
title=
|
369 |
-
barmode=
|
370 |
height=500
|
371 |
)
|
372 |
|
@@ -376,9 +383,9 @@ with tab4:
|
|
376 |
st.header("Assistant intelligent")
|
377 |
|
378 |
# Initialisation du chatbot
|
379 |
-
if
|
380 |
st.session_state.chatbot = InflationChatbot()
|
381 |
-
if
|
382 |
st.session_state.chat_history = []
|
383 |
|
384 |
# Affichage de l'historique de chat
|
@@ -400,7 +407,7 @@ with tab4:
|
|
400 |
prompt,
|
401 |
country=selected_country,
|
402 |
model=selected_model,
|
403 |
-
metrics=metrics if
|
404 |
)
|
405 |
st.markdown(response)
|
406 |
|
|
|
6 |
from datetime import datetime
|
7 |
import plotly.express as px
|
8 |
import plotly.graph_objects as go
|
9 |
+
from data_processing import load_and_preprocess_data, prepare_timeseries_data, apply_scenarios
|
10 |
from models import train_arima, train_mlp, train_lstm, train_bayesian_network
|
11 |
from chatbot import InflationChatbot
|
12 |
import os
|
|
|
62 |
@st.cache_data
|
63 |
def load_data():
|
64 |
try:
|
65 |
+
data = pd.read_csv("inflation_beac_complet_2010_2025.csv")
|
|
|
66 |
except:
|
|
|
67 |
base_url = "https://raw.githubusercontent.com/username/repo/main/"
|
68 |
+
data = pd.read_csv(base_url + "inflation_beac_complet_2010_2025.csv")
|
69 |
+
|
70 |
+
# Renommer les colonnes
|
71 |
+
data = data.rename(columns={
|
72 |
+
"Année": "Date",
|
73 |
+
"Taux d'inflation (%)": "Taux d'inflation",
|
74 |
+
"Masse monétaire (M2)": "Masse monétaire M2",
|
75 |
+
"Croissance PIB (%)": "Taux de croissance du PIB",
|
76 |
+
"Taux de change FCFA/USD": "Taux de change"
|
77 |
+
})
|
78 |
return data
|
79 |
|
80 |
data = load_and_preprocess_data()
|
|
|
85 |
st.markdown("## Paramètres de l'analyse")
|
86 |
|
87 |
# Sélection du pays
|
88 |
+
pays_options = data["Pays"].unique()
|
89 |
selected_country = st.selectbox("Sélectionnez un pays", pays_options)
|
90 |
|
91 |
# Filtrage des dates disponibles pour le pays sélectionné
|
92 |
+
country_data = data[data["Pays"] == selected_country]
|
93 |
+
min_date = country_data["Date"].min()
|
94 |
+
max_date = country_data["Date"].max()
|
95 |
|
96 |
# Sélection de la plage temporelle
|
97 |
start_date = st.date_input("Date de début",
|
98 |
+
value=datetime.strptime(min_date, "%Y-%m-%d"),
|
99 |
+
min_value=datetime.strptime(min_date, "%Y-%m-%d"),
|
100 |
+
max_value=datetime.strptime(max_date, "%Y-%m-%d"))
|
101 |
|
102 |
end_date = st.date_input("Date de fin",
|
103 |
+
value=datetime.strptime(max_date, "%Y-%m-%d"),
|
104 |
+
min_value=datetime.strptime(min_date, "%Y-%m-%d"),
|
105 |
+
max_value=datetime.strptime(max_date, "%Y-%m-%d"))
|
106 |
|
107 |
# Sélection du modèle
|
108 |
model_options = {
|
|
|
146 |
|
147 |
# Filtrage des données selon les sélections
|
148 |
filtered_data = data[
|
149 |
+
(data["Pays"] == selected_country) &
|
150 |
+
(data["Date"] >= str(start_date)) &
|
151 |
+
(data["Date"] <= str(end_date))
|
152 |
]
|
153 |
|
154 |
# Graphique de l'inflation
|
155 |
+
fig1 = px.line(filtered_data, x="Date", y="Taux d'inflation",
|
156 |
+
title=f"Évolution de l'inflation en {selected_country}",
|
157 |
+
labels={"Taux d'inflation": "Taux d'inflation (%)"})
|
158 |
fig1.update_layout(height=400)
|
159 |
st.plotly_chart(fig1, use_container_width=True)
|
160 |
|
161 |
# Graphiques des autres indicateurs
|
162 |
col1, col2 = st.columns(2)
|
163 |
with col1:
|
164 |
+
fig2 = px.line(filtered_data, x="Date", y="Masse monétaire M2",
|
165 |
+
title=f"Masse monétaire M2 en {selected_country}")
|
166 |
st.plotly_chart(fig2, use_container_width=True)
|
167 |
|
168 |
+
fig3 = px.line(filtered_data, x="Date", y="Taux de croissance du PIB",
|
169 |
+
title=f"Croissance du PIB en {selected_country}")
|
170 |
st.plotly_chart(fig3, use_container_width=True)
|
171 |
|
172 |
with col2:
|
173 |
+
fig4 = px.line(filtered_data, x="Date", y="Taux directeur",
|
174 |
+
title=f"Taux directeur en {selected_country}")
|
175 |
st.plotly_chart(fig4, use_container_width=True)
|
176 |
|
177 |
+
fig5 = px.line(filtered_data, x="Date", y="Balance commerciale",
|
178 |
+
title=f"Balance commerciale en {selected_country}")
|
179 |
st.plotly_chart(fig5, use_container_width=True)
|
180 |
|
181 |
# Matrice de corrélation
|
182 |
st.subheader("Matrice de corrélation")
|
183 |
+
corr_data = filtered_data.drop(columns=["Date", "Pays"])
|
184 |
corr_matrix = corr_data.corr()
|
185 |
|
186 |
fig6 = go.Figure(data=go.Heatmap(
|
187 |
z=corr_matrix.values,
|
188 |
x=corr_matrix.columns,
|
189 |
y=corr_matrix.columns,
|
190 |
+
colorscale="RdBu",
|
191 |
zmin=-1,
|
192 |
zmax=1,
|
193 |
colorbar=dict(title="Coefficient de corrélation")
|
194 |
))
|
195 |
+
fig6.update_layout(title="Corrélations entre les indicateurs macroéconomiques",
|
196 |
height=500)
|
197 |
st.plotly_chart(fig6, use_container_width=True)
|
198 |
|
|
|
241 |
|
242 |
# Ajout des données réelles
|
243 |
fig_pred.add_trace(go.Scatter(
|
244 |
+
x=predictions["Date"],
|
245 |
+
y=predictions["Inflation réelle"],
|
246 |
+
name="Inflation réelle",
|
247 |
+
line=dict(color="blue")
|
248 |
))
|
249 |
|
250 |
# Ajout des prédictions
|
251 |
fig_pred.add_trace(go.Scatter(
|
252 |
+
x=predictions["Date"],
|
253 |
+
y=predictions["Inflation prédite"],
|
254 |
+
name="Inflation prédite",
|
255 |
+
line=dict(color="red", dash="dash")
|
256 |
))
|
257 |
|
258 |
fig_pred.update_layout(
|
259 |
+
title=f"Comparaison inflation réelle vs prédite - {selected_model}",
|
260 |
+
xaxis_title="Date",
|
261 |
+
yaxis_title="Taux d'inflation (%)",
|
262 |
height=500
|
263 |
)
|
264 |
st.plotly_chart(fig_pred, use_container_width=True)
|
|
|
267 |
st.subheader("Export des résultats")
|
268 |
|
269 |
# Format CSV
|
270 |
+
csv = predictions.to_csv(index=False).encode("utf-8")
|
271 |
st.download_button(
|
272 |
label="Télécharger les prédictions (CSV)",
|
273 |
data=csv,
|
|
|
277 |
|
278 |
# Format Excel
|
279 |
output = BytesIO()
|
280 |
+
with pd.ExcelWriter(output, engine="xlsxwriter") as writer:
|
281 |
+
predictions.to_excel(writer, sheet_name="Prédictions", index=False)
|
282 |
+
metrics_df = pd.DataFrame.from_dict(metrics, orient="index", columns=["Valeur"])
|
283 |
+
metrics_df.to_excel(writer, sheet_name="Métriques")
|
284 |
excel_data = output.getvalue()
|
285 |
|
286 |
st.download_button(
|
|
|
296 |
# Comparaison entre pays
|
297 |
st.subheader("Comparaison entre pays")
|
298 |
selected_countries = st.multiselect("Sélectionnez les pays à comparer",
|
299 |
+
data["Pays"].unique(),
|
300 |
default=[selected_country])
|
301 |
|
302 |
if selected_countries:
|
303 |
+
compare_data = data[data["Pays"].isin(selected_countries)]
|
304 |
|
305 |
# Graphique comparatif
|
306 |
+
fig_compare = px.line(compare_data, x="Date", y="Taux d'inflation",
|
307 |
+
color="Pays", title="Comparaison des taux d'inflation")
|
308 |
st.plotly_chart(fig_compare, use_container_width=True)
|
309 |
|
310 |
+
# Dernières valeurs disponibles
|
311 |
st.subheader("Dernières valeurs disponibles")
|
312 |
+
latest_data = compare_data.sort_values("Date").groupby("Pays").last().reset_index()
|
313 |
+
st.dataframe(latest_data.set_index("Pays").drop(columns=["Date"]).style.background_gradient(cmap="Blues"))
|
314 |
|
315 |
# Comparaison entre modèles
|
316 |
st.subheader("Comparaison entre modèles")
|
|
|
342 |
)
|
343 |
|
344 |
comparison_results.append({
|
345 |
+
"Modèle": model_name,
|
346 |
+
"MAE": metrics["mae"],
|
347 |
+
"RMSE": metrics["rmse"],
|
348 |
+
"R²": metrics["r2"],
|
349 |
+
"Temps d'entraînement (s)": metrics["training_time"]
|
350 |
})
|
351 |
|
352 |
comparison_df = pd.DataFrame(comparison_results)
|
353 |
|
354 |
# Affichage des résultats
|
355 |
+
st.dataframe(comparison_df.set_index("Modèle").style.background_gradient(cmap="Blues"))
|
356 |
|
357 |
# Graphique de comparaison
|
358 |
fig_models = go.Figure()
|
359 |
|
360 |
fig_models.add_trace(go.Bar(
|
361 |
+
x=comparison_df["Modèle"],
|
362 |
+
y=comparison_df["MAE"],
|
363 |
+
name="MAE",
|
364 |
+
marker_color="indianred"
|
365 |
))
|
366 |
|
367 |
fig_models.add_trace(go.Bar(
|
368 |
+
x=comparison_df["Modèle"],
|
369 |
+
y=comparison_df["RMSE"],
|
370 |
+
name="RMSE",
|
371 |
+
marker_color="lightsalmon"
|
372 |
))
|
373 |
|
374 |
fig_models.update_layout(
|
375 |
+
title="Comparaison des performances des modèles",
|
376 |
+
barmode="group",
|
377 |
height=500
|
378 |
)
|
379 |
|
|
|
383 |
st.header("Assistant intelligent")
|
384 |
|
385 |
# Initialisation du chatbot
|
386 |
+
if "chatbot" not in st.session_state:
|
387 |
st.session_state.chatbot = InflationChatbot()
|
388 |
+
if "chat_history" not in st.session_state:
|
389 |
st.session_state.chat_history = []
|
390 |
|
391 |
# Affichage de l'historique de chat
|
|
|
407 |
prompt,
|
408 |
country=selected_country,
|
409 |
model=selected_model,
|
410 |
+
metrics=metrics if "metrics" in locals() else None
|
411 |
)
|
412 |
st.markdown(response)
|
413 |
|