from dotenv import load_dotenv
load_dotenv()
import os

from langchain_ollama import OllamaEmbeddings
from langchain_openai import ChatOpenAI
from langchain_chroma import Chroma
from langchain.chains import create_retrieval_chain
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain import hub

# ——— CONFIG ———
PERSIST_DIR    = "chroma_db/"
OLLAMA_URL     = os.getenv("OLLAMA_SERVER")
EMBED_MODEL    = "nomic-embed-text:latest"
LLM_API_KEY    = os.getenv("LLM_API_KEY")
LLM_API_BASE   = os.getenv("LLM_API_BASE", "https://llm.chutes.ai/v1")
LLM_MODEL      = "chutesai/Llama-4-Scout-17B-16E-Instruct"
PROMPT         = hub.pull("langchain-ai/retrieval-qa-chat")
TOP_K          = 5
# ——————————

def run_query(query: str):
    # 1) rebuild the same embedder
    embedder = OllamaEmbeddings(base_url=OLLAMA_URL, model=EMBED_MODEL)
    
    # 2) load the on-disk DB with embedder in place
    vectordb = Chroma(
        persist_directory=PERSIST_DIR,
        collection_name="my_docs",
        embedding_function=embedder
    )

    # 3) set up retriever + LLM chain
    retriever = vectordb.as_retriever(search_kwargs={"k": TOP_K})
    llm = ChatOpenAI(api_key=LLM_API_KEY, base_url=LLM_API_BASE, model=LLM_MODEL)
    combine = create_stuff_documents_chain(llm=llm, prompt=PROMPT)
    rag_chain = create_retrieval_chain(retriever, combine)

    # 4) run your query
    print(f"🔍 Query: {query}")
    answer = rag_chain.invoke({"input": query})
    print("\n📄 Answer:\n", answer)

if __name__ == "__main__":
    exit=False
    while not exit:
        user_input = input("Enter your query (or 'exit' to quit): ")
        if user_input.lower() == 'exit':
            exit = True
        else:
            run_query(user_input)