Spaces:
Sleeping
Sleeping
Create api_logic.py
Browse files- api_logic.py +267 -0
api_logic.py
ADDED
@@ -0,0 +1,267 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
import faiss
|
4 |
+
import torch
|
5 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
|
6 |
+
from sentence_transformers import SentenceTransformer
|
7 |
+
from huggingface_hub import HfApi # Pode ser removido se não usar para commits externos
|
8 |
+
import logging
|
9 |
+
|
10 |
+
# Configuração de Logging
|
11 |
+
logging.basicConfig(level=logging.INFO)
|
12 |
+
logger = logging.getLogger(__name__)
|
13 |
+
|
14 |
+
# --- Configurações Globais ---
|
15 |
+
MODEL_NAME = "stabilityai/stablelm-3b-4e1t"
|
16 |
+
EMBED_MODEL = "all-MiniLM-L6-v2" # Usando nome curto que funciona
|
17 |
+
SUMMARIZER_ID = "sshleifer/distilbart-cnn-12-6"
|
18 |
+
MEM_FILE = "memory.json" # Salvará no armazenamento do Space
|
19 |
+
IDX_FILE = "index.faiss" # Salvará no armazenamento do Space
|
20 |
+
# HF_TOKEN = os.getenv("HF_TOKEN") # Se precisar para algo no futuro
|
21 |
+
# REPO_ID = os.getenv("SPACE_ID") # Se precisar para algo no futuro
|
22 |
+
|
23 |
+
# --- Variáveis Globais para Modelos e Memória (Carregados uma vez) ---
|
24 |
+
tokenizer = None
|
25 |
+
model = None
|
26 |
+
chat_pipe = None
|
27 |
+
embedder = None
|
28 |
+
summarizer = None
|
29 |
+
memory = []
|
30 |
+
index = None
|
31 |
+
faiss_dimension = 384 # Dimensão padrão para all-MiniLM-L6-v2
|
32 |
+
|
33 |
+
def load_models_and_memory():
|
34 |
+
"""Carrega modelos, pipelines e a memória/índice FAISS."""
|
35 |
+
global tokenizer, model, chat_pipe, embedder, summarizer, memory, index, faiss_dimension
|
36 |
+
|
37 |
+
logger.info("Iniciando carregamento de modelos e memória...")
|
38 |
+
|
39 |
+
# Carrega Tokenizer e Modelo LLM (FP16)
|
40 |
+
logger.info(f"Carregando tokenizer: {MODEL_NAME}")
|
41 |
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
42 |
+
|
43 |
+
logger.info(f"Carregando modelo: {MODEL_NAME} (FP16, CPU)")
|
44 |
+
model = AutoModelForCausalLM.from_pretrained(
|
45 |
+
MODEL_NAME,
|
46 |
+
torch_dtype=torch.float16,
|
47 |
+
low_cpu_mem_usage=True, # Importante para CPU com RAM limitada
|
48 |
+
device_map="auto" # Deixa accelerate decidir (provavelmente 'cpu')
|
49 |
+
)
|
50 |
+
logger.info("Modelo LLM carregado.")
|
51 |
+
|
52 |
+
# Cria Pipeline de Chat
|
53 |
+
logger.info("Criando pipeline de text-generation...")
|
54 |
+
chat_pipe = pipeline(
|
55 |
+
"text-generation",
|
56 |
+
model=model,
|
57 |
+
tokenizer=tokenizer,
|
58 |
+
# max_length=512, # Definir max_new_tokens na chamada é melhor
|
59 |
+
do_sample=True,
|
60 |
+
top_p=0.9,
|
61 |
+
)
|
62 |
+
logger.info("Pipeline de chat criado.")
|
63 |
+
|
64 |
+
# Carrega Embedder
|
65 |
+
logger.info(f"Carregando embedder: {EMBED_MODEL}")
|
66 |
+
embedder = SentenceTransformer(EMBED_MODEL)
|
67 |
+
faiss_dimension = embedder.get_sentence_embedding_dimension() # Atualiza a dimensão
|
68 |
+
logger.info(f"Embedder carregado. Dimensão: {faiss_dimension}")
|
69 |
+
|
70 |
+
|
71 |
+
# Carrega Summarizer
|
72 |
+
logger.info(f"Carregando summarizer: {SUMMARIZER_ID}")
|
73 |
+
summarizer = pipeline("summarization", model=SUMMARIZER_ID)
|
74 |
+
logger.info("Summarizer carregado.")
|
75 |
+
|
76 |
+
# Carrega/Inicializa Memória e Índice FAISS
|
77 |
+
logger.info("Carregando/Inicializando memória e índice FAISS...")
|
78 |
+
if os.path.exists(MEM_FILE):
|
79 |
+
try:
|
80 |
+
with open(MEM_FILE, "r") as f:
|
81 |
+
memory = json.load(f)
|
82 |
+
logger.info(f"Arquivo de memória '{MEM_FILE}' carregado ({len(memory)} entradas).")
|
83 |
+
except Exception as e:
|
84 |
+
logger.error(f"Erro ao carregar {MEM_FILE}: {e}. Iniciando memória vazia.")
|
85 |
+
memory = []
|
86 |
+
else:
|
87 |
+
logger.info(f"Arquivo '{MEM_FILE}' não encontrado. Iniciando memória vazia.")
|
88 |
+
memory = []
|
89 |
+
|
90 |
+
if os.path.exists(IDX_FILE):
|
91 |
+
try:
|
92 |
+
index = faiss.read_index(IDX_FILE)
|
93 |
+
logger.info(f"Índice FAISS '{IDX_FILE}' carregado ({index.ntotal} vetores).")
|
94 |
+
# Validação simples da dimensão
|
95 |
+
if index.ntotal > 0 and index.d != faiss_dimension:
|
96 |
+
logger.warning(f"Dimensão do índice FAISS ({index.d}) diferente da dimensão do embedder ({faiss_dimension})! Recriando índice.")
|
97 |
+
index = faiss.IndexFlatL2(faiss_dimension)
|
98 |
+
# Idealmente, aqui você re-indexaria a memória existente, mas vamos simplificar por agora
|
99 |
+
except Exception as e:
|
100 |
+
logger.error(f"Erro ao carregar {IDX_FILE}: {e}. Recriando índice FAISS.")
|
101 |
+
index = faiss.IndexFlatL2(faiss_dimension)
|
102 |
+
else:
|
103 |
+
logger.info(f"Arquivo '{IDX_FILE}' não encontrado. Criando novo índice FAISS.")
|
104 |
+
index = faiss.IndexFlatL2(faiss_dimension)
|
105 |
+
|
106 |
+
logger.info("Carregamento de modelos e memória concluído.")
|
107 |
+
|
108 |
+
|
109 |
+
# --- Funções de Lógica (Adaptadas) ---
|
110 |
+
|
111 |
+
def save_state():
|
112 |
+
"""Salva o estado atual da memória e do índice FAISS."""
|
113 |
+
logger.info("Salvando estado (memória e índice)...")
|
114 |
+
try:
|
115 |
+
with open(MEM_FILE, "w") as f:
|
116 |
+
json.dump(memory, f, indent=2) # Adicionado indent para legibilidade
|
117 |
+
faiss.write_index(index, IDX_FILE)
|
118 |
+
logger.info("Estado salvo com sucesso.")
|
119 |
+
except Exception as e:
|
120 |
+
logger.error(f"Erro ao salvar estado: {e}")
|
121 |
+
|
122 |
+
def summarize_block(txt):
|
123 |
+
logger.info("Chamando summarizer...")
|
124 |
+
instr = ("Resuma este trecho de diálogo de RPG preservando personagens, locais, itens e eventos:\n\n")
|
125 |
+
try:
|
126 |
+
summary = summarizer(instr + txt, max_length=150, min_length=50)[0]["summary_text"] # Reduzi min_length
|
127 |
+
logger.info("Resumo gerado.")
|
128 |
+
return summary
|
129 |
+
except Exception as e:
|
130 |
+
logger.error(f"Erro no summarizer: {e}")
|
131 |
+
return f"[[Erro ao resumir: {e}]]"
|
132 |
+
|
133 |
+
def compact_memory(threshold=50): # Reduzi o threshold para testar mais rápido
|
134 |
+
logger.info(f"Verificando compactação (limite={threshold}), memória atual={len(memory)}...")
|
135 |
+
if len(memory) < threshold:
|
136 |
+
return False # Indica que não compactou
|
137 |
+
|
138 |
+
logger.info(f"Compactando memória ({threshold} itens)...")
|
139 |
+
bloco = memory[:threshold]
|
140 |
+
texto = "\n".join(
|
141 |
+
itm["text"] if itm["type"]=="summary"
|
142 |
+
else f"Usuário: {itm['user']}\nIA: {itm['bot']}"
|
143 |
+
for itm in bloco
|
144 |
+
)
|
145 |
+
resumo = summarize_block(texto)
|
146 |
+
memory[:threshold] = [{"type":"summary","text":resumo}] # Substitui bloco por resumo
|
147 |
+
|
148 |
+
logger.info("Recriando índice FAISS após compactação...")
|
149 |
+
new_idx = faiss.IndexFlatL2(faiss_dimension)
|
150 |
+
embeddings_to_add = []
|
151 |
+
for itm in memory:
|
152 |
+
key = itm["text"] if itm["type"]=="summary" else itm["user"]
|
153 |
+
try:
|
154 |
+
# Coleta todos os embeddings primeiro
|
155 |
+
embeddings_to_add.append(embedder.encode([key], convert_to_numpy=True)[0])
|
156 |
+
except Exception as e:
|
157 |
+
logger.error(f"Erro ao encodar item para reindexação: {e}")
|
158 |
+
|
159 |
+
if embeddings_to_add:
|
160 |
+
try:
|
161 |
+
new_idx.add(np.array(embeddings_to_add)) # Adiciona em lote
|
162 |
+
global index
|
163 |
+
index = new_idx
|
164 |
+
logger.info("Reindexação FAISS concluída.")
|
165 |
+
save_state() # Salva após compactar e reindexar
|
166 |
+
return True # Indica que compactou
|
167 |
+
except Exception as e:
|
168 |
+
logger.error(f"Erro ao adicionar embeddings ao novo índice: {e}")
|
169 |
+
return False
|
170 |
+
else:
|
171 |
+
logger.info("Nenhum embedding válido para adicionar ao novo índice.")
|
172 |
+
return False
|
173 |
+
|
174 |
+
|
175 |
+
def add_to_memory_and_index(user_msg, bot_msg):
|
176 |
+
logger.info("Adicionando nova entrada à memória e índice...")
|
177 |
+
entry = {"type":"dialog", "user":user_msg, "bot":bot_msg, "text":user_msg}
|
178 |
+
memory.append(entry)
|
179 |
+
try:
|
180 |
+
embedding = embedder.encode([user_msg], convert_to_numpy=True)
|
181 |
+
index.add(embedding)
|
182 |
+
logger.info(f"Embedding adicionado ao índice. Total de vetores: {index.ntotal}")
|
183 |
+
save_state() # Salva após adicionar
|
184 |
+
compact_memory() # Verifica se precisa compactar
|
185 |
+
except Exception as e:
|
186 |
+
logger.error(f"Erro ao adicionar embedding ou salvar estado: {e}")
|
187 |
+
|
188 |
+
|
189 |
+
def run_chat_logic(user_msg):
|
190 |
+
"""Executa a lógica principal do chat: busca contexto, gera resposta."""
|
191 |
+
logger.info(f"Executando lógica do chat para: {repr(user_msg)}")
|
192 |
+
global memory, index # Garante acesso às variáveis globais
|
193 |
+
|
194 |
+
if not all([tokenizer, model, chat_pipe, embedder, index]):
|
195 |
+
logger.error("Modelos ou índice não foram carregados corretamente.")
|
196 |
+
return "Desculpe, o sistema de IA não está pronto. Tente novamente mais tarde."
|
197 |
+
|
198 |
+
# 1. Embedding e Busca FAISS
|
199 |
+
logger.info("Gerando embedding da mensagem do usuário...")
|
200 |
+
try:
|
201 |
+
emb = embedder.encode([user_msg], convert_to_numpy=True)
|
202 |
+
except Exception as e:
|
203 |
+
logger.error(f"Erro ao gerar embedding: {e}")
|
204 |
+
return "Desculpe, houve um erro ao processar sua mensagem (embedding)."
|
205 |
+
|
206 |
+
context = []
|
207 |
+
logger.info(f"Buscando no índice FAISS ({index.ntotal} vetores)...")
|
208 |
+
if index.ntotal > 0:
|
209 |
+
try:
|
210 |
+
D, I = index.search(emb, k=5) # Busca os 5 vizinhos mais próximos
|
211 |
+
logger.info(f"Índices FAISS encontrados: {I[0]}")
|
212 |
+
for idx in I[0]:
|
213 |
+
if 0 <= idx < len(memory): # Validação crucial
|
214 |
+
itm = memory[idx]
|
215 |
+
context.append(
|
216 |
+
f"Lembrança: {itm['text']}" if itm["type"]=="summary"
|
217 |
+
else f"Histórico [Usuário: {itm['user']} | IA: {itm['bot']}]"
|
218 |
+
)
|
219 |
+
else:
|
220 |
+
logger.warning(f"Índice FAISS inválido encontrado: {idx}")
|
221 |
+
except Exception as e:
|
222 |
+
logger.error(f"Erro durante a busca FAISS: {e}")
|
223 |
+
# Continua sem contexto se a busca falhar
|
224 |
+
|
225 |
+
logger.info(f"Contexto recuperado ({len(context)} itens).")
|
226 |
+
|
227 |
+
# 2. Monta Prompt
|
228 |
+
context_str = "\n".join(context)
|
229 |
+
prompt = (
|
230 |
+
f"Você é um Mestre de RPG experiente e criativo. Continue a história de forma envolvente, "
|
231 |
+
f"considerando o seguinte histórico e lembranças:\n{context_str}\n\n"
|
232 |
+
f"Ação do Jogador: {user_msg}\n\nSua Narração:"
|
233 |
+
)
|
234 |
+
logger.info(f"Prompt enviado ao modelo (primeiros 200 chars):\n{prompt[:200]}...")
|
235 |
+
|
236 |
+
# 3. Chama o Modelo (Pipeline)
|
237 |
+
try:
|
238 |
+
logger.info("Chamando pipeline text-generation...")
|
239 |
+
# return_full_text=False pega só a continuação
|
240 |
+
outputs = chat_pipe(prompt, max_new_tokens=200, return_full_text=False, num_return_sequences=1)
|
241 |
+
logger.info(f"Saída bruta do pipeline: {outputs}")
|
242 |
+
|
243 |
+
if not outputs or not outputs[0] or "generated_text" not in outputs[0]:
|
244 |
+
logger.error("Pipeline não retornou 'generated_text' válido.")
|
245 |
+
return "Desculpe, a IA não conseguiu gerar uma resposta válida."
|
246 |
+
|
247 |
+
bot_msg = outputs[0]["generated_text"].strip()
|
248 |
+
# Limpeza adicional
|
249 |
+
bot_msg = bot_msg.split("<|endoftext|>")[0].strip()
|
250 |
+
# Remover repetições exatas do prompt final se houver
|
251 |
+
if bot_msg.startswith(f"Sua Narração:"): bot_msg = bot_msg[len("Sua Narração:"):].strip()
|
252 |
+
|
253 |
+
logger.info(f"Resposta processada do bot: {repr(bot_msg)}")
|
254 |
+
|
255 |
+
# 4. Adiciona na memória e índice
|
256 |
+
add_to_memory_and_index(user_msg, bot_msg)
|
257 |
+
|
258 |
+
return bot_msg
|
259 |
+
|
260 |
+
except Exception as e:
|
261 |
+
logger.exception("Erro durante a execução do pipeline ou pós-processamento.") # Loga o traceback completo
|
262 |
+
return f"Desculpe, ocorreu um erro interno ao gerar a resposta da IA: {e}"
|
263 |
+
|
264 |
+
# --- Carrega tudo na inicialização ---
|
265 |
+
# Esta linha será chamada quando o módulo for importado pela primeira vez
|
266 |
+
# (Ou podemos chamar explicitamente via lifespan do FastAPI)
|
267 |
+
# load_models_and_memory()
|