Spaces:
Sleeping
Sleeping
import os | |
import json | |
import faiss | |
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline | |
from sentence_transformers import SentenceTransformer | |
from huggingface_hub import HfApi # Pode ser removido se não usar para commits externos | |
import logging | |
# Configuração de Logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# --- Configurações Globais --- | |
MODEL_NAME = "stabilityai/stablelm-3b-4e1t" | |
EMBED_MODEL = "all-MiniLM-L6-v2" # Usando nome curto que funciona | |
SUMMARIZER_ID = "sshleifer/distilbart-cnn-12-6" | |
MEM_FILE = "memory.json" # Salvará no armazenamento do Space | |
IDX_FILE = "index.faiss" # Salvará no armazenamento do Space | |
# HF_TOKEN = os.getenv("HF_TOKEN") # Se precisar para algo no futuro | |
# REPO_ID = os.getenv("SPACE_ID") # Se precisar para algo no futuro | |
# --- Variáveis Globais para Modelos e Memória (Carregados uma vez) --- | |
tokenizer = None | |
model = None | |
chat_pipe = None | |
embedder = None | |
summarizer = None | |
memory = [] | |
index = None | |
faiss_dimension = 384 # Dimensão padrão para all-MiniLM-L6-v2 | |
def load_models_and_memory(): | |
"""Carrega modelos, pipelines e a memória/índice FAISS.""" | |
global tokenizer, model, chat_pipe, embedder, summarizer, memory, index, faiss_dimension | |
logger.info("Iniciando carregamento de modelos e memória...") | |
# Carrega Tokenizer e Modelo LLM (FP16) | |
logger.info(f"Carregando tokenizer: {MODEL_NAME}") | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
logger.info(f"Carregando modelo: {MODEL_NAME} (FP16, CPU)") | |
model = AutoModelForCausalLM.from_pretrained( | |
MODEL_NAME, | |
torch_dtype=torch.float16, | |
low_cpu_mem_usage=True, # Importante para CPU com RAM limitada | |
device_map="auto" # Deixa accelerate decidir (provavelmente 'cpu') | |
) | |
logger.info("Modelo LLM carregado.") | |
# Cria Pipeline de Chat | |
logger.info("Criando pipeline de text-generation...") | |
chat_pipe = pipeline( | |
"text-generation", | |
model=model, | |
tokenizer=tokenizer, | |
# max_length=512, # Definir max_new_tokens na chamada é melhor | |
do_sample=True, | |
top_p=0.9, | |
) | |
logger.info("Pipeline de chat criado.") | |
# Carrega Embedder | |
logger.info(f"Carregando embedder: {EMBED_MODEL}") | |
embedder = SentenceTransformer(EMBED_MODEL) | |
faiss_dimension = embedder.get_sentence_embedding_dimension() # Atualiza a dimensão | |
logger.info(f"Embedder carregado. Dimensão: {faiss_dimension}") | |
# Carrega Summarizer | |
logger.info(f"Carregando summarizer: {SUMMARIZER_ID}") | |
summarizer = pipeline("summarization", model=SUMMARIZER_ID) | |
logger.info("Summarizer carregado.") | |
# Carrega/Inicializa Memória e Índice FAISS | |
logger.info("Carregando/Inicializando memória e índice FAISS...") | |
if os.path.exists(MEM_FILE): | |
try: | |
with open(MEM_FILE, "r") as f: | |
memory = json.load(f) | |
logger.info(f"Arquivo de memória '{MEM_FILE}' carregado ({len(memory)} entradas).") | |
except Exception as e: | |
logger.error(f"Erro ao carregar {MEM_FILE}: {e}. Iniciando memória vazia.") | |
memory = [] | |
else: | |
logger.info(f"Arquivo '{MEM_FILE}' não encontrado. Iniciando memória vazia.") | |
memory = [] | |
if os.path.exists(IDX_FILE): | |
try: | |
index = faiss.read_index(IDX_FILE) | |
logger.info(f"Índice FAISS '{IDX_FILE}' carregado ({index.ntotal} vetores).") | |
# Validação simples da dimensão | |
if index.ntotal > 0 and index.d != faiss_dimension: | |
logger.warning(f"Dimensão do índice FAISS ({index.d}) diferente da dimensão do embedder ({faiss_dimension})! Recriando índice.") | |
index = faiss.IndexFlatL2(faiss_dimension) | |
# Idealmente, aqui você re-indexaria a memória existente, mas vamos simplificar por agora | |
except Exception as e: | |
logger.error(f"Erro ao carregar {IDX_FILE}: {e}. Recriando índice FAISS.") | |
index = faiss.IndexFlatL2(faiss_dimension) | |
else: | |
logger.info(f"Arquivo '{IDX_FILE}' não encontrado. Criando novo índice FAISS.") | |
index = faiss.IndexFlatL2(faiss_dimension) | |
logger.info("Carregamento de modelos e memória concluído.") | |
# --- Funções de Lógica (Adaptadas) --- | |
def save_state(): | |
"""Salva o estado atual da memória e do índice FAISS.""" | |
logger.info("Salvando estado (memória e índice)...") | |
try: | |
with open(MEM_FILE, "w") as f: | |
json.dump(memory, f, indent=2) # Adicionado indent para legibilidade | |
faiss.write_index(index, IDX_FILE) | |
logger.info("Estado salvo com sucesso.") | |
except Exception as e: | |
logger.error(f"Erro ao salvar estado: {e}") | |
def summarize_block(txt): | |
logger.info("Chamando summarizer...") | |
instr = ("Resuma este trecho de diálogo de RPG preservando personagens, locais, itens e eventos:\n\n") | |
try: | |
summary = summarizer(instr + txt, max_length=150, min_length=50)[0]["summary_text"] # Reduzi min_length | |
logger.info("Resumo gerado.") | |
return summary | |
except Exception as e: | |
logger.error(f"Erro no summarizer: {e}") | |
return f"[[Erro ao resumir: {e}]]" | |
def compact_memory(threshold=50): # Reduzi o threshold para testar mais rápido | |
logger.info(f"Verificando compactação (limite={threshold}), memória atual={len(memory)}...") | |
if len(memory) < threshold: | |
return False # Indica que não compactou | |
logger.info(f"Compactando memória ({threshold} itens)...") | |
bloco = memory[:threshold] | |
texto = "\n".join( | |
itm["text"] if itm["type"]=="summary" | |
else f"Usuário: {itm['user']}\nIA: {itm['bot']}" | |
for itm in bloco | |
) | |
resumo = summarize_block(texto) | |
memory[:threshold] = [{"type":"summary","text":resumo}] # Substitui bloco por resumo | |
logger.info("Recriando índice FAISS após compactação...") | |
new_idx = faiss.IndexFlatL2(faiss_dimension) | |
embeddings_to_add = [] | |
for itm in memory: | |
key = itm["text"] if itm["type"]=="summary" else itm["user"] | |
try: | |
# Coleta todos os embeddings primeiro | |
embeddings_to_add.append(embedder.encode([key], convert_to_numpy=True)[0]) | |
except Exception as e: | |
logger.error(f"Erro ao encodar item para reindexação: {e}") | |
if embeddings_to_add: | |
try: | |
new_idx.add(np.array(embeddings_to_add)) # Adiciona em lote | |
global index | |
index = new_idx | |
logger.info("Reindexação FAISS concluída.") | |
save_state() # Salva após compactar e reindexar | |
return True # Indica que compactou | |
except Exception as e: | |
logger.error(f"Erro ao adicionar embeddings ao novo índice: {e}") | |
return False | |
else: | |
logger.info("Nenhum embedding válido para adicionar ao novo índice.") | |
return False | |
def add_to_memory_and_index(user_msg, bot_msg): | |
logger.info("Adicionando nova entrada à memória e índice...") | |
entry = {"type":"dialog", "user":user_msg, "bot":bot_msg, "text":user_msg} | |
memory.append(entry) | |
try: | |
embedding = embedder.encode([user_msg], convert_to_numpy=True) | |
index.add(embedding) | |
logger.info(f"Embedding adicionado ao índice. Total de vetores: {index.ntotal}") | |
save_state() # Salva após adicionar | |
compact_memory() # Verifica se precisa compactar | |
except Exception as e: | |
logger.error(f"Erro ao adicionar embedding ou salvar estado: {e}") | |
def run_chat_logic(user_msg): | |
"""Executa a lógica principal do chat: busca contexto, gera resposta.""" | |
logger.info(f"Executando lógica do chat para: {repr(user_msg)}") | |
global memory, index # Garante acesso às variáveis globais | |
if not all([tokenizer, model, chat_pipe, embedder, index]): | |
logger.error("Modelos ou índice não foram carregados corretamente.") | |
return "Desculpe, o sistema de IA não está pronto. Tente novamente mais tarde." | |
# 1. Embedding e Busca FAISS | |
logger.info("Gerando embedding da mensagem do usuário...") | |
try: | |
emb = embedder.encode([user_msg], convert_to_numpy=True) | |
except Exception as e: | |
logger.error(f"Erro ao gerar embedding: {e}") | |
return "Desculpe, houve um erro ao processar sua mensagem (embedding)." | |
context = [] | |
logger.info(f"Buscando no índice FAISS ({index.ntotal} vetores)...") | |
if index.ntotal > 0: | |
try: | |
D, I = index.search(emb, k=5) # Busca os 5 vizinhos mais próximos | |
logger.info(f"Índices FAISS encontrados: {I[0]}") | |
for idx in I[0]: | |
if 0 <= idx < len(memory): # Validação crucial | |
itm = memory[idx] | |
context.append( | |
f"Lembrança: {itm['text']}" if itm["type"]=="summary" | |
else f"Histórico [Usuário: {itm['user']} | IA: {itm['bot']}]" | |
) | |
else: | |
logger.warning(f"Índice FAISS inválido encontrado: {idx}") | |
except Exception as e: | |
logger.error(f"Erro durante a busca FAISS: {e}") | |
# Continua sem contexto se a busca falhar | |
logger.info(f"Contexto recuperado ({len(context)} itens).") | |
# 2. Monta Prompt | |
context_str = "\n".join(context) | |
prompt = ( | |
f"Você é um Mestre de RPG experiente e criativo. Continue a história de forma envolvente, " | |
f"considerando o seguinte histórico e lembranças:\n{context_str}\n\n" | |
f"Ação do Jogador: {user_msg}\n\nSua Narração:" | |
) | |
logger.info(f"Prompt enviado ao modelo (primeiros 200 chars):\n{prompt[:200]}...") | |
# 3. Chama o Modelo (Pipeline) | |
try: | |
logger.info("Chamando pipeline text-generation...") | |
# return_full_text=False pega só a continuação | |
outputs = chat_pipe(prompt, max_new_tokens=200, return_full_text=False, num_return_sequences=1) | |
logger.info(f"Saída bruta do pipeline: {outputs}") | |
if not outputs or not outputs[0] or "generated_text" not in outputs[0]: | |
logger.error("Pipeline não retornou 'generated_text' válido.") | |
return "Desculpe, a IA não conseguiu gerar uma resposta válida." | |
bot_msg = outputs[0]["generated_text"].strip() | |
# Limpeza adicional | |
bot_msg = bot_msg.split("<|endoftext|>")[0].strip() | |
# Remover repetições exatas do prompt final se houver | |
if bot_msg.startswith(f"Sua Narração:"): bot_msg = bot_msg[len("Sua Narração:"):].strip() | |
logger.info(f"Resposta processada do bot: {repr(bot_msg)}") | |
# 4. Adiciona na memória e índice | |
add_to_memory_and_index(user_msg, bot_msg) | |
return bot_msg | |
except Exception as e: | |
logger.exception("Erro durante a execução do pipeline ou pós-processamento.") # Loga o traceback completo | |
return f"Desculpe, ocorreu um erro interno ao gerar a resposta da IA: {e}" | |
# --- Carrega tudo na inicialização --- | |
# Esta linha será chamada quando o módulo for importado pela primeira vez | |
# (Ou podemos chamar explicitamente via lifespan do FastAPI) | |
# load_models_and_memory() |