Spaces:
Sleeping
Sleeping
File size: 11,301 Bytes
83de08e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 |
import os
import json
import faiss
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
from sentence_transformers import SentenceTransformer
from huggingface_hub import HfApi # Pode ser removido se não usar para commits externos
import logging
# Configuração de Logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# --- Configurações Globais ---
MODEL_NAME = "stabilityai/stablelm-3b-4e1t"
EMBED_MODEL = "all-MiniLM-L6-v2" # Usando nome curto que funciona
SUMMARIZER_ID = "sshleifer/distilbart-cnn-12-6"
MEM_FILE = "memory.json" # Salvará no armazenamento do Space
IDX_FILE = "index.faiss" # Salvará no armazenamento do Space
# HF_TOKEN = os.getenv("HF_TOKEN") # Se precisar para algo no futuro
# REPO_ID = os.getenv("SPACE_ID") # Se precisar para algo no futuro
# --- Variáveis Globais para Modelos e Memória (Carregados uma vez) ---
tokenizer = None
model = None
chat_pipe = None
embedder = None
summarizer = None
memory = []
index = None
faiss_dimension = 384 # Dimensão padrão para all-MiniLM-L6-v2
def load_models_and_memory():
"""Carrega modelos, pipelines e a memória/índice FAISS."""
global tokenizer, model, chat_pipe, embedder, summarizer, memory, index, faiss_dimension
logger.info("Iniciando carregamento de modelos e memória...")
# Carrega Tokenizer e Modelo LLM (FP16)
logger.info(f"Carregando tokenizer: {MODEL_NAME}")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
logger.info(f"Carregando modelo: {MODEL_NAME} (FP16, CPU)")
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
torch_dtype=torch.float16,
low_cpu_mem_usage=True, # Importante para CPU com RAM limitada
device_map="auto" # Deixa accelerate decidir (provavelmente 'cpu')
)
logger.info("Modelo LLM carregado.")
# Cria Pipeline de Chat
logger.info("Criando pipeline de text-generation...")
chat_pipe = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
# max_length=512, # Definir max_new_tokens na chamada é melhor
do_sample=True,
top_p=0.9,
)
logger.info("Pipeline de chat criado.")
# Carrega Embedder
logger.info(f"Carregando embedder: {EMBED_MODEL}")
embedder = SentenceTransformer(EMBED_MODEL)
faiss_dimension = embedder.get_sentence_embedding_dimension() # Atualiza a dimensão
logger.info(f"Embedder carregado. Dimensão: {faiss_dimension}")
# Carrega Summarizer
logger.info(f"Carregando summarizer: {SUMMARIZER_ID}")
summarizer = pipeline("summarization", model=SUMMARIZER_ID)
logger.info("Summarizer carregado.")
# Carrega/Inicializa Memória e Índice FAISS
logger.info("Carregando/Inicializando memória e índice FAISS...")
if os.path.exists(MEM_FILE):
try:
with open(MEM_FILE, "r") as f:
memory = json.load(f)
logger.info(f"Arquivo de memória '{MEM_FILE}' carregado ({len(memory)} entradas).")
except Exception as e:
logger.error(f"Erro ao carregar {MEM_FILE}: {e}. Iniciando memória vazia.")
memory = []
else:
logger.info(f"Arquivo '{MEM_FILE}' não encontrado. Iniciando memória vazia.")
memory = []
if os.path.exists(IDX_FILE):
try:
index = faiss.read_index(IDX_FILE)
logger.info(f"Índice FAISS '{IDX_FILE}' carregado ({index.ntotal} vetores).")
# Validação simples da dimensão
if index.ntotal > 0 and index.d != faiss_dimension:
logger.warning(f"Dimensão do índice FAISS ({index.d}) diferente da dimensão do embedder ({faiss_dimension})! Recriando índice.")
index = faiss.IndexFlatL2(faiss_dimension)
# Idealmente, aqui você re-indexaria a memória existente, mas vamos simplificar por agora
except Exception as e:
logger.error(f"Erro ao carregar {IDX_FILE}: {e}. Recriando índice FAISS.")
index = faiss.IndexFlatL2(faiss_dimension)
else:
logger.info(f"Arquivo '{IDX_FILE}' não encontrado. Criando novo índice FAISS.")
index = faiss.IndexFlatL2(faiss_dimension)
logger.info("Carregamento de modelos e memória concluído.")
# --- Funções de Lógica (Adaptadas) ---
def save_state():
"""Salva o estado atual da memória e do índice FAISS."""
logger.info("Salvando estado (memória e índice)...")
try:
with open(MEM_FILE, "w") as f:
json.dump(memory, f, indent=2) # Adicionado indent para legibilidade
faiss.write_index(index, IDX_FILE)
logger.info("Estado salvo com sucesso.")
except Exception as e:
logger.error(f"Erro ao salvar estado: {e}")
def summarize_block(txt):
logger.info("Chamando summarizer...")
instr = ("Resuma este trecho de diálogo de RPG preservando personagens, locais, itens e eventos:\n\n")
try:
summary = summarizer(instr + txt, max_length=150, min_length=50)[0]["summary_text"] # Reduzi min_length
logger.info("Resumo gerado.")
return summary
except Exception as e:
logger.error(f"Erro no summarizer: {e}")
return f"[[Erro ao resumir: {e}]]"
def compact_memory(threshold=50): # Reduzi o threshold para testar mais rápido
logger.info(f"Verificando compactação (limite={threshold}), memória atual={len(memory)}...")
if len(memory) < threshold:
return False # Indica que não compactou
logger.info(f"Compactando memória ({threshold} itens)...")
bloco = memory[:threshold]
texto = "\n".join(
itm["text"] if itm["type"]=="summary"
else f"Usuário: {itm['user']}\nIA: {itm['bot']}"
for itm in bloco
)
resumo = summarize_block(texto)
memory[:threshold] = [{"type":"summary","text":resumo}] # Substitui bloco por resumo
logger.info("Recriando índice FAISS após compactação...")
new_idx = faiss.IndexFlatL2(faiss_dimension)
embeddings_to_add = []
for itm in memory:
key = itm["text"] if itm["type"]=="summary" else itm["user"]
try:
# Coleta todos os embeddings primeiro
embeddings_to_add.append(embedder.encode([key], convert_to_numpy=True)[0])
except Exception as e:
logger.error(f"Erro ao encodar item para reindexação: {e}")
if embeddings_to_add:
try:
new_idx.add(np.array(embeddings_to_add)) # Adiciona em lote
global index
index = new_idx
logger.info("Reindexação FAISS concluída.")
save_state() # Salva após compactar e reindexar
return True # Indica que compactou
except Exception as e:
logger.error(f"Erro ao adicionar embeddings ao novo índice: {e}")
return False
else:
logger.info("Nenhum embedding válido para adicionar ao novo índice.")
return False
def add_to_memory_and_index(user_msg, bot_msg):
logger.info("Adicionando nova entrada à memória e índice...")
entry = {"type":"dialog", "user":user_msg, "bot":bot_msg, "text":user_msg}
memory.append(entry)
try:
embedding = embedder.encode([user_msg], convert_to_numpy=True)
index.add(embedding)
logger.info(f"Embedding adicionado ao índice. Total de vetores: {index.ntotal}")
save_state() # Salva após adicionar
compact_memory() # Verifica se precisa compactar
except Exception as e:
logger.error(f"Erro ao adicionar embedding ou salvar estado: {e}")
def run_chat_logic(user_msg):
"""Executa a lógica principal do chat: busca contexto, gera resposta."""
logger.info(f"Executando lógica do chat para: {repr(user_msg)}")
global memory, index # Garante acesso às variáveis globais
if not all([tokenizer, model, chat_pipe, embedder, index]):
logger.error("Modelos ou índice não foram carregados corretamente.")
return "Desculpe, o sistema de IA não está pronto. Tente novamente mais tarde."
# 1. Embedding e Busca FAISS
logger.info("Gerando embedding da mensagem do usuário...")
try:
emb = embedder.encode([user_msg], convert_to_numpy=True)
except Exception as e:
logger.error(f"Erro ao gerar embedding: {e}")
return "Desculpe, houve um erro ao processar sua mensagem (embedding)."
context = []
logger.info(f"Buscando no índice FAISS ({index.ntotal} vetores)...")
if index.ntotal > 0:
try:
D, I = index.search(emb, k=5) # Busca os 5 vizinhos mais próximos
logger.info(f"Índices FAISS encontrados: {I[0]}")
for idx in I[0]:
if 0 <= idx < len(memory): # Validação crucial
itm = memory[idx]
context.append(
f"Lembrança: {itm['text']}" if itm["type"]=="summary"
else f"Histórico [Usuário: {itm['user']} | IA: {itm['bot']}]"
)
else:
logger.warning(f"Índice FAISS inválido encontrado: {idx}")
except Exception as e:
logger.error(f"Erro durante a busca FAISS: {e}")
# Continua sem contexto se a busca falhar
logger.info(f"Contexto recuperado ({len(context)} itens).")
# 2. Monta Prompt
context_str = "\n".join(context)
prompt = (
f"Você é um Mestre de RPG experiente e criativo. Continue a história de forma envolvente, "
f"considerando o seguinte histórico e lembranças:\n{context_str}\n\n"
f"Ação do Jogador: {user_msg}\n\nSua Narração:"
)
logger.info(f"Prompt enviado ao modelo (primeiros 200 chars):\n{prompt[:200]}...")
# 3. Chama o Modelo (Pipeline)
try:
logger.info("Chamando pipeline text-generation...")
# return_full_text=False pega só a continuação
outputs = chat_pipe(prompt, max_new_tokens=200, return_full_text=False, num_return_sequences=1)
logger.info(f"Saída bruta do pipeline: {outputs}")
if not outputs or not outputs[0] or "generated_text" not in outputs[0]:
logger.error("Pipeline não retornou 'generated_text' válido.")
return "Desculpe, a IA não conseguiu gerar uma resposta válida."
bot_msg = outputs[0]["generated_text"].strip()
# Limpeza adicional
bot_msg = bot_msg.split("<|endoftext|>")[0].strip()
# Remover repetições exatas do prompt final se houver
if bot_msg.startswith(f"Sua Narração:"): bot_msg = bot_msg[len("Sua Narração:"):].strip()
logger.info(f"Resposta processada do bot: {repr(bot_msg)}")
# 4. Adiciona na memória e índice
add_to_memory_and_index(user_msg, bot_msg)
return bot_msg
except Exception as e:
logger.exception("Erro durante a execução do pipeline ou pós-processamento.") # Loga o traceback completo
return f"Desculpe, ocorreu um erro interno ao gerar a resposta da IA: {e}"
# --- Carrega tudo na inicialização ---
# Esta linha será chamada quando o módulo for importado pela primeira vez
# (Ou podemos chamar explicitamente via lifespan do FastAPI)
# load_models_and_memory() |