Clemylia's picture
Update app.py
87b5a6c verified
# app.py
# ==============================================================================
# 1. IMPORTS
# ==============================================================================
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import gradio as gr
import os # Utile pour gérer les chemins si nécessaire, mais non utilisé ici.
# ==============================================================================
# 2. DÉFINITION DES CONSTANTES ET CHARGEMENT (Optimisé pour un Space)
# ==============================================================================
MODEL_NAME = "Clemylia/LAM-4-ZERO-F"
# Dans un Space, le chargement se fait généralement sur l'appareil disponible (GPU si spécifié, sinon CPU).
# Nous allons laisser Transformers gérer le device_map, car 'cpu' est l'option par défaut si pas de GPU.
# Pour la robustesse, on peut quand même spécifier l'appareil si on veut être explicite.
DEVICE = "cpu"
print(f"Chargement du modèle {MODEL_NAME} sur {DEVICE}...")
# Chargement du tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
# Assurez-vous d'avoir un token de padding/fin de séquence
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# Chargement du modèle. On utilise device_map="auto" pour laisser Transformers
# gérer l'emplacement optimal, mais comme nous n'avons pas spécifié de GPU
# l'inférence se fera sur le CPU par défaut si on n'a pas accès à un GPU.
try:
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
torch_dtype=torch.float32,
device_map="auto" # Laisse Hugging Face gérer le device pour l'hébergement
)
# Assurez-vous qu'il est sur le bon appareil après le chargement
model.to(DEVICE)
print("Modèle chargé avec succès.")
except Exception as e:
print(f"Erreur lors du chargement du modèle : {e}")
# Vous pouvez ajouter ici un modèle de secours ou une erreur plus claire
exit()
# ==============================================================================
# 3. FONCTION D'INFÉRENCE POUR GRADIO CHAT
# ==============================================================================
def format_prompt(history, message):
"""
Formate la conversation complète pour le modèle SLM dans le format :
### Instruction:
[HISTORIQUE DE LA CONVERSATION]
[NOUVELLE QUESTION]
### Response:
"""
# 1. Construire l'historique complet pour le contexte
full_history = ""
# history est une liste de paires [utilisateur, bot]
for user_msg, bot_msg in history:
# On assume un format simple de Questions/Réponses dans l'historique
full_history += f"### Instruction: {user_msg}\n\n### Response: {bot_msg}\n"
# 2. Ajouter la nouvelle question de l'utilisateur
full_prompt = (
f"{full_history}" # L'historique des tours précédents
f"### Instruction: {message}\n\n" # La nouvelle question
f"### Response:\n" # Le modèle doit continuer à partir d'ici
)
return full_prompt.strip()
def generate_response(message, history):
"""
Fonction principale appelée par l'interface Gradio Chat.
"""
# 1. Formatage du prompt complet avec l'historique
prompt = format_prompt(history, message)
# 2. Tokenization et placement sur le CPU
inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE)
# 3. Génération
with torch.no_grad(): # Économise la mémoire et accélère légèrement l'inférence
output_tokens = model.generate(
**inputs,
max_new_tokens=150, # Augmenté légèrement pour un chat
do_sample=True,
temperature=0.5,
top_k=40,
eos_token_id=tokenizer.eos_token_id,
# Attention : désactiver l'historique de la clé/valeur pour le chat
# est souvent une bonne idée pour éviter les erreurs de dimension
use_cache=True,
)
# 4. Décodage et nettoyage
generated_text = tokenizer.decode(output_tokens[0], skip_special_tokens=True)
# Trouver et extraire uniquement la réponse du modèle
# On se base sur le dernier marqueur '### Response:'
assistant_prefix = "### Response:\n"
if assistant_prefix in generated_text:
# On splitte et prend la dernière partie, puis on nettoie les préfixes de fin
clean_response = generated_text.split(assistant_prefix)[-1].strip()
# Le modèle peut parfois répéter les marqueurs de prompt dans sa réponse
# On s'assure de n'afficher que ce qui vient après le prompt
if '### Instruction:' in clean_response:
clean_response = clean_response.split('### Instruction:')[0].strip()
else:
# Fallback si le format n'est pas parfait
clean_response = generated_text.replace(prompt, "").strip()
return clean_response
# ==============================================================================
# 4. CRÉATION DE L'INTERFACE GRADIO
# ==============================================================================
# Le composant `gr.ChatInterface` est le plus simple et le plus adapté.
# Il gère l'historique (le paramètre `history` de `generate_response`)
# et l'affichage de la conversation pour nous.
chat_interface = gr.ChatInterface(
fn=generate_response,
chatbot=gr.Chatbot(height=500), # Taille du champ de chat
textbox=gr.Textbox(placeholder="Posez votre question à Lam-4-zero-f...", container=False, scale=7),
title=f"Chat avec {MODEL_NAME} (SLM)",
description="Interface de discussion pour le SLM Lam-4-zero-f. L'inférence est effectuée sur CPU.",
theme="soft"
)
# Lancement de l'application
# 'share=False' et 'inbrowser=True' sont pour un test local.
# Pour le Space, il suffit d'appeler .launch().
if __name__ == "__main__":
chat_interface.launch()