Spaces:
Running
Running
Update models.py
Browse files
models.py
CHANGED
@@ -240,12 +240,11 @@ def train_lstm(data, country, start_date, end_date, lstm_units, epochs, look_bac
|
|
240 |
raise ValueError(f"Erreur LSTM: {str(e)}\nContexte: {error_details}")
|
241 |
|
242 |
def train_bayesian_network(data, country, start_date, end_date,
|
243 |
-
|
244 |
-
"""Réseau Bayésien avec compatibilité pgmpy moderne"""
|
245 |
start_time = time.time()
|
246 |
|
247 |
try:
|
248 |
-
#
|
249 |
country_data = data[data['Pays'] == country]
|
250 |
filtered_data = country_data[
|
251 |
(country_data['Année'] >= str(start_date)) &
|
@@ -253,22 +252,22 @@ def train_bayesian_network(data, country, start_date, end_date,
|
|
253 |
].sort_values('Année')
|
254 |
|
255 |
modified_data = apply_scenarios(filtered_data, taux_directeur_change, pib_change, m2_change)
|
256 |
-
|
257 |
-
# [2] Discrétisation améliorée
|
258 |
df = modified_data.copy()
|
259 |
-
|
260 |
-
|
261 |
-
|
262 |
-
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
|
267 |
-
|
268 |
-
|
|
|
|
|
269 |
df[dest] = pd.qcut(df[src], q=5, duplicates='drop', labels=False).astype(int)
|
270 |
-
|
271 |
-
#
|
272 |
model = BayesianNetwork([
|
273 |
("M2", "Inflation"),
|
274 |
("PIB", "Inflation"),
|
@@ -276,8 +275,6 @@ def train_bayesian_network(data, country, start_date, end_date,
|
|
276 |
("Balance", "Inflation"),
|
277 |
("Change", "Inflation")
|
278 |
])
|
279 |
-
|
280 |
-
# [4] Entraînement avec vérification
|
281 |
model.fit(df, estimator=MaximumLikelihoodEstimator)
|
282 |
|
283 |
# [5] Correction des CPDs (version compatible pgmpy 0.1.12+)
|
@@ -349,7 +346,7 @@ def train_bayesian_network(data, country, start_date, end_date,
|
|
349 |
|
350 |
except Exception as e:
|
351 |
error_details = {
|
352 |
-
'
|
353 |
-
'
|
354 |
}
|
355 |
raise ValueError(f"Erreur Réseau Bayésien: {str(e)}\nContexte: {error_details}")
|
|
|
240 |
raise ValueError(f"Erreur LSTM: {str(e)}\nContexte: {error_details}")
|
241 |
|
242 |
def train_bayesian_network(data, country, start_date, end_date,
|
243 |
+
taux_directeur_change=0, pib_change=0, m2_change=0):
|
|
|
244 |
start_time = time.time()
|
245 |
|
246 |
try:
|
247 |
+
# 1. Préparation des données
|
248 |
country_data = data[data['Pays'] == country]
|
249 |
filtered_data = country_data[
|
250 |
(country_data['Année'] >= str(start_date)) &
|
|
|
252 |
].sort_values('Année')
|
253 |
|
254 |
modified_data = apply_scenarios(filtered_data, taux_directeur_change, pib_change, m2_change)
|
|
|
|
|
255 |
df = modified_data.copy()
|
256 |
+
|
257 |
+
# 2. Discrétisation avec mapping clair
|
258 |
+
discretized_cols = {
|
259 |
+
"Taux d'inflation (%)": "Inflation",
|
260 |
+
"Masse monétaire (M2)": "M2",
|
261 |
+
"Croissance PIB (%)": "PIB",
|
262 |
+
"Taux directeur": "TauxDirecteur",
|
263 |
+
"Balance commerciale": "Balance",
|
264 |
+
"Taux de change FCFA/USD": "Change"
|
265 |
+
}
|
266 |
+
|
267 |
+
for src, dest in discretized_cols.items():
|
268 |
df[dest] = pd.qcut(df[src], q=5, duplicates='drop', labels=False).astype(int)
|
269 |
+
|
270 |
+
# 3. Construction et entraînement du modèle
|
271 |
model = BayesianNetwork([
|
272 |
("M2", "Inflation"),
|
273 |
("PIB", "Inflation"),
|
|
|
275 |
("Balance", "Inflation"),
|
276 |
("Change", "Inflation")
|
277 |
])
|
|
|
|
|
278 |
model.fit(df, estimator=MaximumLikelihoodEstimator)
|
279 |
|
280 |
# [5] Correction des CPDs (version compatible pgmpy 0.1.12+)
|
|
|
346 |
|
347 |
except Exception as e:
|
348 |
error_details = {
|
349 |
+
'columns': list(modified_data.columns) if 'modified_data' in locals() else None,
|
350 |
+
'discretized': list(discretized_cols.values()) if 'discretized_cols' in locals() else None
|
351 |
}
|
352 |
raise ValueError(f"Erreur Réseau Bayésien: {str(e)}\nContexte: {error_details}")
|