Spaces:
Sleeping
Sleeping
import os | |
from dotenv import load_dotenv | |
from mistralai import Mistral | |
import numpy as np | |
import time | |
import chromadb | |
from chromadb.config import Settings | |
import json | |
import hashlib | |
load_dotenv() | |
MISTRAL_API_KEY = os.getenv("MISTRAL_API_KEY") | |
COLLECTION_NAME = "webpages_collection" | |
PERSIST_DIRECTORY = "./chroma_db" | |
def vectorize(input_texts, batch_size=5): | |
""" | |
Get the text embeddings for the given inputs using Mistral API. | |
""" | |
try: | |
client = Mistral(api_key=MISTRAL_API_KEY) | |
except Exception as e: | |
print(f"Error initializing Mistral client: {e}") | |
return [] | |
embeddings = [] | |
for i in range(0, len(input_texts), batch_size): | |
batch = input_texts[i:i + batch_size] | |
while True: | |
try: | |
embeddings_batch_response = client.embeddings.create( | |
model="mistral-embed", | |
inputs=batch | |
) | |
time.sleep(1) | |
embeddings.extend([data.embedding for data in embeddings_batch_response.data]) | |
break | |
except Exception as e: | |
if "rate limit exceeded" in str(e).lower(): | |
print("Rate limit exceeded. Retrying after 10 seconds...") | |
time.sleep(10) | |
else: | |
print(f"Error in embedding batch: {e}") | |
raise | |
return embeddings | |
def chunk_content(markdown_content, chunk_size=2048): | |
""" | |
Vectorizes the given markdown content into chunks of specified size without cutting sentences. | |
""" | |
def find_sentence_end(text, start): | |
"""Find the nearest sentence end from the start index.""" | |
punctuations = {'.', '!', '?'} | |
end = start | |
while end < len(text) and text[end] not in punctuations: | |
end += 1 | |
while end < len(text) and text[end] in punctuations: | |
end += 1 | |
while end > start and text[end - 1] not in punctuations: | |
end -= 1 | |
return end | |
chunks = [] | |
start = 0 | |
while start < len(markdown_content): | |
end = min(start + chunk_size, len(markdown_content)) | |
end = find_sentence_end(markdown_content, end) | |
chunks.append(markdown_content[start:end].strip()) | |
start = end | |
return chunks | |
def generate_chunk_id(chunk): | |
"""Generate a unique ID for a chunk using SHA-256 hash.""" | |
return hashlib.sha256(chunk.encode('utf-8')).hexdigest() | |
def load_in_vector_db(markdown_content, metadatas=None, collection_name=COLLECTION_NAME): | |
""" | |
Load the text embeddings into a ChromaDB collection for efficient similarity search. | |
""" | |
try: | |
client = chromadb.PersistentClient(path=PERSIST_DIRECTORY) | |
except Exception as e: | |
print(f"Error initializing ChromaDB client: {e}") | |
return | |
try: | |
if collection_name not in [col.name for col in client.list_collections()]: | |
collection = client.create_collection(collection_name) | |
else: | |
collection = client.get_collection(collection_name) | |
except Exception as e: | |
print(f"Error accessing collection: {e}") | |
return | |
try: | |
existing_items = collection.get() | |
except Exception as e: | |
print(f"Error retrieving existing items: {e}") | |
return | |
existing_ids = set() | |
if 'ids' in existing_items: | |
existing_ids.update(existing_items['ids']) | |
chunks = chunk_content(markdown_content) | |
text_to_vectorize = [] | |
for chunk in chunks: | |
chunk_id = generate_chunk_id(chunk) | |
if chunk_id not in existing_ids: | |
text_to_vectorize.append(chunk) | |
print(f"New chunks to vectorize: {len(text_to_vectorize)}") | |
if text_to_vectorize: | |
embeddings = vectorize(text_to_vectorize) | |
for embedding, chunk in zip(embeddings, text_to_vectorize): | |
chunk_id = generate_chunk_id(chunk) | |
if chunk_id not in existing_ids: | |
try: | |
collection.add( | |
embeddings=[embedding], | |
documents=[chunk], | |
metadatas=[metadatas], | |
ids=[chunk_id] | |
) | |
existing_ids.add(chunk_id) | |
except Exception as e: | |
print(f"Error adding embedding to collection: {e}") | |
def retrieve_from_database(query, collection_name=COLLECTION_NAME, n_results=5, distance_threshold=None): | |
""" | |
Retrieve the most similar documents from the vector store based on the query. | |
""" | |
try: | |
client = chromadb.PersistentClient(path=PERSIST_DIRECTORY) | |
collection = client.get_collection(collection_name) | |
except Exception as e: | |
print(f"Error accessing collection: {e}") | |
return | |
try: | |
query_embeddings = vectorize([query]) | |
except Exception as e: | |
print(f"Error vectorizing query: {e}") | |
return | |
try: | |
raw_results = collection.query( | |
query_embeddings=query_embeddings, | |
n_results=n_results, | |
include=["documents", "metadatas", "distances"] | |
) | |
except Exception as e: | |
print(f"Error querying collection: {e}") | |
return | |
if distance_threshold is not None: | |
filtered_results = { | |
"ids": [], | |
"distances": [], | |
"metadatas": [], | |
"documents": [] | |
} | |
for i, distance in enumerate(raw_results['distances'][0]): | |
if distance <= distance_threshold: | |
filtered_results['ids'].append(raw_results['ids'][0][i]) | |
filtered_results['distances'].append(distance) | |
filtered_results['metadatas'].append(raw_results['metadatas'][0][i]) | |
filtered_results['documents'].append(raw_results['documents'][0][i]) | |
results = filtered_results | |
if len(results['documents']) == 0: | |
return "No relevant data found in the knowledge database. Have you checked any webpages? If so, please try to find more relevant data." | |
else: | |
return results | |
else: | |
return raw_results | |
def search_documents(collection_name=COLLECTION_NAME, query=None, query_embedding=None, metadata_filter=None, n_results=10): | |
""" | |
Search for documents in a ChromaDB collection. | |
:param collection_name: The name of the collection to search within. | |
:param query: The text query to search for (optional). | |
:param query_embedding: The embedding query to search for (optional). | |
:param metadata_filter: A filter to apply to the metadata (optional). | |
:param n_results: The number of results to return (default is 10). | |
:return: The search results. | |
""" | |
client = chromadb.PersistentClient(path=PERSIST_DIRECTORY) | |
collection = client.get_collection(collection_name) | |
if query: | |
query_embedding = vectorize([query])[0] | |
if query_embedding: | |
results = collection.query(query_embeddings=[query_embedding], n_results=n_results, where=metadata_filter) | |
else: | |
results = collection.get(where=metadata_filter, limit=n_results) | |
return results | |
def delete_documents(collection_name=COLLECTION_NAME, ids=None): | |
""" | |
Delete documents from a ChromaDB collection based on their IDs. | |
:param collection_name: The name of the collection. | |
:param ids: A list of IDs of the documents to delete. | |
""" | |
client = chromadb.PersistentClient(path=PERSIST_DIRECTORY) | |
collection = client.get_collection(collection_name) | |
collection.delete(ids=ids) | |
print(f"Documents with IDs {ids} have been deleted from the collection {collection_name}.") | |
def delete_collection(collection_name=COLLECTION_NAME): | |
""" | |
Delete a ChromaDB collection. | |
:param collection_name: The name of the collection to delete. | |
""" | |
client = chromadb.PersistentClient(path=PERSIST_DIRECTORY) | |
client.delete_collection(collection_name) | |
print(f"Collection {collection_name} has been deleted.") |