Arvador237 commited on
Commit
c6f18e6
·
verified ·
1 Parent(s): 8f5dee6

Update models.py

Browse files
Files changed (1) hide show
  1. models.py +18 -21
models.py CHANGED
@@ -240,12 +240,11 @@ def train_lstm(data, country, start_date, end_date, lstm_units, epochs, look_bac
240
  raise ValueError(f"Erreur LSTM: {str(e)}\nContexte: {error_details}")
241
 
242
  def train_bayesian_network(data, country, start_date, end_date,
243
- taux_directeur_change=0, pib_change=0, m2_change=0):
244
- """Réseau Bayésien avec compatibilité pgmpy moderne"""
245
  start_time = time.time()
246
 
247
  try:
248
- # [1] Préparation des données (identique)
249
  country_data = data[data['Pays'] == country]
250
  filtered_data = country_data[
251
  (country_data['Année'] >= str(start_date)) &
@@ -253,22 +252,22 @@ def train_bayesian_network(data, country, start_date, end_date,
253
  ].sort_values('Année')
254
 
255
  modified_data = apply_scenarios(filtered_data, taux_directeur_change, pib_change, m2_change)
256
-
257
- # [2] Discrétisation améliorée
258
  df = modified_data.copy()
259
- variables = [
260
- ("Taux d'inflation (%)", "Inflation"),
261
- ("Masse monétaire (M2)", "M2"),
262
- ("Croissance PIB (%)", "PIB"),
263
- ("Taux directeur", "TauxDirecteur"),
264
- ("Balance commerciale", "Balance"),
265
- ("Taux de change FCFA/USD", "Change")
266
- ]
267
-
268
- for src, dest in variables:
 
 
269
  df[dest] = pd.qcut(df[src], q=5, duplicates='drop', labels=False).astype(int)
270
-
271
- # [3] Construction du réseau
272
  model = BayesianNetwork([
273
  ("M2", "Inflation"),
274
  ("PIB", "Inflation"),
@@ -276,8 +275,6 @@ def train_bayesian_network(data, country, start_date, end_date,
276
  ("Balance", "Inflation"),
277
  ("Change", "Inflation")
278
  ])
279
-
280
- # [4] Entraînement avec vérification
281
  model.fit(df, estimator=MaximumLikelihoodEstimator)
282
 
283
  # [5] Correction des CPDs (version compatible pgmpy 0.1.12+)
@@ -349,7 +346,7 @@ def train_bayesian_network(data, country, start_date, end_date,
349
 
350
  except Exception as e:
351
  error_details = {
352
- 'shape': modified_data.shape if 'modified_data' in locals() else None,
353
- 'discretization': {v: len(np.unique(df[v])) for v in variables} if 'df' in locals() else None
354
  }
355
  raise ValueError(f"Erreur Réseau Bayésien: {str(e)}\nContexte: {error_details}")
 
240
  raise ValueError(f"Erreur LSTM: {str(e)}\nContexte: {error_details}")
241
 
242
  def train_bayesian_network(data, country, start_date, end_date,
243
+ taux_directeur_change=0, pib_change=0, m2_change=0):
 
244
  start_time = time.time()
245
 
246
  try:
247
+ # 1. Préparation des données
248
  country_data = data[data['Pays'] == country]
249
  filtered_data = country_data[
250
  (country_data['Année'] >= str(start_date)) &
 
252
  ].sort_values('Année')
253
 
254
  modified_data = apply_scenarios(filtered_data, taux_directeur_change, pib_change, m2_change)
 
 
255
  df = modified_data.copy()
256
+
257
+ # 2. Discrétisation avec mapping clair
258
+ discretized_cols = {
259
+ "Taux d'inflation (%)": "Inflation",
260
+ "Masse monétaire (M2)": "M2",
261
+ "Croissance PIB (%)": "PIB",
262
+ "Taux directeur": "TauxDirecteur",
263
+ "Balance commerciale": "Balance",
264
+ "Taux de change FCFA/USD": "Change"
265
+ }
266
+
267
+ for src, dest in discretized_cols.items():
268
  df[dest] = pd.qcut(df[src], q=5, duplicates='drop', labels=False).astype(int)
269
+
270
+ # 3. Construction et entraînement du modèle
271
  model = BayesianNetwork([
272
  ("M2", "Inflation"),
273
  ("PIB", "Inflation"),
 
275
  ("Balance", "Inflation"),
276
  ("Change", "Inflation")
277
  ])
 
 
278
  model.fit(df, estimator=MaximumLikelihoodEstimator)
279
 
280
  # [5] Correction des CPDs (version compatible pgmpy 0.1.12+)
 
346
 
347
  except Exception as e:
348
  error_details = {
349
+ 'columns': list(modified_data.columns) if 'modified_data' in locals() else None,
350
+ 'discretized': list(discretized_cols.values()) if 'discretized_cols' in locals() else None
351
  }
352
  raise ValueError(f"Erreur Réseau Bayésien: {str(e)}\nContexte: {error_details}")