Spaces:
Runtime error
Runtime error
import json | |
import os | |
from typing import List, Dict | |
import uuid | |
from langchain.schema import Document | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain_community.document_loaders import TextLoader | |
from langchain_community.vectorstores import Chroma | |
from langchain_text_splitters import Language | |
from langchain_core.embeddings import Embeddings | |
import statistics | |
from litellm import embedding | |
import litellm | |
import tiktoken | |
from tqdm import tqdm | |
from langfuse import Langfuse | |
from mllm_tools.utils import _prepare_text_inputs | |
from task_generator import get_prompt_detect_plugins | |
class RAGVectorStore: | |
"""A class for managing vector stores for RAG (Retrieval Augmented Generation). | |
This class handles creation, loading and querying of vector stores for both Manim core | |
and plugin documentation. | |
Args: | |
chroma_db_path (str): Path to ChromaDB storage directory | |
manim_docs_path (str): Path to Manim documentation files | |
embedding_model (str): Name of the embedding model to use | |
trace_id (str, optional): Trace identifier for logging. Defaults to None | |
session_id (str, optional): Session identifier. Defaults to None | |
use_langfuse (bool, optional): Whether to use Langfuse logging. Defaults to True | |
helper_model: Helper model for processing. Defaults to None | |
""" | |
def __init__(self, | |
chroma_db_path: str = "chroma_db", | |
manim_docs_path: str = "rag/manim_docs", | |
embedding_model: str = "text-embedding-ada-002", | |
trace_id: str = None, | |
session_id: str = None, | |
use_langfuse: bool = True, | |
helper_model = None): | |
self.chroma_db_path = chroma_db_path | |
self.manim_docs_path = manim_docs_path | |
self.embedding_model = embedding_model | |
self.trace_id = trace_id | |
self.session_id = session_id | |
self.use_langfuse = use_langfuse | |
self.helper_model = helper_model | |
self.enc = tiktoken.encoding_for_model("gpt-4") | |
self.plugin_stores = {} | |
self.vector_store = self._load_or_create_vector_store() | |
def _load_or_create_vector_store(self): | |
"""Loads existing or creates new ChromaDB vector stores. | |
Creates/loads vector stores for both Manim core documentation and any available plugins. | |
Stores are persisted to disk for future reuse. | |
Returns: | |
Chroma: The core Manim vector store instance | |
""" | |
print("Entering _load_or_create_vector_store with trace_id:", self.trace_id) | |
core_path = os.path.join(self.chroma_db_path, "manim_core") | |
# Load or create core vector store | |
if os.path.exists(core_path): | |
print("Loading existing core ChromaDB...") | |
self.core_vector_store = Chroma( | |
collection_name="manim_core", | |
persist_directory=core_path, | |
embedding_function=self._get_embedding_function() | |
) | |
else: | |
print("Creating new core ChromaDB...") | |
self.core_vector_store = self._create_core_store() | |
# Fix: Use correct path construction for plugin_docs | |
plugin_docs_path = os.path.join(self.manim_docs_path, "plugin_docs") | |
print(f"Plugin docs path: {plugin_docs_path}") | |
if os.path.exists(plugin_docs_path): | |
for plugin_name in os.listdir(plugin_docs_path): | |
plugin_store_path = os.path.join(self.chroma_db_path, f"manim_plugin_{plugin_name}") | |
if os.path.exists(plugin_store_path): | |
print(f"Loading existing plugin store: {plugin_name}") | |
self.plugin_stores[plugin_name] = Chroma( | |
collection_name=f"manim_plugin_{plugin_name}", | |
persist_directory=plugin_store_path, | |
embedding_function=self._get_embedding_function() | |
) | |
else: | |
print(f"Creating new plugin store: {plugin_name}") | |
plugin_path = os.path.join(plugin_docs_path, plugin_name) | |
if os.path.isdir(plugin_path): | |
plugin_store = Chroma( | |
collection_name=f"manim_plugin_{plugin_name}", | |
embedding_function=self._get_embedding_function(), | |
persist_directory=plugin_store_path | |
) | |
plugin_docs = self._process_documentation_folder(plugin_path) | |
if plugin_docs: | |
self._add_documents_to_store(plugin_store, plugin_docs, plugin_name) | |
self.plugin_stores[plugin_name] = plugin_store | |
return self.core_vector_store # Return core store for backward compatibility | |
def _get_embedding_function(self) -> Embeddings: | |
"""Creates an embedding function using litellm. | |
Returns: | |
Embeddings: A LangChain Embeddings instance that wraps litellm functionality | |
""" | |
class LiteLLMEmbeddings(Embeddings): | |
def __init__(self, embedding_model): | |
self.embedding_model = embedding_model | |
def embed_documents(self, texts: list[str]) -> list[list[float]]: | |
litellm.success_callback = [] | |
litellm.failure_callback = [] | |
response = embedding( | |
model=self.embedding_model, | |
input=texts, | |
task_type="CODE_RETRIEVAL_QUERY" if self.embedding_model == "vertex_ai/text-embedding-005" else None | |
) | |
litellm.success_callback = ["langfuse"] | |
litellm.failure_callback = ["langfuse"] | |
return [r["embedding"] for r in response["data"]] | |
def embed_query(self, text: str) -> list[float]: | |
litellm.success_callback = [] | |
litellm.failure_callback = [] | |
response = embedding( | |
model=self.embedding_model, | |
input=[text], | |
task_type="CODE_RETRIEVAL_QUERY" if self.embedding_model == "vertex_ai/text-embedding-005" else None | |
) | |
litellm.success_callback = ["langfuse"] | |
litellm.failure_callback = ["langfuse"] | |
return response["data"][0]["embedding"] | |
return LiteLLMEmbeddings(self.embedding_model) | |
def _create_core_store(self): | |
"""Creates the main ChromaDB vector store for Manim core documentation. | |
Returns: | |
Chroma: The initialized and populated core vector store | |
""" | |
core_vector_store = Chroma( | |
collection_name="manim_core", | |
embedding_function=self._get_embedding_function(), | |
persist_directory=os.path.join(self.chroma_db_path, "manim_core") | |
) | |
# Process manim core docs | |
core_docs = self._process_documentation_folder(os.path.join(self.manim_docs_path, "manim_core")) | |
if core_docs: | |
self._add_documents_to_store(core_vector_store, core_docs, "manim_core") | |
return core_vector_store | |
def _process_documentation_folder(self, folder_path: str) -> List[Document]: | |
"""Processes documentation files from a folder into LangChain documents. | |
Args: | |
folder_path (str): Path to the folder containing documentation files | |
Returns: | |
List[Document]: List of processed LangChain documents | |
""" | |
all_docs = [] | |
for root, _, files in os.walk(folder_path): | |
for file in files: | |
if file.endswith(('.md', '.py')): | |
file_path = os.path.join(root, file) | |
try: | |
loader = TextLoader(file_path) | |
documents = loader.load() | |
for doc in documents: | |
doc.metadata['source'] = file_path | |
all_docs.extend(documents) | |
except Exception as e: | |
print(f"Error loading file {file_path}: {e}") | |
if not all_docs: | |
print(f"No markdown or python files found in {folder_path}") | |
return [] | |
# Split documents using appropriate splitters | |
split_docs = [] | |
markdown_splitter = RecursiveCharacterTextSplitter.from_language( | |
language=Language.MARKDOWN | |
) | |
python_splitter = RecursiveCharacterTextSplitter.from_language( | |
language=Language.PYTHON | |
) | |
for doc in all_docs: | |
if doc.metadata['source'].endswith('.md'): | |
temp_docs = markdown_splitter.split_documents([doc]) | |
for temp_doc in temp_docs: | |
temp_doc.page_content = f"Source: {doc.metadata['source']}\n\n{temp_doc.page_content}" | |
split_docs.extend(temp_docs) | |
elif doc.metadata['source'].endswith('.py'): | |
temp_docs = python_splitter.split_documents([doc]) | |
for temp_doc in temp_docs: | |
temp_doc.page_content = f"Source: {doc.metadata['source']}\n\n{temp_doc.page_content}" | |
split_docs.extend(temp_docs) | |
return split_docs | |
def _add_documents_to_store(self, vector_store: Chroma, documents: List[Document], store_name: str): | |
"""Adds documents to a vector store in batches with rate limiting. | |
Args: | |
vector_store (Chroma): The vector store to add documents to | |
documents (List[Document]): List of documents to add | |
store_name (str): Name of the store for logging purposes | |
""" | |
print(f"Adding documents to {store_name} store") | |
# Calculate token statistics | |
token_lengths = [len(self.enc.encode(doc.page_content)) for doc in documents] | |
print(f"Token length statistics for {store_name}: " | |
f"Min: {min(token_lengths)}, Max: {max(token_lengths)}, " | |
f"Mean: {sum(token_lengths) / len(token_lengths):.1f}, " | |
f"Median: {statistics.median(token_lengths)}, " | |
f"Std: {statistics.stdev(token_lengths):.1f}") | |
import time | |
batch_size = 10 | |
request_count = 0 | |
for i in tqdm(range(0, len(documents), batch_size), desc=f"Processing {store_name} batches"): | |
batch_docs = documents[i:i + batch_size] | |
batch_ids = [str(uuid.uuid4()) for _ in batch_docs] | |
vector_store.add_documents(documents=batch_docs, ids=batch_ids) | |
request_count += 1 | |
if request_count % 100 == 0: | |
time.sleep(60) # Sleep for 1 second every 100 requests | |
vector_store.persist() | |
def find_relevant_docs(self, queries: List[Dict], k: int = 5, trace_id: str = None, topic: str = None, scene_number: int = None) -> List[str]: | |
"""Finds relevant documentation based on the provided queries. | |
Args: | |
queries (List[Dict]): List of query dictionaries with 'type' and 'query' keys | |
k (int, optional): Number of results to return per query. Defaults to 5 | |
trace_id (str, optional): Trace identifier for logging. Defaults to None | |
topic (str, optional): Topic name for logging. Defaults to None | |
scene_number (int, optional): Scene number for logging. Defaults to None | |
Returns: | |
List[str]: Formatted string containing relevant documentation snippets | |
""" | |
manim_core_formatted_results = [] | |
manim_plugin_formatted_results = [] | |
# Create a Langfuse span if enabled | |
if self.use_langfuse: | |
langfuse = Langfuse() | |
span = langfuse.span( | |
trace_id=trace_id, # Use the passed trace_id | |
name=f"RAG search for {topic} - scene {scene_number}", | |
metadata={ | |
"topic": topic, | |
"scene_number": scene_number, | |
"session_id": self.session_id | |
} | |
) | |
# Separate queries by type | |
manim_core_queries = [query for query in queries if query["type"] == "manim-core"] | |
manim_plugin_queries = [query for query in queries if query["type"] != "manim-core" and query["type"] in self.plugin_stores] | |
if len([q for q in queries if q["type"] != "manim-core"]) != len(manim_plugin_queries): | |
print("Warning: Some plugin queries were skipped because their types weren't found in available plugin stores") | |
# Search in core manim docs | |
for query in manim_core_queries: | |
query_text = query["query"] | |
self.core_vector_store._embedding_function.parent_observation_id = span.id | |
manim_core_results = self.core_vector_store.similarity_search_with_relevance_scores( | |
query=query_text, | |
k=k, | |
score_threshold=0.5 | |
) | |
for result in manim_core_results: | |
manim_core_formatted_results.append({ | |
"query": query_text, | |
"source": result[0].metadata['source'], | |
"content": result[0].page_content, | |
"score": result[1] | |
}) | |
# Search in relevant plugin docs | |
for query in manim_plugin_queries: | |
plugin_name = query["type"] | |
query_text = query["query"] | |
self.plugin_stores[plugin_name]._embedding_function.parent_observation_id = span.id | |
if plugin_name in self.plugin_stores: | |
plugin_results = self.plugin_stores[plugin_name].similarity_search_with_relevance_scores( | |
query=query_text, | |
k=k, | |
score_threshold=0.5 | |
) | |
for result in plugin_results: | |
manim_plugin_formatted_results.append({ | |
"query": query_text, | |
"source": result[0].metadata['source'], | |
"content": result[0].page_content, | |
"score": result[1] | |
}) | |
print(f"Number of results before removing duplicates: {len(manim_core_formatted_results) + len(manim_plugin_formatted_results)}") | |
# Remove duplicates based on content | |
manim_core_unique_results = [] | |
manim_plugin_unique_results = [] | |
seen = set() | |
for item in manim_core_formatted_results: | |
key = item['content'] | |
if key not in seen: | |
manim_core_unique_results.append(item) | |
seen.add(key) | |
for item in manim_plugin_formatted_results: | |
key = item['content'] | |
if key not in seen: | |
manim_plugin_unique_results.append(item) | |
seen.add(key) | |
print(f"Number of results after removing duplicates: {len(manim_core_unique_results) + len(manim_plugin_unique_results)}") | |
total_tokens = sum(len(self.enc.encode(res['content'])) for res in manim_core_unique_results + manim_plugin_unique_results) | |
print(f"Total tokens for the RAG search: {total_tokens}") | |
# Update Langfuse with the deduplicated results | |
if self.use_langfuse: | |
filtered_results_markdown = json.dumps(manim_core_unique_results + manim_plugin_unique_results, indent=2) | |
span.update( # Use span.update, not span.end | |
output=filtered_results_markdown, | |
metadata={ | |
"total_tokens": total_tokens, | |
"initial_results_count": len(manim_core_formatted_results) + len(manim_plugin_formatted_results), | |
"filtered_results_count": len(manim_core_unique_results) + len(manim_plugin_unique_results) | |
} | |
) | |
manim_core_results = "Please refer to the following Manim core documentation that may be helpful for the code generation:\n\n" + "\n\n".join([f"Content:\n````text\n{res['content']}\n````\nScore: {res['score']}" for res in manim_core_unique_results]) | |
manim_plugin_results = "Please refer to the following Manim plugin documentation that may be helpful for the code generation:\n\n" + "\n\n".join([f"Content:\n````text\n{res['content']}\n````\nScore: {res['score']}" for res in manim_plugin_unique_results]) | |
return manim_core_results + "\n\n" + manim_plugin_results |