File size: 7,255 Bytes
0f07bde
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
# =============================================================================
# BACKEND API POUR COACH PÉDAGOGIQUE IA (VERSION FINALE)
# =============================================================================
# Ce script utilise FastAPI et implémente la méthode de téléchargement dynamique
# pour le modèle LLM, évitant ainsi d'avoir à le stocker dans le dépôt Git.
# =============================================================================

# --- Imports ---
import os
import torch
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from llama_cpp import Llama
from langchain_community.vectorstores import FAISS
from langchain_community.embeddings import HuggingFaceEmbeddings
from huggingface_hub import hf_hub_download
import logging

# Configuration du logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# --- Initialisation de l'API ---
app = FastAPI()

# --- Modèles Pydantic pour la validation des données ---
class QuestionRequest(BaseModel):
    question: str

class AnswerResponse(BaseModel):
    answer: str

# --- Chargement des modèles au démarrage de l'API ---
# On utilise un "singleton" pour s'assurer que les modèles ne sont chargés qu'une seule fois.
class ModelSingleton:
    llm = None
    vectorstore = None
    embeddings = None

    def load_models(self):
        if self.llm is None:
            try:
                # --- Étape 1 : Configuration des chemins vers les artefacts LOCAUX ---
                # Ces dossiers (embeddings, faiss) DOIVENT être dans votre dépôt Git.
                base_dir = os.path.dirname(__file__)
                faiss_index_path = os.path.join(base_dir, "faiss_index_wize")
                embedding_model_path = os.path.join(base_dir, "embedding_model_saved")

                logger.info("Chargement du modèle d'embeddings local...")
                self.embeddings = HuggingFaceEmbeddings(
                    model_name=embedding_model_path,
                    model_kwargs={'device': 'cpu'} # Sur un Space gratuit, c'est CPU uniquement
                )
                logger.info("Modèle d'embeddings chargé.")

                logger.info("Chargement de la base de connaissances FAISS locale...")
                self.vectorstore = FAISS.load_local(
                    faiss_index_path,
                    self.embeddings,
                    allow_dangerous_deserialization=True
                )
                logger.info("Base de connaissances FAISS chargée.")

                # --- Étape 2 : Téléchargement dynamique du gros modèle GGUF ---
                # Le fichier n'est PAS dans le dépôt, il est téléchargé depuis le Hub.
                model_repo_id = "QuantFactory/Meta-Llama-3-8B-Instruct-GGUF"
                model_filename = "Meta-Llama-3-8B-Instruct.Q4_K_M.gguf"
                
                logger.info(f"Téléchargement du modèle LLM '{model_filename}' depuis le Hub... (peut être long)")
                
                model_path = hf_hub_download(
                    repo_id=model_repo_id,
                    filename=model_filename
                )
                logger.info(f"Modèle téléchargé dans : {model_path}")
                
                # --- Étape 3 : Chargement du LLM depuis le fichier téléchargé ---
                logger.info("Chargement du modèle LLM en mémoire (peut échouer par manque de RAM)...")
                self.llm = Llama(
                    model_path=model_path,
                    n_gpu_layers=0,  # 0 car nous sommes sur un CPU
                    n_ctx=4096,
                    verbose=False,
                    chat_format="llama-3"
                )
                logger.info("✅ Modèle LLM chargé avec succès.")

            except Exception as e:
                logger.error(f"❌ Erreur critique lors du chargement des modèles: {e}")
                # Si le chargement échoue, on lève une exception pour que l'API ne démarre pas incorrectement.
                raise RuntimeError(f"Impossible de charger les modèles: {e}")

# Instancier et charger les modèles au démarrage de l'application
# L'événement "startup" de FastAPI est le meilleur endroit pour faire ça.
@app.on_event("startup")
def startup_event():
    global models
    models = ModelSingleton()
    try:
        models.load_models()
    except Exception as e:
        # On log l'erreur, l'API répondra avec des erreurs 503 si les modèles ne sont pas chargés.
        logger.error(f"DÉMARRAGE ÉCHOUÉ : Les modèles n'ont pas pu être initialisés. {e}")
        # On met les modèles à None pour pouvoir gérer l'erreur proprement dans les endpoints.
        models.llm = None
        models.vectorstore = None

# --- Définition du point de terminaison de l'API ---
@app.post("/ask", response_model=AnswerResponse)
def ask_question(request: QuestionRequest):
    """
    Ce point de terminaison reçoit une question, utilise le RAG pour trouver
    le contexte et génère une réponse avec le LLM.
    """
    if models.llm is None or models.vectorstore is None:
        raise HTTPException(status_code=503, detail="Service non disponible : les modèles n'ont pas pu être chargés au démarrage.")

    user_question = request.question
    logger.info(f"Requête reçue pour la question : '{user_question}'")

    try:
        # 1. RAG : Récupérer le contexte
        retriever = models.vectorstore.as_retriever(search_type="similarity", search_kwargs={"k": 3})
        docs = retriever.invoke(user_question)
        context = "\n".join([doc.page_content for doc in docs])

        # 2. Prompt Engineering (votre logique exacte)
        system_message = (
            "Tu es un coach pédagogique expert, travaillant avec un système RAG basé sur des documents fournis. "
            "Tu réponds uniquement à partir des informations extraites de ces documents. "
            "Tu ne réponds qu’en français. Tu ne dois jamais inventer de réponse. "
            "Tes réponses doivent être en 1 à 2 phrases maximum, claires et compactes."
        )
        
        prompt = f"""
<|begin_of_text|><|start_header_id|>system<|end_header_id|>
{system_message}
<|eot_id|><|start_header_id|>user<|end_header_id|>
Contexte :
{context}

Question : {user_question}
<|eot_id|><|start_header_id|>assistant<|end_header_id|>
"""
        # 3. Génération de la réponse
        logger.info("Génération de la réponse...")
        response = models.llm(
            prompt,
            max_tokens=512, # On réduit pour une réponse plus rapide
            temperature=0.3,
            stop=["<|eot_id|>"],
            echo=False
        )
        answer = response['choices'][0]['text'].strip()
        logger.info("Réponse générée avec succès.")

        return AnswerResponse(answer=answer)

    except Exception as e:
        logger.error(f"Erreur lors de la génération de la réponse : {e}")
        raise HTTPException(status_code=500, detail=str(e))

@app.get("/")
def read_root():
    """Point de terminaison racine pour vérifier que le serveur est en marche."""
    return {"status": "Backend du Coach IA est en ligne. Utilisez le point de terminaison /ask pour poser une question."}