Clemylia commited on
Commit
055bee9
·
verified ·
1 Parent(s): 2e4be4e

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +143 -0
app.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ # ==============================================================================
3
+ # 1. IMPORTS
4
+ # ==============================================================================
5
+ import torch
6
+ from transformers import AutoModelForCausalLM, AutoTokenizer
7
+ import gradio as gr
8
+ import os # Utile pour gérer les chemins si nécessaire, mais non utilisé ici.
9
+
10
+ # ==============================================================================
11
+ # 2. DÉFINITION DES CONSTANTES ET CHARGEMENT (Optimisé pour un Space)
12
+ # ==============================================================================
13
+ MODEL_NAME = "Clemylia/LAM-4-ZERO-F"
14
+ # Dans un Space, le chargement se fait généralement sur l'appareil disponible (GPU si spécifié, sinon CPU).
15
+ # Nous allons laisser Transformers gérer le device_map, car 'cpu' est l'option par défaut si pas de GPU.
16
+ # Pour la robustesse, on peut quand même spécifier l'appareil si on veut être explicite.
17
+ DEVICE = "cpu"
18
+
19
+ print(f"Chargement du modèle {MODEL_NAME} sur {DEVICE}...")
20
+
21
+ # Chargement du tokenizer
22
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
23
+ # Assurez-vous d'avoir un token de padding/fin de séquence
24
+ if tokenizer.pad_token is None:
25
+ tokenizer.pad_token = tokenizer.eos_token
26
+
27
+ # Chargement du modèle. On utilise device_map="auto" pour laisser Transformers
28
+ # gérer l'emplacement optimal, mais comme nous n'avons pas spécifié de GPU
29
+ # l'inférence se fera sur le CPU par défaut si on n'a pas accès à un GPU.
30
+ try:
31
+ model = AutoModelForCausalLM.from_pretrained(
32
+ MODEL_NAME,
33
+ torch_dtype=torch.float32,
34
+ device_map="auto" # Laisse Hugging Face gérer le device pour l'hébergement
35
+ )
36
+ # Assurez-vous qu'il est sur le bon appareil après le chargement
37
+ model.to(DEVICE)
38
+ print("Modèle chargé avec succès.")
39
+
40
+ except Exception as e:
41
+ print(f"Erreur lors du chargement du modèle : {e}")
42
+ # Vous pouvez ajouter ici un modèle de secours ou une erreur plus claire
43
+ exit()
44
+
45
+ # ==============================================================================
46
+ # 3. FONCTION D'INFÉRENCE POUR GRADIO CHAT
47
+ # ==============================================================================
48
+
49
+ def format_prompt(history, message):
50
+ """
51
+ Formate la conversation complète pour le modèle SLM dans le format :
52
+ ### Instruction:
53
+ [HISTORIQUE DE LA CONVERSATION]
54
+ [NOUVELLE QUESTION]
55
+
56
+ ### Response:
57
+ """
58
+
59
+ # 1. Construire l'historique complet pour le contexte
60
+ full_history = ""
61
+ # history est une liste de paires [utilisateur, bot]
62
+ for user_msg, bot_msg in history:
63
+ # On assume un format simple de Questions/Réponses dans l'historique
64
+ full_history += f"### Instruction: {user_msg}\n\n### Response: {bot_msg}\n"
65
+
66
+ # 2. Ajouter la nouvelle question de l'utilisateur
67
+ full_prompt = (
68
+ f"{full_history}" # L'historique des tours précédents
69
+ f"### Instruction: {message}\n\n" # La nouvelle question
70
+ f"### Response:\n" # Le modèle doit continuer à partir d'ici
71
+ )
72
+
73
+ return full_prompt.strip()
74
+
75
+
76
+ def generate_response(message, history):
77
+ """
78
+ Fonction principale appelée par l'interface Gradio Chat.
79
+ """
80
+
81
+ # 1. Formatage du prompt complet avec l'historique
82
+ prompt = format_prompt(history, message)
83
+
84
+ # 2. Tokenization et placement sur le CPU
85
+ inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE)
86
+
87
+ # 3. Génération
88
+ with torch.no_grad(): # Économise la mémoire et accélère légèrement l'inférence
89
+ output_tokens = model.generate(
90
+ **inputs,
91
+ max_new_tokens=150, # Augmenté légèrement pour un chat
92
+ do_sample=True,
93
+ temperature=0.5,
94
+ top_k=40,
95
+ eos_token_id=tokenizer.eos_token_id,
96
+ # Attention : désactiver l'historique de la clé/valeur pour le chat
97
+ # est souvent une bonne idée pour éviter les erreurs de dimension
98
+ use_cache=True,
99
+ )
100
+
101
+ # 4. Décodage et nettoyage
102
+ generated_text = tokenizer.decode(output_tokens[0], skip_special_tokens=True)
103
+
104
+ # Trouver et extraire uniquement la réponse du modèle
105
+ # On se base sur le dernier marqueur '### Response:'
106
+ assistant_prefix = "### Response:\n"
107
+ if assistant_prefix in generated_text:
108
+ # On splitte et prend la dernière partie, puis on nettoie les préfixes de fin
109
+ clean_response = generated_text.split(assistant_prefix)[-1].strip()
110
+
111
+ # Le modèle peut parfois répéter les marqueurs de prompt dans sa réponse
112
+ # On s'assure de n'afficher que ce qui vient après le prompt
113
+ if '### Instruction:' in clean_response:
114
+ clean_response = clean_response.split('### Instruction:')[0].strip()
115
+
116
+ else:
117
+ # Fallback si le format n'est pas parfait
118
+ clean_response = generated_text.replace(prompt, "").strip()
119
+
120
+ return clean_response
121
+
122
+ # ==============================================================================
123
+ # 4. CRÉATION DE L'INTERFACE GRADIO
124
+ # ==============================================================================
125
+
126
+ # Le composant `gr.ChatInterface` est le plus simple et le plus adapté.
127
+ # Il gère l'historique (le paramètre `history` de `generate_response`)
128
+ # et l'affichage de la conversation pour nous.
129
+
130
+ chat_interface = gr.ChatInterface(
131
+ fn=generate_response,
132
+ chatbot=gr.Chatbot(height=500), # Taille du champ de chat
133
+ textbox=gr.Textbox(placeholder="Posez votre question à Lam-4-zero-f...", container=False, scale=7),
134
+ title=f"Chat avec {MODEL_NAME} (SLM)",
135
+ description="Interface de discussion pour le SLM Lam-1. L'inférence est effectuée sur CPU.",
136
+ theme="soft"
137
+ )
138
+
139
+ # Lancement de l'application
140
+ # 'share=False' et 'inbrowser=True' sont pour un test local.
141
+ # Pour le Space, il suffit d'appeler .launch().
142
+ if __name__ == "__main__":
143
+ chat_interface.launch()