File size: 11,301 Bytes
83de08e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
import os
import json
import faiss
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
from sentence_transformers import SentenceTransformer
from huggingface_hub import HfApi # Pode ser removido se não usar para commits externos
import logging

# Configuração de Logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# --- Configurações Globais ---
MODEL_NAME    = "stabilityai/stablelm-3b-4e1t"
EMBED_MODEL   = "all-MiniLM-L6-v2" # Usando nome curto que funciona
SUMMARIZER_ID = "sshleifer/distilbart-cnn-12-6"
MEM_FILE      = "memory.json" # Salvará no armazenamento do Space
IDX_FILE      = "index.faiss" # Salvará no armazenamento do Space
# HF_TOKEN      = os.getenv("HF_TOKEN") # Se precisar para algo no futuro
# REPO_ID       = os.getenv("SPACE_ID") # Se precisar para algo no futuro

# --- Variáveis Globais para Modelos e Memória (Carregados uma vez) ---
tokenizer = None
model = None
chat_pipe = None
embedder = None
summarizer = None
memory = []
index = None
faiss_dimension = 384 # Dimensão padrão para all-MiniLM-L6-v2

def load_models_and_memory():
    """Carrega modelos, pipelines e a memória/índice FAISS."""
    global tokenizer, model, chat_pipe, embedder, summarizer, memory, index, faiss_dimension

    logger.info("Iniciando carregamento de modelos e memória...")

    # Carrega Tokenizer e Modelo LLM (FP16)
    logger.info(f"Carregando tokenizer: {MODEL_NAME}")
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

    logger.info(f"Carregando modelo: {MODEL_NAME} (FP16, CPU)")
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_NAME,
        torch_dtype=torch.float16,
        low_cpu_mem_usage=True, # Importante para CPU com RAM limitada
        device_map="auto" # Deixa accelerate decidir (provavelmente 'cpu')
    )
    logger.info("Modelo LLM carregado.")

    # Cria Pipeline de Chat
    logger.info("Criando pipeline de text-generation...")
    chat_pipe = pipeline(
        "text-generation",
        model=model,
        tokenizer=tokenizer,
        # max_length=512, # Definir max_new_tokens na chamada é melhor
        do_sample=True,
        top_p=0.9,
    )
    logger.info("Pipeline de chat criado.")

    # Carrega Embedder
    logger.info(f"Carregando embedder: {EMBED_MODEL}")
    embedder = SentenceTransformer(EMBED_MODEL)
    faiss_dimension = embedder.get_sentence_embedding_dimension() # Atualiza a dimensão
    logger.info(f"Embedder carregado. Dimensão: {faiss_dimension}")


    # Carrega Summarizer
    logger.info(f"Carregando summarizer: {SUMMARIZER_ID}")
    summarizer = pipeline("summarization", model=SUMMARIZER_ID)
    logger.info("Summarizer carregado.")

    # Carrega/Inicializa Memória e Índice FAISS
    logger.info("Carregando/Inicializando memória e índice FAISS...")
    if os.path.exists(MEM_FILE):
        try:
            with open(MEM_FILE, "r") as f:
                memory = json.load(f)
            logger.info(f"Arquivo de memória '{MEM_FILE}' carregado ({len(memory)} entradas).")
        except Exception as e:
            logger.error(f"Erro ao carregar {MEM_FILE}: {e}. Iniciando memória vazia.")
            memory = []
    else:
        logger.info(f"Arquivo '{MEM_FILE}' não encontrado. Iniciando memória vazia.")
        memory = []

    if os.path.exists(IDX_FILE):
        try:
            index = faiss.read_index(IDX_FILE)
            logger.info(f"Índice FAISS '{IDX_FILE}' carregado ({index.ntotal} vetores).")
            # Validação simples da dimensão
            if index.ntotal > 0 and index.d != faiss_dimension:
                 logger.warning(f"Dimensão do índice FAISS ({index.d}) diferente da dimensão do embedder ({faiss_dimension})! Recriando índice.")
                 index = faiss.IndexFlatL2(faiss_dimension)
                 # Idealmente, aqui você re-indexaria a memória existente, mas vamos simplificar por agora
        except Exception as e:
            logger.error(f"Erro ao carregar {IDX_FILE}: {e}. Recriando índice FAISS.")
            index = faiss.IndexFlatL2(faiss_dimension)
    else:
        logger.info(f"Arquivo '{IDX_FILE}' não encontrado. Criando novo índice FAISS.")
        index = faiss.IndexFlatL2(faiss_dimension)

    logger.info("Carregamento de modelos e memória concluído.")


# --- Funções de Lógica (Adaptadas) ---

def save_state():
    """Salva o estado atual da memória e do índice FAISS."""
    logger.info("Salvando estado (memória e índice)...")
    try:
        with open(MEM_FILE, "w") as f:
            json.dump(memory, f, indent=2) # Adicionado indent para legibilidade
        faiss.write_index(index, IDX_FILE)
        logger.info("Estado salvo com sucesso.")
    except Exception as e:
        logger.error(f"Erro ao salvar estado: {e}")

def summarize_block(txt):
    logger.info("Chamando summarizer...")
    instr = ("Resuma este trecho de diálogo de RPG preservando personagens, locais, itens e eventos:\n\n")
    try:
        summary = summarizer(instr + txt, max_length=150, min_length=50)[0]["summary_text"] # Reduzi min_length
        logger.info("Resumo gerado.")
        return summary
    except Exception as e:
        logger.error(f"Erro no summarizer: {e}")
        return f"[[Erro ao resumir: {e}]]"

def compact_memory(threshold=50): # Reduzi o threshold para testar mais rápido
    logger.info(f"Verificando compactação (limite={threshold}), memória atual={len(memory)}...")
    if len(memory) < threshold:
        return False # Indica que não compactou

    logger.info(f"Compactando memória ({threshold} itens)...")
    bloco = memory[:threshold]
    texto = "\n".join(
        itm["text"] if itm["type"]=="summary"
        else f"Usuário: {itm['user']}\nIA: {itm['bot']}"
        for itm in bloco
    )
    resumo = summarize_block(texto)
    memory[:threshold] = [{"type":"summary","text":resumo}] # Substitui bloco por resumo

    logger.info("Recriando índice FAISS após compactação...")
    new_idx = faiss.IndexFlatL2(faiss_dimension)
    embeddings_to_add = []
    for itm in memory:
        key = itm["text"] if itm["type"]=="summary" else itm["user"]
        try:
            # Coleta todos os embeddings primeiro
            embeddings_to_add.append(embedder.encode([key], convert_to_numpy=True)[0])
        except Exception as e:
            logger.error(f"Erro ao encodar item para reindexação: {e}")

    if embeddings_to_add:
        try:
            new_idx.add(np.array(embeddings_to_add)) # Adiciona em lote
            global index
            index = new_idx
            logger.info("Reindexação FAISS concluída.")
            save_state() # Salva após compactar e reindexar
            return True # Indica que compactou
        except Exception as e:
             logger.error(f"Erro ao adicionar embeddings ao novo índice: {e}")
             return False
    else:
        logger.info("Nenhum embedding válido para adicionar ao novo índice.")
        return False


def add_to_memory_and_index(user_msg, bot_msg):
    logger.info("Adicionando nova entrada à memória e índice...")
    entry = {"type":"dialog", "user":user_msg, "bot":bot_msg, "text":user_msg}
    memory.append(entry)
    try:
        embedding = embedder.encode([user_msg], convert_to_numpy=True)
        index.add(embedding)
        logger.info(f"Embedding adicionado ao índice. Total de vetores: {index.ntotal}")
        save_state() # Salva após adicionar
        compact_memory() # Verifica se precisa compactar
    except Exception as e:
        logger.error(f"Erro ao adicionar embedding ou salvar estado: {e}")


def run_chat_logic(user_msg):
    """Executa a lógica principal do chat: busca contexto, gera resposta."""
    logger.info(f"Executando lógica do chat para: {repr(user_msg)}")
    global memory, index # Garante acesso às variáveis globais

    if not all([tokenizer, model, chat_pipe, embedder, index]):
         logger.error("Modelos ou índice não foram carregados corretamente.")
         return "Desculpe, o sistema de IA não está pronto. Tente novamente mais tarde."

    # 1. Embedding e Busca FAISS
    logger.info("Gerando embedding da mensagem do usuário...")
    try:
        emb = embedder.encode([user_msg], convert_to_numpy=True)
    except Exception as e:
        logger.error(f"Erro ao gerar embedding: {e}")
        return "Desculpe, houve um erro ao processar sua mensagem (embedding)."

    context = []
    logger.info(f"Buscando no índice FAISS ({index.ntotal} vetores)...")
    if index.ntotal > 0:
        try:
            D, I = index.search(emb, k=5) # Busca os 5 vizinhos mais próximos
            logger.info(f"Índices FAISS encontrados: {I[0]}")
            for idx in I[0]:
                if 0 <= idx < len(memory): # Validação crucial
                    itm = memory[idx]
                    context.append(
                        f"Lembrança: {itm['text']}" if itm["type"]=="summary"
                        else f"Histórico [Usuário: {itm['user']} | IA: {itm['bot']}]"
                    )
                else:
                    logger.warning(f"Índice FAISS inválido encontrado: {idx}")
        except Exception as e:
            logger.error(f"Erro durante a busca FAISS: {e}")
            # Continua sem contexto se a busca falhar

    logger.info(f"Contexto recuperado ({len(context)} itens).")

    # 2. Monta Prompt
    context_str = "\n".join(context)
    prompt = (
        f"Você é um Mestre de RPG experiente e criativo. Continue a história de forma envolvente, "
        f"considerando o seguinte histórico e lembranças:\n{context_str}\n\n"
        f"Ação do Jogador: {user_msg}\n\nSua Narração:"
    )
    logger.info(f"Prompt enviado ao modelo (primeiros 200 chars):\n{prompt[:200]}...")

    # 3. Chama o Modelo (Pipeline)
    try:
        logger.info("Chamando pipeline text-generation...")
        # return_full_text=False pega só a continuação
        outputs = chat_pipe(prompt, max_new_tokens=200, return_full_text=False, num_return_sequences=1)
        logger.info(f"Saída bruta do pipeline: {outputs}")

        if not outputs or not outputs[0] or "generated_text" not in outputs[0]:
             logger.error("Pipeline não retornou 'generated_text' válido.")
             return "Desculpe, a IA não conseguiu gerar uma resposta válida."

        bot_msg = outputs[0]["generated_text"].strip()
        # Limpeza adicional
        bot_msg = bot_msg.split("<|endoftext|>")[0].strip()
        # Remover repetições exatas do prompt final se houver
        if bot_msg.startswith(f"Sua Narração:"): bot_msg = bot_msg[len("Sua Narração:"):].strip()

        logger.info(f"Resposta processada do bot: {repr(bot_msg)}")

        # 4. Adiciona na memória e índice
        add_to_memory_and_index(user_msg, bot_msg)

        return bot_msg

    except Exception as e:
        logger.exception("Erro durante a execução do pipeline ou pós-processamento.") # Loga o traceback completo
        return f"Desculpe, ocorreu um erro interno ao gerar a resposta da IA: {e}"

# --- Carrega tudo na inicialização ---
# Esta linha será chamada quando o módulo for importado pela primeira vez
# (Ou podemos chamar explicitamente via lifespan do FastAPI)
# load_models_and_memory()