File size: 5,693 Bytes
673b7cf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
# app.py : Application Streamlit avec Modèle Classique et ANN

import streamlit as st
import pandas as pd
import numpy as np
import joblib
import matplotlib.pyplot as plt
from tensorflow.keras.models import load_model

st.set_page_config(page_title="App Cardiaque", layout="centered")

# Chargement des modèles et scalers
model = joblib.load("meilleur_modele.pkl")
scaler = joblib.load("scaler.pkl")
ann_model = load_model("ann_model.h5")
ann_scaler = joblib.load("ann_model_scaler.pkl")['scaler']

def home():
    st.title("🫀 Application de Prédiction Cardiaque")
    st.markdown(
        """
        Cette application prédit le **risque de décès** chez les patients atteints d'insuffisance cardiaque
        à partir de données cliniques.

        **Fonctionnalités :**
        - Prédiction individuelle (modèle classique)
        - Prédiction avec modèle ANN (réseau de neurones)
        - Chargement de fichiers CSV pour des prédictions multiples
        - Historique de session
        - Visualisation des probabilités et importance des variables

        """
    )

def collect_input():
    age = st.slider("Âge", 18, 100, 50)
    anaemia = st.selectbox("Anémie", [0, 1], format_func=lambda x: "Oui" if x else "Non")
    creatinine_phosphokinase = st.number_input("Créatinine phosphokinase", 10, 10000, 200)
    diabetes = st.selectbox("Diabète", [0, 1], format_func=lambda x: "Oui" if x else "Non")
    ejection_fraction = st.slider("Fraction d'éjection (%)", 10, 80, 35)
    high_blood_pressure = st.selectbox("Hypertension", [0, 1], format_func=lambda x: "Oui" if x else "Non")
    platelets = st.number_input("Plaquettes (en µL)", 25000.0, 850000.0, 300000.0)
    serum_creatinine = st.number_input("Créatinine sérique", 0.1, 10.0, 1.0)
    serum_sodium = st.slider("Sodium sérique", 110, 150, 135)
    sex = st.selectbox("Sexe", [0, 1], format_func=lambda x: "Homme" if x else "Femme")
    smoking = st.selectbox("Fumeur", [0, 1], format_func=lambda x: "Oui" if x else "Non")
    time = st.slider("Temps depuis admission (jours)", 0, 300, 100)

    return pd.DataFrame([{
        'age': age, 'anaemia': anaemia, 'creatinine_phosphokinase': creatinine_phosphokinase,
        'diabetes': diabetes, 'ejection_fraction': ejection_fraction, 'high_blood_pressure': high_blood_pressure,
        'platelets': platelets, 'serum_creatinine': serum_creatinine, 'serum_sodium': serum_sodium,
        'sex': sex, 'smoking': smoking, 'time': time
    }])

def prediction_page():
    st.title("🔍 Prédiction Individuelle (Modèle Classique)")
    input_df = collect_input()
    st.write("### Données entrées :", input_df)

    if st.button("Prédire avec modèle classique"):
        X_scaled = scaler.transform(input_df)
        prediction = model.predict(X_scaled)[0]
        proba = model.predict_proba(X_scaled)[0]
        st.markdown(f"**Probabilité de décès :** {proba[1]*100:.2f}%")
        st.markdown(f"**Probabilité de survie :** {proba[0]*100:.2f}%")

        fig, ax = plt.subplots()
        ax.bar(["Survie", "Décès"], proba, color=["green", "red"])
        st.pyplot(fig)

        if prediction == 1:
            st.error("⚠️ Risque élevé de décès détecté.")
        else:
            st.success("✅ Le patient a de bonnes chances de survie.")

        if "history" not in st.session_state:
            st.session_state["history"] = []
        st.session_state["history"].append({
            "Survie (%)": round(proba[0]*100, 2),
            "Décès (%)": round(proba[1]*100, 2),
            "Décès prédit": "Oui" if prediction == 1 else "Non"
        })

def ann_prediction_page():
    st.title("🤖 Prédiction avec Réseau de Neurones (ANN)")
    input_df = collect_input()
    st.write("### Données entrées :", input_df)

    if st.button("Prédire avec modèle ANN"):
        X_ann_scaled = ann_scaler.transform(input_df)
        ann_pred = ann_model.predict(X_ann_scaled)[0][0]
        st.metric("Probabilité de décès (ANN)", f"{ann_pred*100:.2f}%")

        fig, ax = plt.subplots()
        ax.bar(["Survie", "Décès"], [1-ann_pred, ann_pred], color=["green", "red"])
        st.pyplot(fig)

        if ann_pred >= 0.5:
            st.error("⚠️ Risque élevé détecté (ANN)")
        else:
            st.success("✅ Faible risque détecté (ANN)")

def batch_prediction_page():
    st.title("📁 Prédiction depuis un Fichier CSV")
    uploaded_file = st.file_uploader("Charger un fichier CSV", type="csv")
    if uploaded_file:
        try:
            data = pd.read_csv(uploaded_file)
            X_scaled = scaler.transform(data)
            preds = model.predict(X_scaled)
            probas = model.predict_proba(X_scaled)
            data["Décès prédit"] = preds
            data["Probabilité de décès"] = np.round(probas[:, 1] * 100, 2)
            st.dataframe(data)
            st.download_button("📥 Télécharger résultats", data.to_csv(index=False), "resultats.csv")
        except Exception as e:
            st.error(f"Erreur lors du traitement : {e}")

def history_page():
    st.title("📚 Historique de Prédictions")
    if "history" in st.session_state and st.session_state["history"]:
        hist_df = pd.DataFrame(st.session_state["history"])
        st.dataframe(hist_df)
    else:
        st.info("Aucune prédiction enregistrée dans cette session.")

pages = {
    "🏠 Accueil": home,
    "🔍 Prédiction (Classique)": prediction_page,
    "🤖 Prédiction Deep Learning": ann_prediction_page,
    "📁 Prédiction en lot": batch_prediction_page,
    "📚 Historique": history_page
}

choice = st.sidebar.radio("Navigation", list(pages.keys()))
pages[choice]()