Spaces:
Running
Running
import streamlit as st | |
import pandas as pd | |
import numpy as np | |
import matplotlib.pyplot as plt | |
import seaborn as sns | |
from datetime import datetime | |
import plotly.express as px | |
import plotly.graph_objects as go | |
from data_processing import load_and_preprocess_data, prepare_timeseries_data, apply_scenarios | |
from models import train_arima, train_mlp, train_lstm, train_bayesian_network | |
from chatbot import InflationChatbot | |
import os | |
import pickle | |
from io import BytesIO | |
# Configuration de la page | |
st.set_page_config( | |
page_title="Prédiction de l'Inflation - Zone BEAC", | |
page_icon="📈", | |
layout="wide", | |
initial_sidebar_state="expanded" | |
) | |
# Style CSS personnalisé | |
st.markdown(""" | |
<style> | |
/* ========== FOND ET STRUCTURE GÉNÉRALE ========== */ | |
.main { | |
background-color: #f5f7fa; | |
font-family: 'Helvetica Neue', Arial, sans-serif; | |
} | |
/* ========== EN-TÊTE ========== */ | |
.header { | |
background: linear-gradient(135deg, #003366 0%, #002244 100%); | |
color: white; | |
padding: 1.5rem; | |
border-radius: 0 0 15px 15px; | |
margin-bottom: 2rem; | |
box-shadow: 0 4px 12px rgba(0, 51, 102, 0.15); | |
} | |
.header h1 { | |
font-weight: 700; | |
margin-bottom: 0.5rem; | |
font-size: 1.8rem; | |
} | |
.header p { | |
opacity: 0.9; | |
font-size: 1rem; | |
} | |
/* ========== SIDEBAR ========== */ | |
[data-testid="stSidebar"] { | |
background: linear-gradient(180deg, #003366 0%, #002244 100%) !important; | |
padding-top: 2rem; | |
} | |
.sidebar .sidebar-content { | |
color: white; | |
} | |
/* ========== CARTES DE MÉTRIQUES ========== */ | |
.stMetric { | |
background: white; | |
border-radius: 10px; | |
padding: 1.5rem; | |
box-shadow: 0 2px 8px rgba(0, 51, 102, 0.1); | |
border-left: 4px solid #003366; | |
margin-bottom: 1.5rem; | |
} | |
.stMetric > div > div { | |
font-size: 1.8rem !important; | |
font-weight: 600; | |
color: #003366 !important; | |
} | |
.stMetric > div > label { | |
font-size: 0.9rem !important; | |
color: #555 !important; | |
} | |
/* ========== ONGLETS ========== */ | |
[data-baseweb="tab-list"] { | |
gap: 5px; | |
padding: 0 1rem; | |
} | |
[data-baseweb="tab"] { | |
background: #e9ecef !important; | |
border-radius: 8px !important; | |
padding: 10px 20px !important; | |
margin: 0 2px !important; | |
font-weight: 500; | |
transition: all 0.3s ease; | |
} | |
[aria-selected="true"] { | |
background: #003366 !important; | |
color: white !important; | |
box-shadow: 0 2px 5px rgba(0, 51, 102, 0.2); | |
} | |
/* ========== GRAPHIQUES ========== */ | |
.stPlotlyChart { | |
border-radius: 12px; | |
box-shadow: 0 4px 12px rgba(0, 51, 102, 0.1); | |
background: white; | |
padding: 1rem; | |
} | |
/* ========== BOUTONS ========== */ | |
.stButton > button { | |
background-color: #003366; | |
color: white; | |
border-radius: 6px; | |
border: none; | |
padding: 0.7rem 1.5rem; | |
font-weight: 500; | |
transition: all 0.3s ease; | |
} | |
.stButton > button:hover { | |
background-color: #002244; | |
transform: translateY(-1px); | |
box-shadow: 0 2px 8px rgba(0, 51, 102, 0.2); | |
} | |
/* ========== TABLEAUX ========== */ | |
.stDataFrame { | |
border-radius: 10px; | |
box-shadow: 0 2px 8px rgba(0, 51, 102, 0.1); | |
} | |
/* ========== EXPANDERS ========== */ | |
.stExpander { | |
background: white; | |
border-radius: 10px; | |
box-shadow: 0 2px 8px rgba(0, 51, 102, 0.1); | |
margin-bottom: 1.5rem; | |
} | |
.stExpander > details > summary { | |
background: #003366 !important; | |
color: white !important; | |
padding: 1rem !important; | |
border-radius: 8px !important; | |
font-weight: 600; | |
} | |
/* ========== FOOTER ========== */ | |
.footer { | |
background: #003366; | |
color: white; | |
padding: 1rem; | |
text-align: center; | |
margin-top: 3rem; | |
border-radius: 8px 8px 0 0; | |
font-size: 0.85rem; | |
} | |
/* ========== COULEURS THÉMATIQUES ========== */ | |
:root { | |
--beac-primary: #003366; | |
--beac-secondary: #D4AF37; | |
--beac-accent: #8BBEFF; | |
--beac-success: #28a745; | |
--beac-danger: #dc3545; | |
} | |
/* ========== ANIMATIONS ========== */ | |
@keyframes fadeIn { | |
from { opacity: 0; transform: translateY(10px); } | |
to { opacity: 1; transform: translateY(0); } | |
} | |
.stMetric, .stPlotlyChart, .stDataFrame { | |
animation: fadeIn 0.5s ease-out; | |
} | |
</style> | |
""", unsafe_allow_html=True) | |
# Titre de l'application | |
st.title("📊 Prédiction de l'Inflation dans la Zone BEAC") | |
st.markdown(""" | |
**Application interactive pour la prédiction des taux d'inflation** | |
Cameroun • Gabon • Tchad • Congo • Guinée équatoriale • Centrafrique | |
""") | |
# Chargement des données | |
def load_data(): | |
try: | |
data = pd.read_csv("inflation_beac_complet_2010_2025.csv") | |
data['Année'] = pd.to_datetime(data['Année'], format='%Y').dt.strftime('%Y-%m-%d') | |
return data | |
except Exception as e: | |
st.error(f"Erreur de chargement: {str(e)}") | |
st.stop() | |
# Chargement et prétraitement des données | |
data = load_and_preprocess_data() | |
# Sidebar - Configuration | |
with st.sidebar: | |
st.image("beac_logo.png", width=150) | |
st.markdown("## Paramètres de l'analyse") | |
# Sélection du pays | |
pays_options = data["Pays"].unique() | |
selected_country = st.selectbox("Sélectionnez un pays", pays_options) | |
# Filtrage des dates disponibles pour le pays sélectionné | |
country_data = data[data["Pays"] == selected_country] | |
min_date = country_data["Année"].min() | |
max_date = country_data["Année"].max() | |
# Sélection de la plage temporelle | |
start_date = st.date_input("Année de début", | |
value=datetime.strptime(min_date, "%Y-%m-%d"), | |
min_value=datetime.strptime(min_date, "%Y-%m-%d"), | |
max_value=datetime.strptime(max_date, "%Y-%m-%d")) | |
end_date = st.date_input("Année de fin", | |
value=datetime.strptime(max_date, "%Y-%m-%d"), | |
min_value=datetime.strptime(min_date, "%Y-%m-%d"), | |
max_value=datetime.strptime(max_date, "%Y-%m-%d")) | |
# Sélection du modèle | |
model_options = { | |
"ARIMA": "Modèle statistique pour séries temporelles", | |
"MLP": "Réseau de neurones à perceptrons multicouches", | |
"LSTM": "Réseau de neurones récurrents à mémoire longue", | |
"Réseau Bayésien": "Modèle probabiliste graphique" | |
} | |
selected_model = st.selectbox("Modèle de prédiction", list(model_options.keys()), | |
format_func=lambda x: f"{x} - {model_options[x]}") | |
# Paramètres avancés | |
with st.expander("Paramètres avancés"): | |
st.markdown("### Scénarios personnalisés") | |
taux_directeur_change = st.slider("Variation du taux directeur (%)", -5.0, 5.0, 0.0, 0.1) | |
pib_change = st.slider("Variation du PIB (%)", -5.0, 5.0, 0.0, 0.1) | |
m2_change = st.slider("Variation de la masse monétaire M2 (%)", -10.0, 10.0, 0.0, 0.1) | |
if selected_model == "ARIMA": | |
p = st.slider("Paramètre p (AR)", 0, 5, 1) | |
d = st.slider("Paramètre d (I)", 0, 2, 1) | |
q = st.slider("Paramètre q (MA)", 0, 5, 1) | |
elif selected_model == "MLP": | |
hidden_layers = st.slider("Nombre de couches cachées", 1, 5, 2) | |
neurons = st.slider("Neurones par couche", 10, 100, 50) | |
epochs = st.slider("Nombre d'epochs", 1, 500, 100) | |
elif selected_model == "LSTM": | |
# Calcul dynamique du look_back maximum | |
country_data = data[data["Pays"] == selected_country] | |
max_look_back = max(1, len(country_data) // 2) | |
lstm_units = st.slider("Unités LSTM", 10, 100, 50) | |
epochs = st.slider("Nombre d'epochs", 10, 200, 50) | |
look_back = st.slider("Période de look-back", | |
min_value=1, | |
max_value=max_look_back, | |
value=min(4, max_look_back)) | |
st.markdown("---") | |
st.markdown("Développé par MONTI VINCENT LOIC") | |
st.markdown("© 2023 - Tous droits réservés") | |
# Onglets principaux | |
tab1, tab2, tab3, tab4 = st.tabs(["📈 Visualisation", "🤖 Prédiction", "📊 Dashboard", "💬 Assistant"]) | |
with tab1: | |
st.header("Analyse des données historiques") | |
# Filtrage des données selon les sélections | |
filtered_data = data[ | |
(data["Pays"] == selected_country) & | |
(data["Année"] >= str(start_date)) & | |
(data["Année"] <= str(end_date)) | |
] | |
# Graphique de l'inflation | |
fig1 = px.line(filtered_data, x="Année", y="Taux d'inflation (%)", | |
title=f"Évolution de l'inflation en {selected_country}", | |
labels={"Taux d'inflation (%)": "Taux d'inflation (%)"}) | |
fig1.update_layout(height=400) | |
st.plotly_chart(fig1, use_container_width=True) | |
# Graphiques des autres indicateurs | |
col1, col2 = st.columns(2) | |
with col1: | |
fig2 = px.line(filtered_data, x="Année", y="Masse monétaire (M2)", | |
title=f"Masse monétaire M2 en {selected_country}") | |
st.plotly_chart(fig2, use_container_width=True) | |
fig3 = px.line(filtered_data, x="Année", y="Croissance PIB (%)", | |
title=f"Croissance du PIB en {selected_country}") | |
st.plotly_chart(fig3, use_container_width=True) | |
with col2: | |
fig4 = px.line(filtered_data, x="Année", y="Taux directeur", | |
title=f"Taux directeur en {selected_country}") | |
st.plotly_chart(fig4, use_container_width=True) | |
fig5 = px.line(filtered_data, x="Année", y="Balance commerciale", | |
title=f"Balance commerciale en {selected_country}") | |
st.plotly_chart(fig5, use_container_width=True) | |
# Matrice de corrélation | |
st.subheader("Matrice de corrélation") | |
corr_data = filtered_data.drop(columns=["Année", "Pays"]) | |
corr_matrix = corr_data.corr() | |
fig6 = go.Figure(data=go.Heatmap( | |
z=corr_matrix.values, | |
x=corr_matrix.columns, | |
y=corr_matrix.columns, | |
colorscale="RdBu", | |
zmin=-1, | |
zmax=1, | |
colorbar=dict(title="Coefficient de corrélation") | |
)) | |
fig6.update_layout(title="Corrélations entre les indicateurs macroéconomiques", | |
height=500) | |
st.plotly_chart(fig6, use_container_width=True) | |
with tab2: | |
st.header("Prédiction de l'inflation") | |
# Entraînement du modèle sélectionné | |
with st.spinner(f"Entraînement du modèle {selected_model} en cours..."): | |
if selected_model == "ARIMA": | |
model, predictions, metrics = train_arima( | |
data, selected_country, start_date, end_date, | |
p, d, q, taux_directeur_change, pib_change, m2_change | |
) | |
elif selected_model == "MLP": | |
model, predictions, metrics = train_mlp( | |
data, selected_country, start_date, end_date, | |
hidden_layers, neurons, epochs, | |
taux_directeur_change, pib_change, m2_change | |
) | |
elif selected_model == "LSTM": | |
model, predictions, metrics = train_lstm( | |
data, selected_country, start_date, end_date, | |
lstm_units, epochs, look_back, | |
taux_directeur_change, pib_change, m2_change | |
) | |
elif selected_model == "Réseau Bayésien": | |
model, predictions, metrics = train_bayesian_network( | |
data, selected_country, start_date, end_date, | |
taux_directeur_change, pib_change, m2_change | |
) | |
# Affichage des résultats si succès | |
st.success("Modèle entraîné avec succès!") | |
# Métriques de performance | |
col1, col2, col3 = st.columns(3) | |
with col1: | |
st.metric(label="MAE (Erreur absolue moyenne)", value=f"{metrics['mae']:.2f}") | |
with col2: | |
st.metric(label="RMSE (Racine de l'erreur quadratique moyenne)", value=f"{metrics['rmse']:.2f}") | |
# Graphique de comparaison des prédictions | |
fig_pred = go.Figure() | |
# Ajout des données réelles | |
fig_pred.add_trace(go.Scatter( | |
x=predictions["Année"], | |
y=predictions["Inflation réelle"], | |
name="Inflation réelle", | |
line=dict(color="blue") | |
)) | |
# Ajout des prédictions | |
fig_pred.add_trace(go.Scatter( | |
x=predictions["Année"], | |
y=predictions["Inflation prédite"], | |
name="Inflation prédite", | |
line=dict(color="red", dash="dash") | |
)) | |
fig_pred.update_layout( | |
title=f"Comparaison inflation réelle vs prédite - {selected_model}", | |
xaxis_title="Année", | |
yaxis_title="Taux d'inflation (%)", | |
height=500 | |
) | |
st.plotly_chart(fig_pred, use_container_width=True) | |
# Téléchargement des résultats | |
st.subheader("Export des résultats") | |
# Format CSV | |
csv = predictions.to_csv(index=False).encode("utf-8") | |
st.download_button( | |
label="Télécharger les prédictions (CSV)", | |
data=csv, | |
file_name=f"predictions_inflation_{selected_country}_{selected_model}.csv", | |
mime="text/csv" | |
) | |
# Format Excel | |
output = BytesIO() | |
with pd.ExcelWriter(output, engine="xlsxwriter") as writer: | |
predictions.to_excel(writer, sheet_name="Prédictions", index=False) | |
metrics_df = pd.DataFrame.from_dict(metrics, orient="index", columns=["Valeur"]) | |
metrics_df.to_excel(writer, sheet_name="Métriques") | |
excel_data = output.getvalue() | |
st.download_button( | |
label="Télécharger les résultats (Excel)", | |
data=excel_data, | |
file_name=f"resultats_inflation_{selected_country}_{selected_model}.xlsx", | |
mime="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet" | |
) | |
with tab3: | |
st.header("Tableau de bord comparatif") | |
# Comparaison entre pays | |
st.subheader("Comparaison entre pays") | |
selected_countries = st.multiselect("Sélectionnez les pays à comparer", | |
data["Pays"].unique(), | |
default=[selected_country]) | |
if selected_countries: | |
compare_data = data[data["Pays"].isin(selected_countries)] | |
# Graphique comparatif - Correction ici | |
fig_compare = px.line(compare_data, x="Année", y="Taux d'inflation (%)", | |
color="Pays", title="Comparaison des taux d'inflation") | |
st.plotly_chart(fig_compare, use_container_width=True) | |
# Dernières valeurs disponibles | |
st.subheader("Dernières valeurs disponibles") | |
latest_data = compare_data.sort_values("Année").groupby("Pays").last().reset_index() | |
st.dataframe(latest_data.set_index("Pays").style.background_gradient(cmap="Blues")) | |
# Comparaison entre modèles | |
st.subheader("Comparaison entre modèles") | |
if st.button("Lancer la comparaison des modèles"): | |
with st.spinner("Comparaison des modèles en cours..."): | |
models_to_compare = ["ARIMA", "MLP", "LSTM", "Réseau Bayésien"] | |
comparison_results = [] | |
for model_name in models_to_compare: | |
if model_name == "ARIMA": | |
_, _, metrics = train_arima( | |
data, selected_country, start_date, end_date, | |
1, 1, 1, 0, 0, 0 | |
) | |
elif model_name == "MLP": | |
_, _, metrics = train_mlp( | |
data, selected_country, start_date, end_date, | |
2, 50, 50, 0, 0, 0 | |
) | |
elif model_name == "LSTM": | |
_, _, metrics = train_lstm( | |
data, selected_country, start_date, end_date, | |
50, 50, 12, 0, 0, 0 | |
) | |
elif model_name == "Réseau Bayésien": | |
_, _, metrics = train_bayesian_network( | |
data, selected_country, start_date, end_date, 0, 0, 0 | |
) | |
comparison_results.append({ | |
"Modèle": model_name, | |
"MAE": metrics["mae"], | |
"RMSE": metrics["rmse"], | |
"R²": metrics["r2"], | |
"Temps d'entraînement (s)": metrics.get("training_time", "N/A") | |
}) | |
comparison_df = pd.DataFrame(comparison_results) | |
# Affichage des résultats | |
st.dataframe(comparison_df.set_index("Modèle").style.background_gradient(cmap="Blues")) | |
# Graphique de comparaison | |
fig_models = go.Figure() | |
fig_models.add_trace(go.Bar( | |
x=comparison_df["Modèle"], | |
y=comparison_df["MAE"], | |
name="MAE", | |
marker_color="indianred" | |
)) | |
fig_models.add_trace(go.Bar( | |
x=comparison_df["Modèle"], | |
y=comparison_df["RMSE"], | |
name="RMSE", | |
marker_color="lightsalmon" | |
)) | |
fig_models.update_layout( | |
title="Comparaison des performances des modèles", | |
barmode="group", | |
height=500 | |
) | |
st.plotly_chart(fig_models, use_container_width=True) | |
with tab4: | |
st.header("Assistant intelligent") | |
# Clés API (⚠️ à ne jamais exposer publiquement) | |
COHERE_API_KEY = "moZJbgxiW9cW8Wqo0ecce0pa84uf3eT6F2oL1whB" | |
ELEVEN_API_KEY = "sk_6c5472c80964f88fcdc9d9c189db749143ae1a0d2c8f26f3" | |
# Instanciation unique | |
if "chatbot" not in st.session_state: | |
st.session_state.chatbot = InflationChatbot(COHERE_API_KEY, ELEVEN_API_KEY) | |
if "chat_history" not in st.session_state: | |
st.session_state.chat_history = [] | |
# Affichage historique | |
for message in st.session_state.chat_history: | |
with st.chat_message(message["role"]): | |
st.markdown(message["content"]) | |
# Entrée utilisateur | |
if prompt := st.chat_input("Posez votre question sur l'inflation dans la zone BEAC..."): | |
with st.chat_message("user"): | |
st.markdown(prompt) | |
st.session_state.chat_history.append({"role": "user", "content": prompt}) | |
with st.spinner("Réflexion en cours..."): | |
result = st.session_state.chatbot.ask(prompt, st.session_state.chat_history) | |
if "error" in result: | |
st.error(f"Erreur : {result['error']}") | |
else: | |
with st.chat_message("assistant"): | |
st.markdown(result["reply"]) | |
st.markdown(f"🔊 **Résumé vocal :** _{result['summary']}_") | |
if result["audio"]: | |
st.audio(result["audio"], format="audio/mp3") | |
st.session_state.chat_history.append({"role": "assistant", "content": result["reply"]}) | |
# Pied de page | |
st.markdown("---") | |
st.markdown(""" | |
**Application de prédiction de l'inflation** | |
Données: BEAC • Modèles: ARIMA, MLP, LSTM, Réseaux Bayésiens | |
Version 1.0 • | |
""") |