import json import os from typing import List, Dict import uuid from langchain.schema import Document from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain_community.document_loaders import TextLoader from langchain_community.vectorstores import Chroma from langchain_text_splitters import Language from langchain_core.embeddings import Embeddings import statistics from litellm import embedding import litellm import tiktoken from tqdm import tqdm from langfuse import Langfuse from mllm_tools.utils import _prepare_text_inputs from task_generator import get_prompt_detect_plugins class RAGVectorStore: """A class for managing vector stores for RAG (Retrieval Augmented Generation). This class handles creation, loading and querying of vector stores for both Manim core and plugin documentation. Args: chroma_db_path (str): Path to ChromaDB storage directory manim_docs_path (str): Path to Manim documentation files embedding_model (str): Name of the embedding model to use trace_id (str, optional): Trace identifier for logging. Defaults to None session_id (str, optional): Session identifier. Defaults to None use_langfuse (bool, optional): Whether to use Langfuse logging. Defaults to True helper_model: Helper model for processing. Defaults to None """ def __init__(self, chroma_db_path: str = "chroma_db", manim_docs_path: str = "rag/manim_docs", embedding_model: str = "text-embedding-ada-002", trace_id: str = None, session_id: str = None, use_langfuse: bool = True, helper_model = None): self.chroma_db_path = chroma_db_path self.manim_docs_path = manim_docs_path self.embedding_model = embedding_model self.trace_id = trace_id self.session_id = session_id self.use_langfuse = use_langfuse self.helper_model = helper_model self.enc = tiktoken.encoding_for_model("gpt-4") self.plugin_stores = {} self.vector_store = self._load_or_create_vector_store() def _load_or_create_vector_store(self): """Loads existing or creates new ChromaDB vector stores. Creates/loads vector stores for both Manim core documentation and any available plugins. Stores are persisted to disk for future reuse. Returns: Chroma: The core Manim vector store instance """ print("Entering _load_or_create_vector_store with trace_id:", self.trace_id) core_path = os.path.join(self.chroma_db_path, "manim_core") # Load or create core vector store if os.path.exists(core_path): print("Loading existing core ChromaDB...") self.core_vector_store = Chroma( collection_name="manim_core", persist_directory=core_path, embedding_function=self._get_embedding_function() ) else: print("Creating new core ChromaDB...") self.core_vector_store = self._create_core_store() # Fix: Use correct path construction for plugin_docs plugin_docs_path = os.path.join(self.manim_docs_path, "plugin_docs") print(f"Plugin docs path: {plugin_docs_path}") if os.path.exists(plugin_docs_path): for plugin_name in os.listdir(plugin_docs_path): plugin_store_path = os.path.join(self.chroma_db_path, f"manim_plugin_{plugin_name}") if os.path.exists(plugin_store_path): print(f"Loading existing plugin store: {plugin_name}") self.plugin_stores[plugin_name] = Chroma( collection_name=f"manim_plugin_{plugin_name}", persist_directory=plugin_store_path, embedding_function=self._get_embedding_function() ) else: print(f"Creating new plugin store: {plugin_name}") plugin_path = os.path.join(plugin_docs_path, plugin_name) if os.path.isdir(plugin_path): plugin_store = Chroma( collection_name=f"manim_plugin_{plugin_name}", embedding_function=self._get_embedding_function(), persist_directory=plugin_store_path ) plugin_docs = self._process_documentation_folder(plugin_path) if plugin_docs: self._add_documents_to_store(plugin_store, plugin_docs, plugin_name) self.plugin_stores[plugin_name] = plugin_store return self.core_vector_store # Return core store for backward compatibility def _get_embedding_function(self) -> Embeddings: """Creates an embedding function using litellm. Returns: Embeddings: A LangChain Embeddings instance that wraps litellm functionality """ class LiteLLMEmbeddings(Embeddings): def __init__(self, embedding_model): self.embedding_model = embedding_model def embed_documents(self, texts: list[str]) -> list[list[float]]: litellm.success_callback = [] litellm.failure_callback = [] response = embedding( model=self.embedding_model, input=texts, task_type="CODE_RETRIEVAL_QUERY" if self.embedding_model == "vertex_ai/text-embedding-005" else None ) litellm.success_callback = ["langfuse"] litellm.failure_callback = ["langfuse"] return [r["embedding"] for r in response["data"]] def embed_query(self, text: str) -> list[float]: litellm.success_callback = [] litellm.failure_callback = [] response = embedding( model=self.embedding_model, input=[text], task_type="CODE_RETRIEVAL_QUERY" if self.embedding_model == "vertex_ai/text-embedding-005" else None ) litellm.success_callback = ["langfuse"] litellm.failure_callback = ["langfuse"] return response["data"][0]["embedding"] return LiteLLMEmbeddings(self.embedding_model) def _create_core_store(self): """Creates the main ChromaDB vector store for Manim core documentation. Returns: Chroma: The initialized and populated core vector store """ core_vector_store = Chroma( collection_name="manim_core", embedding_function=self._get_embedding_function(), persist_directory=os.path.join(self.chroma_db_path, "manim_core") ) # Process manim core docs core_docs = self._process_documentation_folder(os.path.join(self.manim_docs_path, "manim_core")) if core_docs: self._add_documents_to_store(core_vector_store, core_docs, "manim_core") return core_vector_store def _process_documentation_folder(self, folder_path: str) -> List[Document]: """Processes documentation files from a folder into LangChain documents. Args: folder_path (str): Path to the folder containing documentation files Returns: List[Document]: List of processed LangChain documents """ all_docs = [] for root, _, files in os.walk(folder_path): for file in files: if file.endswith(('.md', '.py')): file_path = os.path.join(root, file) try: loader = TextLoader(file_path) documents = loader.load() for doc in documents: doc.metadata['source'] = file_path all_docs.extend(documents) except Exception as e: print(f"Error loading file {file_path}: {e}") if not all_docs: print(f"No markdown or python files found in {folder_path}") return [] # Split documents using appropriate splitters split_docs = [] markdown_splitter = RecursiveCharacterTextSplitter.from_language( language=Language.MARKDOWN ) python_splitter = RecursiveCharacterTextSplitter.from_language( language=Language.PYTHON ) for doc in all_docs: if doc.metadata['source'].endswith('.md'): temp_docs = markdown_splitter.split_documents([doc]) for temp_doc in temp_docs: temp_doc.page_content = f"Source: {doc.metadata['source']}\n\n{temp_doc.page_content}" split_docs.extend(temp_docs) elif doc.metadata['source'].endswith('.py'): temp_docs = python_splitter.split_documents([doc]) for temp_doc in temp_docs: temp_doc.page_content = f"Source: {doc.metadata['source']}\n\n{temp_doc.page_content}" split_docs.extend(temp_docs) return split_docs def _add_documents_to_store(self, vector_store: Chroma, documents: List[Document], store_name: str): """Adds documents to a vector store in batches with rate limiting. Args: vector_store (Chroma): The vector store to add documents to documents (List[Document]): List of documents to add store_name (str): Name of the store for logging purposes """ print(f"Adding documents to {store_name} store") # Calculate token statistics token_lengths = [len(self.enc.encode(doc.page_content)) for doc in documents] print(f"Token length statistics for {store_name}: " f"Min: {min(token_lengths)}, Max: {max(token_lengths)}, " f"Mean: {sum(token_lengths) / len(token_lengths):.1f}, " f"Median: {statistics.median(token_lengths)}, " f"Std: {statistics.stdev(token_lengths):.1f}") import time batch_size = 10 request_count = 0 for i in tqdm(range(0, len(documents), batch_size), desc=f"Processing {store_name} batches"): batch_docs = documents[i:i + batch_size] batch_ids = [str(uuid.uuid4()) for _ in batch_docs] vector_store.add_documents(documents=batch_docs, ids=batch_ids) request_count += 1 if request_count % 100 == 0: time.sleep(60) # Sleep for 1 second every 100 requests vector_store.persist() def find_relevant_docs(self, queries: List[Dict], k: int = 5, trace_id: str = None, topic: str = None, scene_number: int = None) -> List[str]: """Finds relevant documentation based on the provided queries. Args: queries (List[Dict]): List of query dictionaries with 'type' and 'query' keys k (int, optional): Number of results to return per query. Defaults to 5 trace_id (str, optional): Trace identifier for logging. Defaults to None topic (str, optional): Topic name for logging. Defaults to None scene_number (int, optional): Scene number for logging. Defaults to None Returns: List[str]: Formatted string containing relevant documentation snippets """ manim_core_formatted_results = [] manim_plugin_formatted_results = [] # Create a Langfuse span if enabled if self.use_langfuse: langfuse = Langfuse() span = langfuse.span( trace_id=trace_id, # Use the passed trace_id name=f"RAG search for {topic} - scene {scene_number}", metadata={ "topic": topic, "scene_number": scene_number, "session_id": self.session_id } ) # Separate queries by type manim_core_queries = [query for query in queries if query["type"] == "manim-core"] manim_plugin_queries = [query for query in queries if query["type"] != "manim-core" and query["type"] in self.plugin_stores] if len([q for q in queries if q["type"] != "manim-core"]) != len(manim_plugin_queries): print("Warning: Some plugin queries were skipped because their types weren't found in available plugin stores") # Search in core manim docs for query in manim_core_queries: query_text = query["query"] self.core_vector_store._embedding_function.parent_observation_id = span.id manim_core_results = self.core_vector_store.similarity_search_with_relevance_scores( query=query_text, k=k, score_threshold=0.5 ) for result in manim_core_results: manim_core_formatted_results.append({ "query": query_text, "source": result[0].metadata['source'], "content": result[0].page_content, "score": result[1] }) # Search in relevant plugin docs for query in manim_plugin_queries: plugin_name = query["type"] query_text = query["query"] self.plugin_stores[plugin_name]._embedding_function.parent_observation_id = span.id if plugin_name in self.plugin_stores: plugin_results = self.plugin_stores[plugin_name].similarity_search_with_relevance_scores( query=query_text, k=k, score_threshold=0.5 ) for result in plugin_results: manim_plugin_formatted_results.append({ "query": query_text, "source": result[0].metadata['source'], "content": result[0].page_content, "score": result[1] }) print(f"Number of results before removing duplicates: {len(manim_core_formatted_results) + len(manim_plugin_formatted_results)}") # Remove duplicates based on content manim_core_unique_results = [] manim_plugin_unique_results = [] seen = set() for item in manim_core_formatted_results: key = item['content'] if key not in seen: manim_core_unique_results.append(item) seen.add(key) for item in manim_plugin_formatted_results: key = item['content'] if key not in seen: manim_plugin_unique_results.append(item) seen.add(key) print(f"Number of results after removing duplicates: {len(manim_core_unique_results) + len(manim_plugin_unique_results)}") total_tokens = sum(len(self.enc.encode(res['content'])) for res in manim_core_unique_results + manim_plugin_unique_results) print(f"Total tokens for the RAG search: {total_tokens}") # Update Langfuse with the deduplicated results if self.use_langfuse: filtered_results_markdown = json.dumps(manim_core_unique_results + manim_plugin_unique_results, indent=2) span.update( # Use span.update, not span.end output=filtered_results_markdown, metadata={ "total_tokens": total_tokens, "initial_results_count": len(manim_core_formatted_results) + len(manim_plugin_formatted_results), "filtered_results_count": len(manim_core_unique_results) + len(manim_plugin_unique_results) } ) manim_core_results = "Please refer to the following Manim core documentation that may be helpful for the code generation:\n\n" + "\n\n".join([f"Content:\n````text\n{res['content']}\n````\nScore: {res['score']}" for res in manim_core_unique_results]) manim_plugin_results = "Please refer to the following Manim plugin documentation that may be helpful for the code generation:\n\n" + "\n\n".join([f"Content:\n````text\n{res['content']}\n````\nScore: {res['score']}" for res in manim_plugin_unique_results]) return manim_core_results + "\n\n" + manim_plugin_results