File size: 5,693 Bytes
673b7cf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
# app.py : Application Streamlit avec Modèle Classique et ANN
import streamlit as st
import pandas as pd
import numpy as np
import joblib
import matplotlib.pyplot as plt
from tensorflow.keras.models import load_model
st.set_page_config(page_title="App Cardiaque", layout="centered")
# Chargement des modèles et scalers
model = joblib.load("meilleur_modele.pkl")
scaler = joblib.load("scaler.pkl")
ann_model = load_model("ann_model.h5")
ann_scaler = joblib.load("ann_model_scaler.pkl")['scaler']
def home():
st.title("🫀 Application de Prédiction Cardiaque")
st.markdown(
"""
Cette application prédit le **risque de décès** chez les patients atteints d'insuffisance cardiaque
à partir de données cliniques.
**Fonctionnalités :**
- Prédiction individuelle (modèle classique)
- Prédiction avec modèle ANN (réseau de neurones)
- Chargement de fichiers CSV pour des prédictions multiples
- Historique de session
- Visualisation des probabilités et importance des variables
"""
)
def collect_input():
age = st.slider("Âge", 18, 100, 50)
anaemia = st.selectbox("Anémie", [0, 1], format_func=lambda x: "Oui" if x else "Non")
creatinine_phosphokinase = st.number_input("Créatinine phosphokinase", 10, 10000, 200)
diabetes = st.selectbox("Diabète", [0, 1], format_func=lambda x: "Oui" if x else "Non")
ejection_fraction = st.slider("Fraction d'éjection (%)", 10, 80, 35)
high_blood_pressure = st.selectbox("Hypertension", [0, 1], format_func=lambda x: "Oui" if x else "Non")
platelets = st.number_input("Plaquettes (en µL)", 25000.0, 850000.0, 300000.0)
serum_creatinine = st.number_input("Créatinine sérique", 0.1, 10.0, 1.0)
serum_sodium = st.slider("Sodium sérique", 110, 150, 135)
sex = st.selectbox("Sexe", [0, 1], format_func=lambda x: "Homme" if x else "Femme")
smoking = st.selectbox("Fumeur", [0, 1], format_func=lambda x: "Oui" if x else "Non")
time = st.slider("Temps depuis admission (jours)", 0, 300, 100)
return pd.DataFrame([{
'age': age, 'anaemia': anaemia, 'creatinine_phosphokinase': creatinine_phosphokinase,
'diabetes': diabetes, 'ejection_fraction': ejection_fraction, 'high_blood_pressure': high_blood_pressure,
'platelets': platelets, 'serum_creatinine': serum_creatinine, 'serum_sodium': serum_sodium,
'sex': sex, 'smoking': smoking, 'time': time
}])
def prediction_page():
st.title("🔍 Prédiction Individuelle (Modèle Classique)")
input_df = collect_input()
st.write("### Données entrées :", input_df)
if st.button("Prédire avec modèle classique"):
X_scaled = scaler.transform(input_df)
prediction = model.predict(X_scaled)[0]
proba = model.predict_proba(X_scaled)[0]
st.markdown(f"**Probabilité de décès :** {proba[1]*100:.2f}%")
st.markdown(f"**Probabilité de survie :** {proba[0]*100:.2f}%")
fig, ax = plt.subplots()
ax.bar(["Survie", "Décès"], proba, color=["green", "red"])
st.pyplot(fig)
if prediction == 1:
st.error("⚠️ Risque élevé de décès détecté.")
else:
st.success("✅ Le patient a de bonnes chances de survie.")
if "history" not in st.session_state:
st.session_state["history"] = []
st.session_state["history"].append({
"Survie (%)": round(proba[0]*100, 2),
"Décès (%)": round(proba[1]*100, 2),
"Décès prédit": "Oui" if prediction == 1 else "Non"
})
def ann_prediction_page():
st.title("🤖 Prédiction avec Réseau de Neurones (ANN)")
input_df = collect_input()
st.write("### Données entrées :", input_df)
if st.button("Prédire avec modèle ANN"):
X_ann_scaled = ann_scaler.transform(input_df)
ann_pred = ann_model.predict(X_ann_scaled)[0][0]
st.metric("Probabilité de décès (ANN)", f"{ann_pred*100:.2f}%")
fig, ax = plt.subplots()
ax.bar(["Survie", "Décès"], [1-ann_pred, ann_pred], color=["green", "red"])
st.pyplot(fig)
if ann_pred >= 0.5:
st.error("⚠️ Risque élevé détecté (ANN)")
else:
st.success("✅ Faible risque détecté (ANN)")
def batch_prediction_page():
st.title("📁 Prédiction depuis un Fichier CSV")
uploaded_file = st.file_uploader("Charger un fichier CSV", type="csv")
if uploaded_file:
try:
data = pd.read_csv(uploaded_file)
X_scaled = scaler.transform(data)
preds = model.predict(X_scaled)
probas = model.predict_proba(X_scaled)
data["Décès prédit"] = preds
data["Probabilité de décès"] = np.round(probas[:, 1] * 100, 2)
st.dataframe(data)
st.download_button("📥 Télécharger résultats", data.to_csv(index=False), "resultats.csv")
except Exception as e:
st.error(f"Erreur lors du traitement : {e}")
def history_page():
st.title("📚 Historique de Prédictions")
if "history" in st.session_state and st.session_state["history"]:
hist_df = pd.DataFrame(st.session_state["history"])
st.dataframe(hist_df)
else:
st.info("Aucune prédiction enregistrée dans cette session.")
pages = {
"🏠 Accueil": home,
"🔍 Prédiction (Classique)": prediction_page,
"🤖 Prédiction Deep Learning": ann_prediction_page,
"📁 Prédiction en lot": batch_prediction_page,
"📚 Historique": history_page
}
choice = st.sidebar.radio("Navigation", list(pages.keys()))
pages[choice]()
|