Spaces:
Runtime error
Runtime error
# ============================================================================= | |
# BACKEND API POUR COACH PÉDAGOGIQUE IA (VERSION FINALE) | |
# ============================================================================= | |
# Ce script utilise FastAPI et implémente la méthode de téléchargement dynamique | |
# pour le modèle LLM, évitant ainsi d'avoir à le stocker dans le dépôt Git. | |
# ============================================================================= | |
# --- Imports --- | |
import os | |
import torch | |
from fastapi import FastAPI, HTTPException | |
from pydantic import BaseModel | |
from llama_cpp import Llama | |
from langchain_community.vectorstores import FAISS | |
from langchain_community.embeddings import HuggingFaceEmbeddings | |
from huggingface_hub import hf_hub_download | |
import logging | |
# Configuration du logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# --- Initialisation de l'API --- | |
app = FastAPI() | |
# --- Modèles Pydantic pour la validation des données --- | |
class QuestionRequest(BaseModel): | |
question: str | |
class AnswerResponse(BaseModel): | |
answer: str | |
# --- Chargement des modèles au démarrage de l'API --- | |
# On utilise un "singleton" pour s'assurer que les modèles ne sont chargés qu'une seule fois. | |
class ModelSingleton: | |
llm = None | |
vectorstore = None | |
embeddings = None | |
def load_models(self): | |
if self.llm is None: | |
try: | |
# --- Étape 1 : Configuration des chemins vers les artefacts LOCAUX --- | |
# Ces dossiers (embeddings, faiss) DOIVENT être dans votre dépôt Git. | |
base_dir = os.path.dirname(__file__) | |
faiss_index_path = os.path.join(base_dir, "faiss_index_wize") | |
embedding_model_path = os.path.join(base_dir, "embedding_model_saved") | |
logger.info("Chargement du modèle d'embeddings local...") | |
self.embeddings = HuggingFaceEmbeddings( | |
model_name=embedding_model_path, | |
model_kwargs={'device': 'cpu'} # Sur un Space gratuit, c'est CPU uniquement | |
) | |
logger.info("Modèle d'embeddings chargé.") | |
logger.info("Chargement de la base de connaissances FAISS locale...") | |
self.vectorstore = FAISS.load_local( | |
faiss_index_path, | |
self.embeddings, | |
allow_dangerous_deserialization=True | |
) | |
logger.info("Base de connaissances FAISS chargée.") | |
# --- Étape 2 : Téléchargement dynamique du gros modèle GGUF --- | |
# Le fichier n'est PAS dans le dépôt, il est téléchargé depuis le Hub. | |
model_repo_id = "QuantFactory/Meta-Llama-3-8B-Instruct-GGUF" | |
model_filename = "Meta-Llama-3-8B-Instruct.Q4_K_M.gguf" | |
logger.info(f"Téléchargement du modèle LLM '{model_filename}' depuis le Hub... (peut être long)") | |
model_path = hf_hub_download( | |
repo_id=model_repo_id, | |
filename=model_filename | |
) | |
logger.info(f"Modèle téléchargé dans : {model_path}") | |
# --- Étape 3 : Chargement du LLM depuis le fichier téléchargé --- | |
logger.info("Chargement du modèle LLM en mémoire (peut échouer par manque de RAM)...") | |
self.llm = Llama( | |
model_path=model_path, | |
n_gpu_layers=0, # 0 car nous sommes sur un CPU | |
n_ctx=4096, | |
verbose=False, | |
chat_format="llama-3" | |
) | |
logger.info("✅ Modèle LLM chargé avec succès.") | |
except Exception as e: | |
logger.error(f"❌ Erreur critique lors du chargement des modèles: {e}") | |
# Si le chargement échoue, on lève une exception pour que l'API ne démarre pas incorrectement. | |
raise RuntimeError(f"Impossible de charger les modèles: {e}") | |
# Instancier et charger les modèles au démarrage de l'application | |
# L'événement "startup" de FastAPI est le meilleur endroit pour faire ça. | |
def startup_event(): | |
global models | |
models = ModelSingleton() | |
try: | |
models.load_models() | |
except Exception as e: | |
# On log l'erreur, l'API répondra avec des erreurs 503 si les modèles ne sont pas chargés. | |
logger.error(f"DÉMARRAGE ÉCHOUÉ : Les modèles n'ont pas pu être initialisés. {e}") | |
# On met les modèles à None pour pouvoir gérer l'erreur proprement dans les endpoints. | |
models.llm = None | |
models.vectorstore = None | |
# --- Définition du point de terminaison de l'API --- | |
def ask_question(request: QuestionRequest): | |
""" | |
Ce point de terminaison reçoit une question, utilise le RAG pour trouver | |
le contexte et génère une réponse avec le LLM. | |
""" | |
if models.llm is None or models.vectorstore is None: | |
raise HTTPException(status_code=503, detail="Service non disponible : les modèles n'ont pas pu être chargés au démarrage.") | |
user_question = request.question | |
logger.info(f"Requête reçue pour la question : '{user_question}'") | |
try: | |
# 1. RAG : Récupérer le contexte | |
retriever = models.vectorstore.as_retriever(search_type="similarity", search_kwargs={"k": 3}) | |
docs = retriever.invoke(user_question) | |
context = "\n".join([doc.page_content for doc in docs]) | |
# 2. Prompt Engineering (votre logique exacte) | |
system_message = ( | |
"Tu es un coach pédagogique expert, travaillant avec un système RAG basé sur des documents fournis. " | |
"Tu réponds uniquement à partir des informations extraites de ces documents. " | |
"Tu ne réponds qu’en français. Tu ne dois jamais inventer de réponse. " | |
"Tes réponses doivent être en 1 à 2 phrases maximum, claires et compactes." | |
) | |
prompt = f""" | |
<|begin_of_text|><|start_header_id|>system<|end_header_id|> | |
{system_message} | |
<|eot_id|><|start_header_id|>user<|end_header_id|> | |
Contexte : | |
{context} | |
Question : {user_question} | |
<|eot_id|><|start_header_id|>assistant<|end_header_id|> | |
""" | |
# 3. Génération de la réponse | |
logger.info("Génération de la réponse...") | |
response = models.llm( | |
prompt, | |
max_tokens=512, # On réduit pour une réponse plus rapide | |
temperature=0.3, | |
stop=["<|eot_id|>"], | |
echo=False | |
) | |
answer = response['choices'][0]['text'].strip() | |
logger.info("Réponse générée avec succès.") | |
return AnswerResponse(answer=answer) | |
except Exception as e: | |
logger.error(f"Erreur lors de la génération de la réponse : {e}") | |
raise HTTPException(status_code=500, detail=str(e)) | |
def read_root(): | |
"""Point de terminaison racine pour vérifier que le serveur est en marche.""" | |
return {"status": "Backend du Coach IA est en ligne. Utilisez le point de terminaison /ask pour poser une question."} | |