rpg-api-backend / api_logic.py
KaykySouza's picture
Create api_logic.py
83de08e verified
import os
import json
import faiss
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
from sentence_transformers import SentenceTransformer
from huggingface_hub import HfApi # Pode ser removido se não usar para commits externos
import logging
# Configuração de Logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# --- Configurações Globais ---
MODEL_NAME = "stabilityai/stablelm-3b-4e1t"
EMBED_MODEL = "all-MiniLM-L6-v2" # Usando nome curto que funciona
SUMMARIZER_ID = "sshleifer/distilbart-cnn-12-6"
MEM_FILE = "memory.json" # Salvará no armazenamento do Space
IDX_FILE = "index.faiss" # Salvará no armazenamento do Space
# HF_TOKEN = os.getenv("HF_TOKEN") # Se precisar para algo no futuro
# REPO_ID = os.getenv("SPACE_ID") # Se precisar para algo no futuro
# --- Variáveis Globais para Modelos e Memória (Carregados uma vez) ---
tokenizer = None
model = None
chat_pipe = None
embedder = None
summarizer = None
memory = []
index = None
faiss_dimension = 384 # Dimensão padrão para all-MiniLM-L6-v2
def load_models_and_memory():
"""Carrega modelos, pipelines e a memória/índice FAISS."""
global tokenizer, model, chat_pipe, embedder, summarizer, memory, index, faiss_dimension
logger.info("Iniciando carregamento de modelos e memória...")
# Carrega Tokenizer e Modelo LLM (FP16)
logger.info(f"Carregando tokenizer: {MODEL_NAME}")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
logger.info(f"Carregando modelo: {MODEL_NAME} (FP16, CPU)")
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
torch_dtype=torch.float16,
low_cpu_mem_usage=True, # Importante para CPU com RAM limitada
device_map="auto" # Deixa accelerate decidir (provavelmente 'cpu')
)
logger.info("Modelo LLM carregado.")
# Cria Pipeline de Chat
logger.info("Criando pipeline de text-generation...")
chat_pipe = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
# max_length=512, # Definir max_new_tokens na chamada é melhor
do_sample=True,
top_p=0.9,
)
logger.info("Pipeline de chat criado.")
# Carrega Embedder
logger.info(f"Carregando embedder: {EMBED_MODEL}")
embedder = SentenceTransformer(EMBED_MODEL)
faiss_dimension = embedder.get_sentence_embedding_dimension() # Atualiza a dimensão
logger.info(f"Embedder carregado. Dimensão: {faiss_dimension}")
# Carrega Summarizer
logger.info(f"Carregando summarizer: {SUMMARIZER_ID}")
summarizer = pipeline("summarization", model=SUMMARIZER_ID)
logger.info("Summarizer carregado.")
# Carrega/Inicializa Memória e Índice FAISS
logger.info("Carregando/Inicializando memória e índice FAISS...")
if os.path.exists(MEM_FILE):
try:
with open(MEM_FILE, "r") as f:
memory = json.load(f)
logger.info(f"Arquivo de memória '{MEM_FILE}' carregado ({len(memory)} entradas).")
except Exception as e:
logger.error(f"Erro ao carregar {MEM_FILE}: {e}. Iniciando memória vazia.")
memory = []
else:
logger.info(f"Arquivo '{MEM_FILE}' não encontrado. Iniciando memória vazia.")
memory = []
if os.path.exists(IDX_FILE):
try:
index = faiss.read_index(IDX_FILE)
logger.info(f"Índice FAISS '{IDX_FILE}' carregado ({index.ntotal} vetores).")
# Validação simples da dimensão
if index.ntotal > 0 and index.d != faiss_dimension:
logger.warning(f"Dimensão do índice FAISS ({index.d}) diferente da dimensão do embedder ({faiss_dimension})! Recriando índice.")
index = faiss.IndexFlatL2(faiss_dimension)
# Idealmente, aqui você re-indexaria a memória existente, mas vamos simplificar por agora
except Exception as e:
logger.error(f"Erro ao carregar {IDX_FILE}: {e}. Recriando índice FAISS.")
index = faiss.IndexFlatL2(faiss_dimension)
else:
logger.info(f"Arquivo '{IDX_FILE}' não encontrado. Criando novo índice FAISS.")
index = faiss.IndexFlatL2(faiss_dimension)
logger.info("Carregamento de modelos e memória concluído.")
# --- Funções de Lógica (Adaptadas) ---
def save_state():
"""Salva o estado atual da memória e do índice FAISS."""
logger.info("Salvando estado (memória e índice)...")
try:
with open(MEM_FILE, "w") as f:
json.dump(memory, f, indent=2) # Adicionado indent para legibilidade
faiss.write_index(index, IDX_FILE)
logger.info("Estado salvo com sucesso.")
except Exception as e:
logger.error(f"Erro ao salvar estado: {e}")
def summarize_block(txt):
logger.info("Chamando summarizer...")
instr = ("Resuma este trecho de diálogo de RPG preservando personagens, locais, itens e eventos:\n\n")
try:
summary = summarizer(instr + txt, max_length=150, min_length=50)[0]["summary_text"] # Reduzi min_length
logger.info("Resumo gerado.")
return summary
except Exception as e:
logger.error(f"Erro no summarizer: {e}")
return f"[[Erro ao resumir: {e}]]"
def compact_memory(threshold=50): # Reduzi o threshold para testar mais rápido
logger.info(f"Verificando compactação (limite={threshold}), memória atual={len(memory)}...")
if len(memory) < threshold:
return False # Indica que não compactou
logger.info(f"Compactando memória ({threshold} itens)...")
bloco = memory[:threshold]
texto = "\n".join(
itm["text"] if itm["type"]=="summary"
else f"Usuário: {itm['user']}\nIA: {itm['bot']}"
for itm in bloco
)
resumo = summarize_block(texto)
memory[:threshold] = [{"type":"summary","text":resumo}] # Substitui bloco por resumo
logger.info("Recriando índice FAISS após compactação...")
new_idx = faiss.IndexFlatL2(faiss_dimension)
embeddings_to_add = []
for itm in memory:
key = itm["text"] if itm["type"]=="summary" else itm["user"]
try:
# Coleta todos os embeddings primeiro
embeddings_to_add.append(embedder.encode([key], convert_to_numpy=True)[0])
except Exception as e:
logger.error(f"Erro ao encodar item para reindexação: {e}")
if embeddings_to_add:
try:
new_idx.add(np.array(embeddings_to_add)) # Adiciona em lote
global index
index = new_idx
logger.info("Reindexação FAISS concluída.")
save_state() # Salva após compactar e reindexar
return True # Indica que compactou
except Exception as e:
logger.error(f"Erro ao adicionar embeddings ao novo índice: {e}")
return False
else:
logger.info("Nenhum embedding válido para adicionar ao novo índice.")
return False
def add_to_memory_and_index(user_msg, bot_msg):
logger.info("Adicionando nova entrada à memória e índice...")
entry = {"type":"dialog", "user":user_msg, "bot":bot_msg, "text":user_msg}
memory.append(entry)
try:
embedding = embedder.encode([user_msg], convert_to_numpy=True)
index.add(embedding)
logger.info(f"Embedding adicionado ao índice. Total de vetores: {index.ntotal}")
save_state() # Salva após adicionar
compact_memory() # Verifica se precisa compactar
except Exception as e:
logger.error(f"Erro ao adicionar embedding ou salvar estado: {e}")
def run_chat_logic(user_msg):
"""Executa a lógica principal do chat: busca contexto, gera resposta."""
logger.info(f"Executando lógica do chat para: {repr(user_msg)}")
global memory, index # Garante acesso às variáveis globais
if not all([tokenizer, model, chat_pipe, embedder, index]):
logger.error("Modelos ou índice não foram carregados corretamente.")
return "Desculpe, o sistema de IA não está pronto. Tente novamente mais tarde."
# 1. Embedding e Busca FAISS
logger.info("Gerando embedding da mensagem do usuário...")
try:
emb = embedder.encode([user_msg], convert_to_numpy=True)
except Exception as e:
logger.error(f"Erro ao gerar embedding: {e}")
return "Desculpe, houve um erro ao processar sua mensagem (embedding)."
context = []
logger.info(f"Buscando no índice FAISS ({index.ntotal} vetores)...")
if index.ntotal > 0:
try:
D, I = index.search(emb, k=5) # Busca os 5 vizinhos mais próximos
logger.info(f"Índices FAISS encontrados: {I[0]}")
for idx in I[0]:
if 0 <= idx < len(memory): # Validação crucial
itm = memory[idx]
context.append(
f"Lembrança: {itm['text']}" if itm["type"]=="summary"
else f"Histórico [Usuário: {itm['user']} | IA: {itm['bot']}]"
)
else:
logger.warning(f"Índice FAISS inválido encontrado: {idx}")
except Exception as e:
logger.error(f"Erro durante a busca FAISS: {e}")
# Continua sem contexto se a busca falhar
logger.info(f"Contexto recuperado ({len(context)} itens).")
# 2. Monta Prompt
context_str = "\n".join(context)
prompt = (
f"Você é um Mestre de RPG experiente e criativo. Continue a história de forma envolvente, "
f"considerando o seguinte histórico e lembranças:\n{context_str}\n\n"
f"Ação do Jogador: {user_msg}\n\nSua Narração:"
)
logger.info(f"Prompt enviado ao modelo (primeiros 200 chars):\n{prompt[:200]}...")
# 3. Chama o Modelo (Pipeline)
try:
logger.info("Chamando pipeline text-generation...")
# return_full_text=False pega só a continuação
outputs = chat_pipe(prompt, max_new_tokens=200, return_full_text=False, num_return_sequences=1)
logger.info(f"Saída bruta do pipeline: {outputs}")
if not outputs or not outputs[0] or "generated_text" not in outputs[0]:
logger.error("Pipeline não retornou 'generated_text' válido.")
return "Desculpe, a IA não conseguiu gerar uma resposta válida."
bot_msg = outputs[0]["generated_text"].strip()
# Limpeza adicional
bot_msg = bot_msg.split("<|endoftext|>")[0].strip()
# Remover repetições exatas do prompt final se houver
if bot_msg.startswith(f"Sua Narração:"): bot_msg = bot_msg[len("Sua Narração:"):].strip()
logger.info(f"Resposta processada do bot: {repr(bot_msg)}")
# 4. Adiciona na memória e índice
add_to_memory_and_index(user_msg, bot_msg)
return bot_msg
except Exception as e:
logger.exception("Erro durante a execução do pipeline ou pós-processamento.") # Loga o traceback completo
return f"Desculpe, ocorreu um erro interno ao gerar a resposta da IA: {e}"
# --- Carrega tudo na inicialização ---
# Esta linha será chamada quando o módulo for importado pela primeira vez
# (Ou podemos chamar explicitamente via lifespan do FastAPI)
# load_models_and_memory()