Spaces:
Running
on
Zero
Running
on
Zero
"""Vector store management using Chroma for document storage and retrieval.""" | |
import os | |
from typing import List, Optional, Dict, Any, Tuple | |
from pathlib import Path | |
from langchain_chroma import Chroma | |
from langchain_core.documents import Document | |
from langchain_core.vectorstores import VectorStoreRetriever | |
from langchain_community.retrievers import BM25Retriever | |
from langchain.retrievers import EnsembleRetriever | |
from src.rag.embeddings import embedding_manager | |
from src.core.config import config | |
from src.core.logging_config import get_logger | |
logger = get_logger(__name__) | |
class VectorStoreManager: | |
"""Manages Chroma vector store for document storage and retrieval.""" | |
def __init__(self, persist_directory: Optional[str] = None, collection_name: str = "markit_documents"): | |
""" | |
Initialize the vector store manager. | |
Args: | |
persist_directory: Directory to persist the vector database | |
collection_name: Name of the collection in the vector store | |
""" | |
self.collection_name = collection_name | |
# Set default persist directory | |
if persist_directory is None: | |
persist_directory = config.rag.vector_store_path | |
self.persist_directory = str(Path(persist_directory).resolve()) | |
# Ensure the directory exists | |
os.makedirs(self.persist_directory, exist_ok=True) | |
self._vector_store: Optional[Chroma] = None | |
self._documents_cache: List[Document] = [] # Cache documents for BM25 retriever | |
self._bm25_retriever: Optional[BM25Retriever] = None | |
logger.info(f"VectorStoreManager initialized with persist_directory={self.persist_directory}") | |
def get_vector_store(self) -> Chroma: | |
"""Get or create the Chroma vector store.""" | |
if self._vector_store is None: | |
try: | |
embedding_model = embedding_manager.get_embedding_model() | |
self._vector_store = Chroma( | |
collection_name=self.collection_name, | |
embedding_function=embedding_model, | |
persist_directory=self.persist_directory, | |
collection_metadata={"hnsw:space": "cosine"} # Use cosine similarity | |
) | |
logger.info(f"Vector store initialized with collection '{self.collection_name}'") | |
except Exception as e: | |
logger.error(f"Failed to initialize vector store: {e}") | |
raise | |
return self._vector_store | |
def add_documents(self, documents: List[Document]) -> List[str]: | |
""" | |
Add documents to the vector store. | |
Args: | |
documents: List of Document objects to add | |
Returns: | |
List of document IDs that were added | |
""" | |
try: | |
if not documents: | |
logger.warning("No documents provided to add to vector store") | |
return [] | |
vector_store = self.get_vector_store() | |
# Generate unique IDs for documents | |
doc_ids = [f"doc_{i}_{hash(doc.page_content)}" for i, doc in enumerate(documents)] | |
# Add documents to the vector store | |
added_ids = vector_store.add_documents(documents=documents, ids=doc_ids) | |
# Update documents cache for BM25 retriever | |
self._documents_cache.extend(documents) | |
# Reset BM25 retriever to force rebuild with new documents | |
self._bm25_retriever = None | |
logger.info(f"Added {len(added_ids)} documents to vector store") | |
return added_ids | |
except Exception as e: | |
logger.error(f"Error adding documents to vector store: {e}") | |
raise | |
def similarity_search(self, query: str, k: int = 4, score_threshold: Optional[float] = None) -> List[Document]: | |
""" | |
Search for similar documents using semantic similarity. | |
Args: | |
query: Search query | |
k: Number of documents to return | |
score_threshold: Minimum similarity score threshold | |
Returns: | |
List of similar documents | |
""" | |
try: | |
vector_store = self.get_vector_store() | |
if score_threshold is not None: | |
# Use similarity search with score threshold | |
docs_with_scores = vector_store.similarity_search_with_relevance_scores( | |
query=query, | |
k=k, | |
score_threshold=score_threshold | |
) | |
documents = [doc for doc, score in docs_with_scores] | |
else: | |
# Regular similarity search | |
documents = vector_store.similarity_search(query=query, k=k) | |
logger.info(f"Found {len(documents)} similar documents for query: '{query[:50]}...'") | |
return documents | |
except Exception as e: | |
logger.error(f"Error performing similarity search: {e}") | |
return [] | |
def get_retriever(self, search_type: str = "similarity", search_kwargs: Optional[Dict[str, Any]] = None) -> VectorStoreRetriever: | |
""" | |
Get a retriever for the vector store. | |
Args: | |
search_type: Type of search ("similarity", "mmr", "similarity_score_threshold") | |
search_kwargs: Additional search parameters | |
Returns: | |
VectorStoreRetriever object | |
""" | |
try: | |
vector_store = self.get_vector_store() | |
if search_kwargs is None: | |
search_kwargs = {"k": 4} | |
retriever = vector_store.as_retriever( | |
search_type=search_type, | |
search_kwargs=search_kwargs | |
) | |
logger.info(f"Created retriever with search_type='{search_type}' and kwargs={search_kwargs}") | |
return retriever | |
except Exception as e: | |
logger.error(f"Error creating retriever: {e}") | |
raise | |
def get_bm25_retriever(self, k: int = 4) -> BM25Retriever: | |
""" | |
Get or create a BM25 retriever for keyword-based search. | |
Args: | |
k: Number of documents to return | |
Returns: | |
BM25Retriever object | |
""" | |
try: | |
if self._bm25_retriever is None or not self._documents_cache: | |
if not self._documents_cache: | |
# Try to load documents from the vector store | |
vector_store = self.get_vector_store() | |
collection = vector_store._collection | |
all_docs = collection.get() | |
if all_docs and all_docs.get('documents') and all_docs.get('metadatas'): | |
# Reconstruct documents from vector store | |
self._documents_cache = [ | |
Document(page_content=content, metadata=metadata) | |
for content, metadata in zip(all_docs['documents'], all_docs['metadatas']) | |
] | |
if self._documents_cache: | |
self._bm25_retriever = BM25Retriever.from_documents( | |
documents=self._documents_cache, | |
k=k | |
) | |
logger.info(f"Created BM25 retriever with {len(self._documents_cache)} documents") | |
else: | |
logger.warning("No documents available for BM25 retriever") | |
# Create empty retriever | |
self._bm25_retriever = BM25Retriever.from_documents( | |
documents=[Document(page_content="", metadata={})], | |
k=k | |
) | |
# Update k if different | |
if hasattr(self._bm25_retriever, 'k'): | |
self._bm25_retriever.k = k | |
return self._bm25_retriever | |
except Exception as e: | |
logger.error(f"Error creating BM25 retriever: {e}") | |
raise | |
def get_hybrid_retriever(self, | |
k: int = 4, | |
semantic_weight: float = 0.7, | |
keyword_weight: float = 0.3, | |
search_type: str = "similarity", | |
search_kwargs: Optional[Dict[str, Any]] = None) -> EnsembleRetriever: | |
""" | |
Get a hybrid retriever that combines semantic (vector) and keyword (BM25) search. | |
Args: | |
k: Number of documents to return | |
semantic_weight: Weight for semantic search (0.0 to 1.0) | |
keyword_weight: Weight for keyword search (0.0 to 1.0) | |
search_type: Type of semantic search ("similarity", "mmr", "similarity_score_threshold") | |
search_kwargs: Additional search parameters for semantic retriever | |
Returns: | |
EnsembleRetriever object combining both approaches | |
""" | |
try: | |
# Normalize weights | |
total_weight = semantic_weight + keyword_weight | |
if total_weight == 0: | |
semantic_weight, keyword_weight = 0.7, 0.3 | |
else: | |
semantic_weight = semantic_weight / total_weight | |
keyword_weight = keyword_weight / total_weight | |
# Get semantic retriever | |
if search_kwargs is None: | |
search_kwargs = {"k": k} | |
else: | |
search_kwargs = search_kwargs.copy() | |
search_kwargs["k"] = k | |
semantic_retriever = self.get_retriever( | |
search_type=search_type, | |
search_kwargs=search_kwargs | |
) | |
# Get BM25 retriever | |
keyword_retriever = self.get_bm25_retriever(k=k) | |
# Create ensemble retriever | |
ensemble_retriever = EnsembleRetriever( | |
retrievers=[semantic_retriever, keyword_retriever], | |
weights=[semantic_weight, keyword_weight] | |
) | |
logger.info(f"Created hybrid retriever with weights: semantic={semantic_weight:.2f}, keyword={keyword_weight:.2f}") | |
return ensemble_retriever | |
except Exception as e: | |
logger.error(f"Error creating hybrid retriever: {e}") | |
raise | |
def get_collection_info(self) -> Dict[str, Any]: | |
""" | |
Get information about the current collection. | |
Returns: | |
Dictionary with collection information | |
""" | |
try: | |
vector_store = self.get_vector_store() | |
# Get collection count | |
count = vector_store._collection.count() | |
info = { | |
"collection_name": self.collection_name, | |
"persist_directory": self.persist_directory, | |
"document_count": count, | |
"embedding_model": "text-embedding-3-small" | |
} | |
logger.info(f"Collection info: {info}") | |
return info | |
except Exception as e: | |
logger.error(f"Error getting collection info: {e}") | |
return {"error": str(e)} | |
def delete_collection(self) -> bool: | |
""" | |
Delete the current collection and reset the vector store. | |
Returns: | |
True if successful, False otherwise | |
""" | |
try: | |
if self._vector_store is not None: | |
self._vector_store.delete_collection() | |
self._vector_store = None | |
logger.info(f"Deleted collection '{self.collection_name}'") | |
return True | |
except Exception as e: | |
logger.error(f"Error deleting collection: {e}") | |
return False | |
def search_with_metadata_filter(self, query: str, metadata_filter: Dict[str, Any], k: int = 4) -> List[Document]: | |
""" | |
Search documents with metadata filtering. | |
Args: | |
query: Search query | |
metadata_filter: Metadata filter conditions | |
k: Number of documents to return | |
Returns: | |
List of filtered documents | |
""" | |
try: | |
vector_store = self.get_vector_store() | |
documents = vector_store.similarity_search( | |
query=query, | |
k=k, | |
filter=metadata_filter | |
) | |
logger.info(f"Found {len(documents)} documents with metadata filter: {metadata_filter}") | |
return documents | |
except Exception as e: | |
logger.error(f"Error searching with metadata filter: {e}") | |
return [] | |
def clear_all_documents(self) -> bool: | |
""" | |
Clear all documents from the vector store collection. | |
Returns: | |
True if successful, False otherwise | |
""" | |
try: | |
vector_store = self.get_vector_store() | |
# Get all document IDs first | |
collection = vector_store._collection | |
all_docs = collection.get() | |
if not all_docs or not all_docs.get('ids'): | |
logger.info("No documents found in vector store to clear") | |
return True | |
# Delete all documents using their IDs | |
collection.delete(ids=all_docs['ids']) | |
# Reset the vector store instance to ensure clean state | |
self._vector_store = None | |
# Clear documents cache and BM25 retriever | |
self._documents_cache.clear() | |
self._bm25_retriever = None | |
logger.info(f"Successfully cleared {len(all_docs['ids'])} documents from vector store") | |
return True | |
except Exception as e: | |
logger.error(f"Error clearing all documents: {e}") | |
return False | |
# Global vector store manager instance | |
vector_store_manager = VectorStoreManager() |