import os import json import faiss import torch from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline from sentence_transformers import SentenceTransformer from huggingface_hub import HfApi # Pode ser removido se não usar para commits externos import logging # Configuração de Logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # --- Configurações Globais --- MODEL_NAME = "stabilityai/stablelm-3b-4e1t" EMBED_MODEL = "all-MiniLM-L6-v2" # Usando nome curto que funciona SUMMARIZER_ID = "sshleifer/distilbart-cnn-12-6" MEM_FILE = "memory.json" # Salvará no armazenamento do Space IDX_FILE = "index.faiss" # Salvará no armazenamento do Space # HF_TOKEN = os.getenv("HF_TOKEN") # Se precisar para algo no futuro # REPO_ID = os.getenv("SPACE_ID") # Se precisar para algo no futuro # --- Variáveis Globais para Modelos e Memória (Carregados uma vez) --- tokenizer = None model = None chat_pipe = None embedder = None summarizer = None memory = [] index = None faiss_dimension = 384 # Dimensão padrão para all-MiniLM-L6-v2 def load_models_and_memory(): """Carrega modelos, pipelines e a memória/índice FAISS.""" global tokenizer, model, chat_pipe, embedder, summarizer, memory, index, faiss_dimension logger.info("Iniciando carregamento de modelos e memória...") # Carrega Tokenizer e Modelo LLM (FP16) logger.info(f"Carregando tokenizer: {MODEL_NAME}") tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) logger.info(f"Carregando modelo: {MODEL_NAME} (FP16, CPU)") model = AutoModelForCausalLM.from_pretrained( MODEL_NAME, torch_dtype=torch.float16, low_cpu_mem_usage=True, # Importante para CPU com RAM limitada device_map="auto" # Deixa accelerate decidir (provavelmente 'cpu') ) logger.info("Modelo LLM carregado.") # Cria Pipeline de Chat logger.info("Criando pipeline de text-generation...") chat_pipe = pipeline( "text-generation", model=model, tokenizer=tokenizer, # max_length=512, # Definir max_new_tokens na chamada é melhor do_sample=True, top_p=0.9, ) logger.info("Pipeline de chat criado.") # Carrega Embedder logger.info(f"Carregando embedder: {EMBED_MODEL}") embedder = SentenceTransformer(EMBED_MODEL) faiss_dimension = embedder.get_sentence_embedding_dimension() # Atualiza a dimensão logger.info(f"Embedder carregado. Dimensão: {faiss_dimension}") # Carrega Summarizer logger.info(f"Carregando summarizer: {SUMMARIZER_ID}") summarizer = pipeline("summarization", model=SUMMARIZER_ID) logger.info("Summarizer carregado.") # Carrega/Inicializa Memória e Índice FAISS logger.info("Carregando/Inicializando memória e índice FAISS...") if os.path.exists(MEM_FILE): try: with open(MEM_FILE, "r") as f: memory = json.load(f) logger.info(f"Arquivo de memória '{MEM_FILE}' carregado ({len(memory)} entradas).") except Exception as e: logger.error(f"Erro ao carregar {MEM_FILE}: {e}. Iniciando memória vazia.") memory = [] else: logger.info(f"Arquivo '{MEM_FILE}' não encontrado. Iniciando memória vazia.") memory = [] if os.path.exists(IDX_FILE): try: index = faiss.read_index(IDX_FILE) logger.info(f"Índice FAISS '{IDX_FILE}' carregado ({index.ntotal} vetores).") # Validação simples da dimensão if index.ntotal > 0 and index.d != faiss_dimension: logger.warning(f"Dimensão do índice FAISS ({index.d}) diferente da dimensão do embedder ({faiss_dimension})! Recriando índice.") index = faiss.IndexFlatL2(faiss_dimension) # Idealmente, aqui você re-indexaria a memória existente, mas vamos simplificar por agora except Exception as e: logger.error(f"Erro ao carregar {IDX_FILE}: {e}. Recriando índice FAISS.") index = faiss.IndexFlatL2(faiss_dimension) else: logger.info(f"Arquivo '{IDX_FILE}' não encontrado. Criando novo índice FAISS.") index = faiss.IndexFlatL2(faiss_dimension) logger.info("Carregamento de modelos e memória concluído.") # --- Funções de Lógica (Adaptadas) --- def save_state(): """Salva o estado atual da memória e do índice FAISS.""" logger.info("Salvando estado (memória e índice)...") try: with open(MEM_FILE, "w") as f: json.dump(memory, f, indent=2) # Adicionado indent para legibilidade faiss.write_index(index, IDX_FILE) logger.info("Estado salvo com sucesso.") except Exception as e: logger.error(f"Erro ao salvar estado: {e}") def summarize_block(txt): logger.info("Chamando summarizer...") instr = ("Resuma este trecho de diálogo de RPG preservando personagens, locais, itens e eventos:\n\n") try: summary = summarizer(instr + txt, max_length=150, min_length=50)[0]["summary_text"] # Reduzi min_length logger.info("Resumo gerado.") return summary except Exception as e: logger.error(f"Erro no summarizer: {e}") return f"[[Erro ao resumir: {e}]]" def compact_memory(threshold=50): # Reduzi o threshold para testar mais rápido logger.info(f"Verificando compactação (limite={threshold}), memória atual={len(memory)}...") if len(memory) < threshold: return False # Indica que não compactou logger.info(f"Compactando memória ({threshold} itens)...") bloco = memory[:threshold] texto = "\n".join( itm["text"] if itm["type"]=="summary" else f"Usuário: {itm['user']}\nIA: {itm['bot']}" for itm in bloco ) resumo = summarize_block(texto) memory[:threshold] = [{"type":"summary","text":resumo}] # Substitui bloco por resumo logger.info("Recriando índice FAISS após compactação...") new_idx = faiss.IndexFlatL2(faiss_dimension) embeddings_to_add = [] for itm in memory: key = itm["text"] if itm["type"]=="summary" else itm["user"] try: # Coleta todos os embeddings primeiro embeddings_to_add.append(embedder.encode([key], convert_to_numpy=True)[0]) except Exception as e: logger.error(f"Erro ao encodar item para reindexação: {e}") if embeddings_to_add: try: new_idx.add(np.array(embeddings_to_add)) # Adiciona em lote global index index = new_idx logger.info("Reindexação FAISS concluída.") save_state() # Salva após compactar e reindexar return True # Indica que compactou except Exception as e: logger.error(f"Erro ao adicionar embeddings ao novo índice: {e}") return False else: logger.info("Nenhum embedding válido para adicionar ao novo índice.") return False def add_to_memory_and_index(user_msg, bot_msg): logger.info("Adicionando nova entrada à memória e índice...") entry = {"type":"dialog", "user":user_msg, "bot":bot_msg, "text":user_msg} memory.append(entry) try: embedding = embedder.encode([user_msg], convert_to_numpy=True) index.add(embedding) logger.info(f"Embedding adicionado ao índice. Total de vetores: {index.ntotal}") save_state() # Salva após adicionar compact_memory() # Verifica se precisa compactar except Exception as e: logger.error(f"Erro ao adicionar embedding ou salvar estado: {e}") def run_chat_logic(user_msg): """Executa a lógica principal do chat: busca contexto, gera resposta.""" logger.info(f"Executando lógica do chat para: {repr(user_msg)}") global memory, index # Garante acesso às variáveis globais if not all([tokenizer, model, chat_pipe, embedder, index]): logger.error("Modelos ou índice não foram carregados corretamente.") return "Desculpe, o sistema de IA não está pronto. Tente novamente mais tarde." # 1. Embedding e Busca FAISS logger.info("Gerando embedding da mensagem do usuário...") try: emb = embedder.encode([user_msg], convert_to_numpy=True) except Exception as e: logger.error(f"Erro ao gerar embedding: {e}") return "Desculpe, houve um erro ao processar sua mensagem (embedding)." context = [] logger.info(f"Buscando no índice FAISS ({index.ntotal} vetores)...") if index.ntotal > 0: try: D, I = index.search(emb, k=5) # Busca os 5 vizinhos mais próximos logger.info(f"Índices FAISS encontrados: {I[0]}") for idx in I[0]: if 0 <= idx < len(memory): # Validação crucial itm = memory[idx] context.append( f"Lembrança: {itm['text']}" if itm["type"]=="summary" else f"Histórico [Usuário: {itm['user']} | IA: {itm['bot']}]" ) else: logger.warning(f"Índice FAISS inválido encontrado: {idx}") except Exception as e: logger.error(f"Erro durante a busca FAISS: {e}") # Continua sem contexto se a busca falhar logger.info(f"Contexto recuperado ({len(context)} itens).") # 2. Monta Prompt context_str = "\n".join(context) prompt = ( f"Você é um Mestre de RPG experiente e criativo. Continue a história de forma envolvente, " f"considerando o seguinte histórico e lembranças:\n{context_str}\n\n" f"Ação do Jogador: {user_msg}\n\nSua Narração:" ) logger.info(f"Prompt enviado ao modelo (primeiros 200 chars):\n{prompt[:200]}...") # 3. Chama o Modelo (Pipeline) try: logger.info("Chamando pipeline text-generation...") # return_full_text=False pega só a continuação outputs = chat_pipe(prompt, max_new_tokens=200, return_full_text=False, num_return_sequences=1) logger.info(f"Saída bruta do pipeline: {outputs}") if not outputs or not outputs[0] or "generated_text" not in outputs[0]: logger.error("Pipeline não retornou 'generated_text' válido.") return "Desculpe, a IA não conseguiu gerar uma resposta válida." bot_msg = outputs[0]["generated_text"].strip() # Limpeza adicional bot_msg = bot_msg.split("<|endoftext|>")[0].strip() # Remover repetições exatas do prompt final se houver if bot_msg.startswith(f"Sua Narração:"): bot_msg = bot_msg[len("Sua Narração:"):].strip() logger.info(f"Resposta processada do bot: {repr(bot_msg)}") # 4. Adiciona na memória e índice add_to_memory_and_index(user_msg, bot_msg) return bot_msg except Exception as e: logger.exception("Erro durante a execução do pipeline ou pós-processamento.") # Loga o traceback completo return f"Desculpe, ocorreu um erro interno ao gerar a resposta da IA: {e}" # --- Carrega tudo na inicialização --- # Esta linha será chamada quando o módulo for importado pela primeira vez # (Ou podemos chamar explicitamente via lifespan do FastAPI) # load_models_and_memory()