khadijaaao's picture
Create app.py
0f07bde verified
# =============================================================================
# BACKEND API POUR COACH PÉDAGOGIQUE IA (VERSION FINALE)
# =============================================================================
# Ce script utilise FastAPI et implémente la méthode de téléchargement dynamique
# pour le modèle LLM, évitant ainsi d'avoir à le stocker dans le dépôt Git.
# =============================================================================
# --- Imports ---
import os
import torch
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from llama_cpp import Llama
from langchain_community.vectorstores import FAISS
from langchain_community.embeddings import HuggingFaceEmbeddings
from huggingface_hub import hf_hub_download
import logging
# Configuration du logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# --- Initialisation de l'API ---
app = FastAPI()
# --- Modèles Pydantic pour la validation des données ---
class QuestionRequest(BaseModel):
question: str
class AnswerResponse(BaseModel):
answer: str
# --- Chargement des modèles au démarrage de l'API ---
# On utilise un "singleton" pour s'assurer que les modèles ne sont chargés qu'une seule fois.
class ModelSingleton:
llm = None
vectorstore = None
embeddings = None
def load_models(self):
if self.llm is None:
try:
# --- Étape 1 : Configuration des chemins vers les artefacts LOCAUX ---
# Ces dossiers (embeddings, faiss) DOIVENT être dans votre dépôt Git.
base_dir = os.path.dirname(__file__)
faiss_index_path = os.path.join(base_dir, "faiss_index_wize")
embedding_model_path = os.path.join(base_dir, "embedding_model_saved")
logger.info("Chargement du modèle d'embeddings local...")
self.embeddings = HuggingFaceEmbeddings(
model_name=embedding_model_path,
model_kwargs={'device': 'cpu'} # Sur un Space gratuit, c'est CPU uniquement
)
logger.info("Modèle d'embeddings chargé.")
logger.info("Chargement de la base de connaissances FAISS locale...")
self.vectorstore = FAISS.load_local(
faiss_index_path,
self.embeddings,
allow_dangerous_deserialization=True
)
logger.info("Base de connaissances FAISS chargée.")
# --- Étape 2 : Téléchargement dynamique du gros modèle GGUF ---
# Le fichier n'est PAS dans le dépôt, il est téléchargé depuis le Hub.
model_repo_id = "QuantFactory/Meta-Llama-3-8B-Instruct-GGUF"
model_filename = "Meta-Llama-3-8B-Instruct.Q4_K_M.gguf"
logger.info(f"Téléchargement du modèle LLM '{model_filename}' depuis le Hub... (peut être long)")
model_path = hf_hub_download(
repo_id=model_repo_id,
filename=model_filename
)
logger.info(f"Modèle téléchargé dans : {model_path}")
# --- Étape 3 : Chargement du LLM depuis le fichier téléchargé ---
logger.info("Chargement du modèle LLM en mémoire (peut échouer par manque de RAM)...")
self.llm = Llama(
model_path=model_path,
n_gpu_layers=0, # 0 car nous sommes sur un CPU
n_ctx=4096,
verbose=False,
chat_format="llama-3"
)
logger.info("✅ Modèle LLM chargé avec succès.")
except Exception as e:
logger.error(f"❌ Erreur critique lors du chargement des modèles: {e}")
# Si le chargement échoue, on lève une exception pour que l'API ne démarre pas incorrectement.
raise RuntimeError(f"Impossible de charger les modèles: {e}")
# Instancier et charger les modèles au démarrage de l'application
# L'événement "startup" de FastAPI est le meilleur endroit pour faire ça.
@app.on_event("startup")
def startup_event():
global models
models = ModelSingleton()
try:
models.load_models()
except Exception as e:
# On log l'erreur, l'API répondra avec des erreurs 503 si les modèles ne sont pas chargés.
logger.error(f"DÉMARRAGE ÉCHOUÉ : Les modèles n'ont pas pu être initialisés. {e}")
# On met les modèles à None pour pouvoir gérer l'erreur proprement dans les endpoints.
models.llm = None
models.vectorstore = None
# --- Définition du point de terminaison de l'API ---
@app.post("/ask", response_model=AnswerResponse)
def ask_question(request: QuestionRequest):
"""
Ce point de terminaison reçoit une question, utilise le RAG pour trouver
le contexte et génère une réponse avec le LLM.
"""
if models.llm is None or models.vectorstore is None:
raise HTTPException(status_code=503, detail="Service non disponible : les modèles n'ont pas pu être chargés au démarrage.")
user_question = request.question
logger.info(f"Requête reçue pour la question : '{user_question}'")
try:
# 1. RAG : Récupérer le contexte
retriever = models.vectorstore.as_retriever(search_type="similarity", search_kwargs={"k": 3})
docs = retriever.invoke(user_question)
context = "\n".join([doc.page_content for doc in docs])
# 2. Prompt Engineering (votre logique exacte)
system_message = (
"Tu es un coach pédagogique expert, travaillant avec un système RAG basé sur des documents fournis. "
"Tu réponds uniquement à partir des informations extraites de ces documents. "
"Tu ne réponds qu’en français. Tu ne dois jamais inventer de réponse. "
"Tes réponses doivent être en 1 à 2 phrases maximum, claires et compactes."
)
prompt = f"""
<|begin_of_text|><|start_header_id|>system<|end_header_id|>
{system_message}
<|eot_id|><|start_header_id|>user<|end_header_id|>
Contexte :
{context}
Question : {user_question}
<|eot_id|><|start_header_id|>assistant<|end_header_id|>
"""
# 3. Génération de la réponse
logger.info("Génération de la réponse...")
response = models.llm(
prompt,
max_tokens=512, # On réduit pour une réponse plus rapide
temperature=0.3,
stop=["<|eot_id|>"],
echo=False
)
answer = response['choices'][0]['text'].strip()
logger.info("Réponse générée avec succès.")
return AnswerResponse(answer=answer)
except Exception as e:
logger.error(f"Erreur lors de la génération de la réponse : {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/")
def read_root():
"""Point de terminaison racine pour vérifier que le serveur est en marche."""
return {"status": "Backend du Coach IA est en ligne. Utilisez le point de terminaison /ask pour poser une question."}