import os from typing import Optional, List import chromadb from chromadb.utils import embedding_functions from langchain_community.document_loaders import DirectoryLoader, PyPDFLoader from langchain.text_splitter import RecursiveCharacterTextSplitter from dotenv import load_dotenv import google.generativeai as genai load_dotenv() # Configure paths CORPUS_DIR = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "corpus") DB_DIR = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "vectordb") # Ensure directories exist os.makedirs(CORPUS_DIR, exist_ok=True) os.makedirs(DB_DIR, exist_ok=True) def load_documents(corpus_dir: str = CORPUS_DIR) -> List: """Load documents from the corpus directory.""" if not os.path.exists(corpus_dir): raise FileNotFoundError(f"Corpus directory not found: {corpus_dir}") print(f"Loading documents from {corpus_dir}...") # Initialize loaders for different file types loaders = { # "txt": DirectoryLoader(corpus_dir, glob="**/*.txt", loader_cls=TextLoader), "pdf": DirectoryLoader(corpus_dir, glob="**/*.pdf", loader_cls=PyPDFLoader), # "docx": DirectoryLoader(corpus_dir, glob="**/*.docx", loader_cls=Docx2txtLoader), } documents = [] for file_type, loader in loaders.items(): try: docs = loader.load() print(f"Loaded {len(docs)} {file_type} documents") documents.extend(docs) except Exception as e: print(f"Error loading {file_type} documents: {e}") return documents def split_documents(documents, chunk_size=1000, chunk_overlap=200): """Split documents into chunks.""" text_splitter = RecursiveCharacterTextSplitter( chunk_size=chunk_size, chunk_overlap=chunk_overlap, length_function=len, ) splits = text_splitter.split_documents(documents) print(f"Split {len(documents)} documents into {len(splits)} chunks") return splits def create_chroma_db_and_document(document, collection_name="corpus_collection", db_dir=DB_DIR): """Create a Chroma vector database from documents.""" # Initialize the Gemini embedding function gemini_ef = embedding_functions.GoogleGenerativeAiEmbeddingFunction( api_key=os.getenv("GOOGLE_API_KEY"), model_name="models/embedding-001" ) # Initialize Chroma client client = chromadb.PersistentClient(path=db_dir) # Create or get collection try: collection = client.get_collection(name=collection_name) print(f"Using existing collection: {collection_name}") except: collection = client.create_collection( name=collection_name, embedding_function=gemini_ef ) print(f"Created new collection: {collection_name}") try: collection.add( documents = [document.page_content], ids = [document.id] ) print("Document added to collection successfully.") return True except Exception as e: print(f"Error adding document to collection: {e}") return False def query_chroma_db(query: str, collection_name="corpus_collection", n_results=5, db_dir=DB_DIR): """Query the Chroma vector database.""" # Initialize the Gemini embedding function gemini_ef = embedding_functions.GoogleGenerativeAiEmbeddingFunction( api_key=os.getenv("GOOGLE_API_KEY"), model_name="models/embedding-001" ) # Initialize Chroma client client = chromadb.PersistentClient(path=db_dir) # Get collection collection = client.get_collection(name=collection_name, embedding_function=gemini_ef) # Query collection results = collection.query( query_texts=[query], n_results=n_results ) return results def main(): """Main function to create and test the vector database.""" print("Starting vector database creation...") # Load documents documents = load_documents() if not documents: print("No documents found in corpus directory. Please add documents to proceed.") return # Split documents splits = split_documents(documents) # Create vector database collection = create_chroma_db(splits) # Test query test_query = "What is this corpus about?" print(f"\nTesting query: '{test_query}'") results = query_chroma_db(test_query) print(f"Found {len(results['documents'][0])} matching documents") for i, (doc, metadata) in enumerate(zip(results['documents'][0], results['metadatas'][0])): print(f"\nResult {i+1}:") print(f"Document: {doc[:150]}...") print(f"Source: {metadata.get('source', 'Unknown')}") print("\nVector database creation and testing complete!") if __name__ == "__main__": main()