KaykySouza commited on
Commit
83de08e
·
verified ·
1 Parent(s): bb68492

Create api_logic.py

Browse files
Files changed (1) hide show
  1. api_logic.py +267 -0
api_logic.py ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import faiss
4
+ import torch
5
+ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
6
+ from sentence_transformers import SentenceTransformer
7
+ from huggingface_hub import HfApi # Pode ser removido se não usar para commits externos
8
+ import logging
9
+
10
+ # Configuração de Logging
11
+ logging.basicConfig(level=logging.INFO)
12
+ logger = logging.getLogger(__name__)
13
+
14
+ # --- Configurações Globais ---
15
+ MODEL_NAME = "stabilityai/stablelm-3b-4e1t"
16
+ EMBED_MODEL = "all-MiniLM-L6-v2" # Usando nome curto que funciona
17
+ SUMMARIZER_ID = "sshleifer/distilbart-cnn-12-6"
18
+ MEM_FILE = "memory.json" # Salvará no armazenamento do Space
19
+ IDX_FILE = "index.faiss" # Salvará no armazenamento do Space
20
+ # HF_TOKEN = os.getenv("HF_TOKEN") # Se precisar para algo no futuro
21
+ # REPO_ID = os.getenv("SPACE_ID") # Se precisar para algo no futuro
22
+
23
+ # --- Variáveis Globais para Modelos e Memória (Carregados uma vez) ---
24
+ tokenizer = None
25
+ model = None
26
+ chat_pipe = None
27
+ embedder = None
28
+ summarizer = None
29
+ memory = []
30
+ index = None
31
+ faiss_dimension = 384 # Dimensão padrão para all-MiniLM-L6-v2
32
+
33
+ def load_models_and_memory():
34
+ """Carrega modelos, pipelines e a memória/índice FAISS."""
35
+ global tokenizer, model, chat_pipe, embedder, summarizer, memory, index, faiss_dimension
36
+
37
+ logger.info("Iniciando carregamento de modelos e memória...")
38
+
39
+ # Carrega Tokenizer e Modelo LLM (FP16)
40
+ logger.info(f"Carregando tokenizer: {MODEL_NAME}")
41
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
42
+
43
+ logger.info(f"Carregando modelo: {MODEL_NAME} (FP16, CPU)")
44
+ model = AutoModelForCausalLM.from_pretrained(
45
+ MODEL_NAME,
46
+ torch_dtype=torch.float16,
47
+ low_cpu_mem_usage=True, # Importante para CPU com RAM limitada
48
+ device_map="auto" # Deixa accelerate decidir (provavelmente 'cpu')
49
+ )
50
+ logger.info("Modelo LLM carregado.")
51
+
52
+ # Cria Pipeline de Chat
53
+ logger.info("Criando pipeline de text-generation...")
54
+ chat_pipe = pipeline(
55
+ "text-generation",
56
+ model=model,
57
+ tokenizer=tokenizer,
58
+ # max_length=512, # Definir max_new_tokens na chamada é melhor
59
+ do_sample=True,
60
+ top_p=0.9,
61
+ )
62
+ logger.info("Pipeline de chat criado.")
63
+
64
+ # Carrega Embedder
65
+ logger.info(f"Carregando embedder: {EMBED_MODEL}")
66
+ embedder = SentenceTransformer(EMBED_MODEL)
67
+ faiss_dimension = embedder.get_sentence_embedding_dimension() # Atualiza a dimensão
68
+ logger.info(f"Embedder carregado. Dimensão: {faiss_dimension}")
69
+
70
+
71
+ # Carrega Summarizer
72
+ logger.info(f"Carregando summarizer: {SUMMARIZER_ID}")
73
+ summarizer = pipeline("summarization", model=SUMMARIZER_ID)
74
+ logger.info("Summarizer carregado.")
75
+
76
+ # Carrega/Inicializa Memória e Índice FAISS
77
+ logger.info("Carregando/Inicializando memória e índice FAISS...")
78
+ if os.path.exists(MEM_FILE):
79
+ try:
80
+ with open(MEM_FILE, "r") as f:
81
+ memory = json.load(f)
82
+ logger.info(f"Arquivo de memória '{MEM_FILE}' carregado ({len(memory)} entradas).")
83
+ except Exception as e:
84
+ logger.error(f"Erro ao carregar {MEM_FILE}: {e}. Iniciando memória vazia.")
85
+ memory = []
86
+ else:
87
+ logger.info(f"Arquivo '{MEM_FILE}' não encontrado. Iniciando memória vazia.")
88
+ memory = []
89
+
90
+ if os.path.exists(IDX_FILE):
91
+ try:
92
+ index = faiss.read_index(IDX_FILE)
93
+ logger.info(f"Índice FAISS '{IDX_FILE}' carregado ({index.ntotal} vetores).")
94
+ # Validação simples da dimensão
95
+ if index.ntotal > 0 and index.d != faiss_dimension:
96
+ logger.warning(f"Dimensão do índice FAISS ({index.d}) diferente da dimensão do embedder ({faiss_dimension})! Recriando índice.")
97
+ index = faiss.IndexFlatL2(faiss_dimension)
98
+ # Idealmente, aqui você re-indexaria a memória existente, mas vamos simplificar por agora
99
+ except Exception as e:
100
+ logger.error(f"Erro ao carregar {IDX_FILE}: {e}. Recriando índice FAISS.")
101
+ index = faiss.IndexFlatL2(faiss_dimension)
102
+ else:
103
+ logger.info(f"Arquivo '{IDX_FILE}' não encontrado. Criando novo índice FAISS.")
104
+ index = faiss.IndexFlatL2(faiss_dimension)
105
+
106
+ logger.info("Carregamento de modelos e memória concluído.")
107
+
108
+
109
+ # --- Funções de Lógica (Adaptadas) ---
110
+
111
+ def save_state():
112
+ """Salva o estado atual da memória e do índice FAISS."""
113
+ logger.info("Salvando estado (memória e índice)...")
114
+ try:
115
+ with open(MEM_FILE, "w") as f:
116
+ json.dump(memory, f, indent=2) # Adicionado indent para legibilidade
117
+ faiss.write_index(index, IDX_FILE)
118
+ logger.info("Estado salvo com sucesso.")
119
+ except Exception as e:
120
+ logger.error(f"Erro ao salvar estado: {e}")
121
+
122
+ def summarize_block(txt):
123
+ logger.info("Chamando summarizer...")
124
+ instr = ("Resuma este trecho de diálogo de RPG preservando personagens, locais, itens e eventos:\n\n")
125
+ try:
126
+ summary = summarizer(instr + txt, max_length=150, min_length=50)[0]["summary_text"] # Reduzi min_length
127
+ logger.info("Resumo gerado.")
128
+ return summary
129
+ except Exception as e:
130
+ logger.error(f"Erro no summarizer: {e}")
131
+ return f"[[Erro ao resumir: {e}]]"
132
+
133
+ def compact_memory(threshold=50): # Reduzi o threshold para testar mais rápido
134
+ logger.info(f"Verificando compactação (limite={threshold}), memória atual={len(memory)}...")
135
+ if len(memory) < threshold:
136
+ return False # Indica que não compactou
137
+
138
+ logger.info(f"Compactando memória ({threshold} itens)...")
139
+ bloco = memory[:threshold]
140
+ texto = "\n".join(
141
+ itm["text"] if itm["type"]=="summary"
142
+ else f"Usuário: {itm['user']}\nIA: {itm['bot']}"
143
+ for itm in bloco
144
+ )
145
+ resumo = summarize_block(texto)
146
+ memory[:threshold] = [{"type":"summary","text":resumo}] # Substitui bloco por resumo
147
+
148
+ logger.info("Recriando índice FAISS após compactação...")
149
+ new_idx = faiss.IndexFlatL2(faiss_dimension)
150
+ embeddings_to_add = []
151
+ for itm in memory:
152
+ key = itm["text"] if itm["type"]=="summary" else itm["user"]
153
+ try:
154
+ # Coleta todos os embeddings primeiro
155
+ embeddings_to_add.append(embedder.encode([key], convert_to_numpy=True)[0])
156
+ except Exception as e:
157
+ logger.error(f"Erro ao encodar item para reindexação: {e}")
158
+
159
+ if embeddings_to_add:
160
+ try:
161
+ new_idx.add(np.array(embeddings_to_add)) # Adiciona em lote
162
+ global index
163
+ index = new_idx
164
+ logger.info("Reindexação FAISS concluída.")
165
+ save_state() # Salva após compactar e reindexar
166
+ return True # Indica que compactou
167
+ except Exception as e:
168
+ logger.error(f"Erro ao adicionar embeddings ao novo índice: {e}")
169
+ return False
170
+ else:
171
+ logger.info("Nenhum embedding válido para adicionar ao novo índice.")
172
+ return False
173
+
174
+
175
+ def add_to_memory_and_index(user_msg, bot_msg):
176
+ logger.info("Adicionando nova entrada à memória e índice...")
177
+ entry = {"type":"dialog", "user":user_msg, "bot":bot_msg, "text":user_msg}
178
+ memory.append(entry)
179
+ try:
180
+ embedding = embedder.encode([user_msg], convert_to_numpy=True)
181
+ index.add(embedding)
182
+ logger.info(f"Embedding adicionado ao índice. Total de vetores: {index.ntotal}")
183
+ save_state() # Salva após adicionar
184
+ compact_memory() # Verifica se precisa compactar
185
+ except Exception as e:
186
+ logger.error(f"Erro ao adicionar embedding ou salvar estado: {e}")
187
+
188
+
189
+ def run_chat_logic(user_msg):
190
+ """Executa a lógica principal do chat: busca contexto, gera resposta."""
191
+ logger.info(f"Executando lógica do chat para: {repr(user_msg)}")
192
+ global memory, index # Garante acesso às variáveis globais
193
+
194
+ if not all([tokenizer, model, chat_pipe, embedder, index]):
195
+ logger.error("Modelos ou índice não foram carregados corretamente.")
196
+ return "Desculpe, o sistema de IA não está pronto. Tente novamente mais tarde."
197
+
198
+ # 1. Embedding e Busca FAISS
199
+ logger.info("Gerando embedding da mensagem do usuário...")
200
+ try:
201
+ emb = embedder.encode([user_msg], convert_to_numpy=True)
202
+ except Exception as e:
203
+ logger.error(f"Erro ao gerar embedding: {e}")
204
+ return "Desculpe, houve um erro ao processar sua mensagem (embedding)."
205
+
206
+ context = []
207
+ logger.info(f"Buscando no índice FAISS ({index.ntotal} vetores)...")
208
+ if index.ntotal > 0:
209
+ try:
210
+ D, I = index.search(emb, k=5) # Busca os 5 vizinhos mais próximos
211
+ logger.info(f"Índices FAISS encontrados: {I[0]}")
212
+ for idx in I[0]:
213
+ if 0 <= idx < len(memory): # Validação crucial
214
+ itm = memory[idx]
215
+ context.append(
216
+ f"Lembrança: {itm['text']}" if itm["type"]=="summary"
217
+ else f"Histórico [Usuário: {itm['user']} | IA: {itm['bot']}]"
218
+ )
219
+ else:
220
+ logger.warning(f"Índice FAISS inválido encontrado: {idx}")
221
+ except Exception as e:
222
+ logger.error(f"Erro durante a busca FAISS: {e}")
223
+ # Continua sem contexto se a busca falhar
224
+
225
+ logger.info(f"Contexto recuperado ({len(context)} itens).")
226
+
227
+ # 2. Monta Prompt
228
+ context_str = "\n".join(context)
229
+ prompt = (
230
+ f"Você é um Mestre de RPG experiente e criativo. Continue a história de forma envolvente, "
231
+ f"considerando o seguinte histórico e lembranças:\n{context_str}\n\n"
232
+ f"Ação do Jogador: {user_msg}\n\nSua Narração:"
233
+ )
234
+ logger.info(f"Prompt enviado ao modelo (primeiros 200 chars):\n{prompt[:200]}...")
235
+
236
+ # 3. Chama o Modelo (Pipeline)
237
+ try:
238
+ logger.info("Chamando pipeline text-generation...")
239
+ # return_full_text=False pega só a continuação
240
+ outputs = chat_pipe(prompt, max_new_tokens=200, return_full_text=False, num_return_sequences=1)
241
+ logger.info(f"Saída bruta do pipeline: {outputs}")
242
+
243
+ if not outputs or not outputs[0] or "generated_text" not in outputs[0]:
244
+ logger.error("Pipeline não retornou 'generated_text' válido.")
245
+ return "Desculpe, a IA não conseguiu gerar uma resposta válida."
246
+
247
+ bot_msg = outputs[0]["generated_text"].strip()
248
+ # Limpeza adicional
249
+ bot_msg = bot_msg.split("<|endoftext|>")[0].strip()
250
+ # Remover repetições exatas do prompt final se houver
251
+ if bot_msg.startswith(f"Sua Narração:"): bot_msg = bot_msg[len("Sua Narração:"):].strip()
252
+
253
+ logger.info(f"Resposta processada do bot: {repr(bot_msg)}")
254
+
255
+ # 4. Adiciona na memória e índice
256
+ add_to_memory_and_index(user_msg, bot_msg)
257
+
258
+ return bot_msg
259
+
260
+ except Exception as e:
261
+ logger.exception("Erro durante a execução do pipeline ou pós-processamento.") # Loga o traceback completo
262
+ return f"Desculpe, ocorreu um erro interno ao gerar a resposta da IA: {e}"
263
+
264
+ # --- Carrega tudo na inicialização ---
265
+ # Esta linha será chamada quando o módulo for importado pela primeira vez
266
+ # (Ou podemos chamar explicitamente via lifespan do FastAPI)
267
+ # load_models_and_memory()