Spaces:
Runtime error
Runtime error
File size: 7,255 Bytes
0f07bde |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 |
# =============================================================================
# BACKEND API POUR COACH PÉDAGOGIQUE IA (VERSION FINALE)
# =============================================================================
# Ce script utilise FastAPI et implémente la méthode de téléchargement dynamique
# pour le modèle LLM, évitant ainsi d'avoir à le stocker dans le dépôt Git.
# =============================================================================
# --- Imports ---
import os
import torch
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from llama_cpp import Llama
from langchain_community.vectorstores import FAISS
from langchain_community.embeddings import HuggingFaceEmbeddings
from huggingface_hub import hf_hub_download
import logging
# Configuration du logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# --- Initialisation de l'API ---
app = FastAPI()
# --- Modèles Pydantic pour la validation des données ---
class QuestionRequest(BaseModel):
question: str
class AnswerResponse(BaseModel):
answer: str
# --- Chargement des modèles au démarrage de l'API ---
# On utilise un "singleton" pour s'assurer que les modèles ne sont chargés qu'une seule fois.
class ModelSingleton:
llm = None
vectorstore = None
embeddings = None
def load_models(self):
if self.llm is None:
try:
# --- Étape 1 : Configuration des chemins vers les artefacts LOCAUX ---
# Ces dossiers (embeddings, faiss) DOIVENT être dans votre dépôt Git.
base_dir = os.path.dirname(__file__)
faiss_index_path = os.path.join(base_dir, "faiss_index_wize")
embedding_model_path = os.path.join(base_dir, "embedding_model_saved")
logger.info("Chargement du modèle d'embeddings local...")
self.embeddings = HuggingFaceEmbeddings(
model_name=embedding_model_path,
model_kwargs={'device': 'cpu'} # Sur un Space gratuit, c'est CPU uniquement
)
logger.info("Modèle d'embeddings chargé.")
logger.info("Chargement de la base de connaissances FAISS locale...")
self.vectorstore = FAISS.load_local(
faiss_index_path,
self.embeddings,
allow_dangerous_deserialization=True
)
logger.info("Base de connaissances FAISS chargée.")
# --- Étape 2 : Téléchargement dynamique du gros modèle GGUF ---
# Le fichier n'est PAS dans le dépôt, il est téléchargé depuis le Hub.
model_repo_id = "QuantFactory/Meta-Llama-3-8B-Instruct-GGUF"
model_filename = "Meta-Llama-3-8B-Instruct.Q4_K_M.gguf"
logger.info(f"Téléchargement du modèle LLM '{model_filename}' depuis le Hub... (peut être long)")
model_path = hf_hub_download(
repo_id=model_repo_id,
filename=model_filename
)
logger.info(f"Modèle téléchargé dans : {model_path}")
# --- Étape 3 : Chargement du LLM depuis le fichier téléchargé ---
logger.info("Chargement du modèle LLM en mémoire (peut échouer par manque de RAM)...")
self.llm = Llama(
model_path=model_path,
n_gpu_layers=0, # 0 car nous sommes sur un CPU
n_ctx=4096,
verbose=False,
chat_format="llama-3"
)
logger.info("✅ Modèle LLM chargé avec succès.")
except Exception as e:
logger.error(f"❌ Erreur critique lors du chargement des modèles: {e}")
# Si le chargement échoue, on lève une exception pour que l'API ne démarre pas incorrectement.
raise RuntimeError(f"Impossible de charger les modèles: {e}")
# Instancier et charger les modèles au démarrage de l'application
# L'événement "startup" de FastAPI est le meilleur endroit pour faire ça.
@app.on_event("startup")
def startup_event():
global models
models = ModelSingleton()
try:
models.load_models()
except Exception as e:
# On log l'erreur, l'API répondra avec des erreurs 503 si les modèles ne sont pas chargés.
logger.error(f"DÉMARRAGE ÉCHOUÉ : Les modèles n'ont pas pu être initialisés. {e}")
# On met les modèles à None pour pouvoir gérer l'erreur proprement dans les endpoints.
models.llm = None
models.vectorstore = None
# --- Définition du point de terminaison de l'API ---
@app.post("/ask", response_model=AnswerResponse)
def ask_question(request: QuestionRequest):
"""
Ce point de terminaison reçoit une question, utilise le RAG pour trouver
le contexte et génère une réponse avec le LLM.
"""
if models.llm is None or models.vectorstore is None:
raise HTTPException(status_code=503, detail="Service non disponible : les modèles n'ont pas pu être chargés au démarrage.")
user_question = request.question
logger.info(f"Requête reçue pour la question : '{user_question}'")
try:
# 1. RAG : Récupérer le contexte
retriever = models.vectorstore.as_retriever(search_type="similarity", search_kwargs={"k": 3})
docs = retriever.invoke(user_question)
context = "\n".join([doc.page_content for doc in docs])
# 2. Prompt Engineering (votre logique exacte)
system_message = (
"Tu es un coach pédagogique expert, travaillant avec un système RAG basé sur des documents fournis. "
"Tu réponds uniquement à partir des informations extraites de ces documents. "
"Tu ne réponds qu’en français. Tu ne dois jamais inventer de réponse. "
"Tes réponses doivent être en 1 à 2 phrases maximum, claires et compactes."
)
prompt = f"""
<|begin_of_text|><|start_header_id|>system<|end_header_id|>
{system_message}
<|eot_id|><|start_header_id|>user<|end_header_id|>
Contexte :
{context}
Question : {user_question}
<|eot_id|><|start_header_id|>assistant<|end_header_id|>
"""
# 3. Génération de la réponse
logger.info("Génération de la réponse...")
response = models.llm(
prompt,
max_tokens=512, # On réduit pour une réponse plus rapide
temperature=0.3,
stop=["<|eot_id|>"],
echo=False
)
answer = response['choices'][0]['text'].strip()
logger.info("Réponse générée avec succès.")
return AnswerResponse(answer=answer)
except Exception as e:
logger.error(f"Erreur lors de la génération de la réponse : {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/")
def read_root():
"""Point de terminaison racine pour vérifier que le serveur est en marche."""
return {"status": "Backend du Coach IA est en ligne. Utilisez le point de terminaison /ask pour poser une question."}
|