# app.py # ============================================================================== # 1. IMPORTS # ============================================================================== import torch from transformers import AutoModelForCausalLM, AutoTokenizer import gradio as gr import os # Utile pour gérer les chemins si nécessaire, mais non utilisé ici. # ============================================================================== # 2. DÉFINITION DES CONSTANTES ET CHARGEMENT (Optimisé pour un Space) # ============================================================================== MODEL_NAME = "Clemylia/LAM-4-ZERO-F" # Dans un Space, le chargement se fait généralement sur l'appareil disponible (GPU si spécifié, sinon CPU). # Nous allons laisser Transformers gérer le device_map, car 'cpu' est l'option par défaut si pas de GPU. # Pour la robustesse, on peut quand même spécifier l'appareil si on veut être explicite. DEVICE = "cpu" print(f"Chargement du modèle {MODEL_NAME} sur {DEVICE}...") # Chargement du tokenizer tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) # Assurez-vous d'avoir un token de padding/fin de séquence if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token # Chargement du modèle. On utilise device_map="auto" pour laisser Transformers # gérer l'emplacement optimal, mais comme nous n'avons pas spécifié de GPU # l'inférence se fera sur le CPU par défaut si on n'a pas accès à un GPU. try: model = AutoModelForCausalLM.from_pretrained( MODEL_NAME, torch_dtype=torch.float32, device_map="auto" # Laisse Hugging Face gérer le device pour l'hébergement ) # Assurez-vous qu'il est sur le bon appareil après le chargement model.to(DEVICE) print("Modèle chargé avec succès.") except Exception as e: print(f"Erreur lors du chargement du modèle : {e}") # Vous pouvez ajouter ici un modèle de secours ou une erreur plus claire exit() # ============================================================================== # 3. FONCTION D'INFÉRENCE POUR GRADIO CHAT # ============================================================================== def format_prompt(history, message): """ Formate la conversation complète pour le modèle SLM dans le format : ### Instruction: [HISTORIQUE DE LA CONVERSATION] [NOUVELLE QUESTION] ### Response: """ # 1. Construire l'historique complet pour le contexte full_history = "" # history est une liste de paires [utilisateur, bot] for user_msg, bot_msg in history: # On assume un format simple de Questions/Réponses dans l'historique full_history += f"### Instruction: {user_msg}\n\n### Response: {bot_msg}\n" # 2. Ajouter la nouvelle question de l'utilisateur full_prompt = ( f"{full_history}" # L'historique des tours précédents f"### Instruction: {message}\n\n" # La nouvelle question f"### Response:\n" # Le modèle doit continuer à partir d'ici ) return full_prompt.strip() def generate_response(message, history): """ Fonction principale appelée par l'interface Gradio Chat. """ # 1. Formatage du prompt complet avec l'historique prompt = format_prompt(history, message) # 2. Tokenization et placement sur le CPU inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE) # 3. Génération with torch.no_grad(): # Économise la mémoire et accélère légèrement l'inférence output_tokens = model.generate( **inputs, max_new_tokens=150, # Augmenté légèrement pour un chat do_sample=True, temperature=0.5, top_k=40, eos_token_id=tokenizer.eos_token_id, # Attention : désactiver l'historique de la clé/valeur pour le chat # est souvent une bonne idée pour éviter les erreurs de dimension use_cache=True, ) # 4. Décodage et nettoyage generated_text = tokenizer.decode(output_tokens[0], skip_special_tokens=True) # Trouver et extraire uniquement la réponse du modèle # On se base sur le dernier marqueur '### Response:' assistant_prefix = "### Response:\n" if assistant_prefix in generated_text: # On splitte et prend la dernière partie, puis on nettoie les préfixes de fin clean_response = generated_text.split(assistant_prefix)[-1].strip() # Le modèle peut parfois répéter les marqueurs de prompt dans sa réponse # On s'assure de n'afficher que ce qui vient après le prompt if '### Instruction:' in clean_response: clean_response = clean_response.split('### Instruction:')[0].strip() else: # Fallback si le format n'est pas parfait clean_response = generated_text.replace(prompt, "").strip() return clean_response # ============================================================================== # 4. CRÉATION DE L'INTERFACE GRADIO # ============================================================================== # Le composant `gr.ChatInterface` est le plus simple et le plus adapté. # Il gère l'historique (le paramètre `history` de `generate_response`) # et l'affichage de la conversation pour nous. chat_interface = gr.ChatInterface( fn=generate_response, chatbot=gr.Chatbot(height=500), # Taille du champ de chat textbox=gr.Textbox(placeholder="Posez votre question à Lam-4-zero-f...", container=False, scale=7), title=f"Chat avec {MODEL_NAME} (SLM)", description="Interface de discussion pour le SLM Lam-4-zero-f. L'inférence est effectuée sur CPU.", theme="soft" ) # Lancement de l'application # 'share=False' et 'inbrowser=True' sont pour un test local. # Pour le Space, il suffit d'appeler .launch(). if __name__ == "__main__": chat_interface.launch()