Spaces:
Sleeping
Sleeping
Update embedding model to Google Generative AI and enhance vector store functionality
Browse files- Changed the embedding model from OpenAI to Google Generative AI in the EmbeddingManager class.
- Updated the configuration to reflect the new embedding model path.
- Modified validation checks to ensure the presence of the Google API key for RAG embeddings.
- Added a new method to reset the vector store, allowing for a complete clear and recreation of the collection.
- Enhanced logging to provide clearer feedback on embedding model initialization and vector store operations.
- src/core/config.py +3 -3
- src/rag/embeddings.py +17 -19
- src/rag/vector_store.py +32 -1
src/core/config.py
CHANGED
@@ -86,7 +86,7 @@ class RAGConfig:
|
|
86 |
chat_history_path: str = "./data/chat_history"
|
87 |
|
88 |
# Embedding settings
|
89 |
-
embedding_model: str = "text-embedding-
|
90 |
embedding_chunk_size: int = 1000
|
91 |
|
92 |
# Chunking settings
|
@@ -182,8 +182,8 @@ class Config:
|
|
182 |
validation_results["warnings"].append("Mistral API key not found - Mistral parser will be unavailable")
|
183 |
|
184 |
# Check RAG dependencies
|
185 |
-
if not self.api.
|
186 |
-
validation_results["warnings"].append("
|
187 |
|
188 |
if not self.api.google_api_key:
|
189 |
validation_results["warnings"].append("Google API key not found - RAG chat will be unavailable")
|
|
|
86 |
chat_history_path: str = "./data/chat_history"
|
87 |
|
88 |
# Embedding settings
|
89 |
+
embedding_model: str = "models/text-embedding-004"
|
90 |
embedding_chunk_size: int = 1000
|
91 |
|
92 |
# Chunking settings
|
|
|
182 |
validation_results["warnings"].append("Mistral API key not found - Mistral parser will be unavailable")
|
183 |
|
184 |
# Check RAG dependencies
|
185 |
+
if not self.api.google_api_key:
|
186 |
+
validation_results["warnings"].append("Google API key not found - RAG embeddings will be unavailable")
|
187 |
|
188 |
if not self.api.google_api_key:
|
189 |
validation_results["warnings"].append("Google API key not found - RAG chat will be unavailable")
|
src/rag/embeddings.py
CHANGED
@@ -2,7 +2,7 @@
|
|
2 |
|
3 |
import os
|
4 |
from typing import Optional
|
5 |
-
from
|
6 |
from src.core.config import config
|
7 |
from src.core.logging_config import get_logger
|
8 |
|
@@ -12,30 +12,28 @@ class EmbeddingManager:
|
|
12 |
"""Manages embedding models for document vectorization."""
|
13 |
|
14 |
def __init__(self):
|
15 |
-
self._embedding_model: Optional[
|
16 |
|
17 |
-
def get_embedding_model(self) ->
|
18 |
-
"""Get or create the
|
19 |
if self._embedding_model is None:
|
20 |
try:
|
21 |
-
# Get
|
22 |
-
|
23 |
|
24 |
-
if not
|
25 |
-
raise ValueError("
|
26 |
|
27 |
-
self._embedding_model =
|
28 |
-
model=
|
29 |
-
|
30 |
-
|
31 |
-
max_retries=3,
|
32 |
-
timeout=30
|
33 |
)
|
34 |
|
35 |
-
logger.info("
|
36 |
|
37 |
except Exception as e:
|
38 |
-
logger.error(f"Failed to initialize
|
39 |
raise
|
40 |
|
41 |
return self._embedding_model
|
@@ -50,14 +48,14 @@ class EmbeddingManager:
|
|
50 |
|
51 |
# Check if we got a valid embedding (list of floats)
|
52 |
if isinstance(embedding, list) and len(embedding) > 0 and isinstance(embedding[0], float):
|
53 |
-
logger.info("
|
54 |
return True
|
55 |
else:
|
56 |
-
logger.error("
|
57 |
return False
|
58 |
|
59 |
except Exception as e:
|
60 |
-
logger.error(f"
|
61 |
return False
|
62 |
|
63 |
# Global embedding manager instance
|
|
|
2 |
|
3 |
import os
|
4 |
from typing import Optional
|
5 |
+
from langchain_google_genai import GoogleGenerativeAIEmbeddings
|
6 |
from src.core.config import config
|
7 |
from src.core.logging_config import get_logger
|
8 |
|
|
|
12 |
"""Manages embedding models for document vectorization."""
|
13 |
|
14 |
def __init__(self):
|
15 |
+
self._embedding_model: Optional[GoogleGenerativeAIEmbeddings] = None
|
16 |
|
17 |
+
def get_embedding_model(self) -> GoogleGenerativeAIEmbeddings:
|
18 |
+
"""Get or create the Gemini embedding model."""
|
19 |
if self._embedding_model is None:
|
20 |
try:
|
21 |
+
# Get Google API key from config/environment
|
22 |
+
google_api_key = config.api.google_api_key or os.getenv("GOOGLE_API_KEY")
|
23 |
|
24 |
+
if not google_api_key:
|
25 |
+
raise ValueError("Google API key not found. Please set GOOGLE_API_KEY in environment variables.")
|
26 |
|
27 |
+
self._embedding_model = GoogleGenerativeAIEmbeddings(
|
28 |
+
model=config.rag.embedding_model,
|
29 |
+
google_api_key=google_api_key,
|
30 |
+
task_type="RETRIEVAL_DOCUMENT"
|
|
|
|
|
31 |
)
|
32 |
|
33 |
+
logger.info(f"Gemini embedding model ({config.rag.embedding_model}) initialized successfully")
|
34 |
|
35 |
except Exception as e:
|
36 |
+
logger.error(f"Failed to initialize Gemini embedding model: {e}")
|
37 |
raise
|
38 |
|
39 |
return self._embedding_model
|
|
|
48 |
|
49 |
# Check if we got a valid embedding (list of floats)
|
50 |
if isinstance(embedding, list) and len(embedding) > 0 and isinstance(embedding[0], float):
|
51 |
+
logger.info("Gemini embedding model test successful")
|
52 |
return True
|
53 |
else:
|
54 |
+
logger.error("Gemini embedding model test failed: Invalid embedding format")
|
55 |
return False
|
56 |
|
57 |
except Exception as e:
|
58 |
+
logger.error(f"Gemini embedding model test failed: {e}")
|
59 |
return False
|
60 |
|
61 |
# Global embedding manager instance
|
src/rag/vector_store.py
CHANGED
@@ -70,6 +70,7 @@ class VectorStoreManager:
|
|
70 |
|
71 |
logger.info(f"VectorStoreManager initialized with persist_directory={self.persist_directory}")
|
72 |
|
|
|
73 |
def get_vector_store(self) -> Chroma:
|
74 |
"""Get or create the Chroma vector store."""
|
75 |
if self._vector_store is None:
|
@@ -314,7 +315,7 @@ class VectorStoreManager:
|
|
314 |
"collection_name": self.collection_name,
|
315 |
"persist_directory": self.persist_directory,
|
316 |
"document_count": count,
|
317 |
-
"embedding_model":
|
318 |
}
|
319 |
|
320 |
logger.info(f"Collection info: {info}")
|
@@ -371,6 +372,36 @@ class VectorStoreManager:
|
|
371 |
logger.error(f"Error searching with metadata filter: {e}")
|
372 |
return []
|
373 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
374 |
def clear_all_documents(self) -> bool:
|
375 |
"""
|
376 |
Clear all documents from the vector store collection.
|
|
|
70 |
|
71 |
logger.info(f"VectorStoreManager initialized with persist_directory={self.persist_directory}")
|
72 |
|
73 |
+
|
74 |
def get_vector_store(self) -> Chroma:
|
75 |
"""Get or create the Chroma vector store."""
|
76 |
if self._vector_store is None:
|
|
|
315 |
"collection_name": self.collection_name,
|
316 |
"persist_directory": self.persist_directory,
|
317 |
"document_count": count,
|
318 |
+
"embedding_model": config.rag.embedding_model
|
319 |
}
|
320 |
|
321 |
logger.info(f"Collection info: {info}")
|
|
|
372 |
logger.error(f"Error searching with metadata filter: {e}")
|
373 |
return []
|
374 |
|
375 |
+
def reset_vector_store(self) -> bool:
|
376 |
+
"""
|
377 |
+
Reset the vector store completely.
|
378 |
+
This will clear all documents and recreate the collection.
|
379 |
+
|
380 |
+
Returns:
|
381 |
+
True if successful, False otherwise
|
382 |
+
"""
|
383 |
+
try:
|
384 |
+
logger.info("Resetting vector store...")
|
385 |
+
|
386 |
+
# Clear all documents and reset the vector store
|
387 |
+
success = self.clear_all_documents()
|
388 |
+
|
389 |
+
if success:
|
390 |
+
# Also delete the collection to ensure clean state
|
391 |
+
if self._vector_store is not None:
|
392 |
+
self._vector_store.delete_collection()
|
393 |
+
self._vector_store = None
|
394 |
+
|
395 |
+
logger.info("Vector store reset successfully")
|
396 |
+
return True
|
397 |
+
else:
|
398 |
+
logger.error("Failed to reset vector store")
|
399 |
+
return False
|
400 |
+
|
401 |
+
except Exception as e:
|
402 |
+
logger.error(f"Error resetting vector store: {e}")
|
403 |
+
return False
|
404 |
+
|
405 |
def clear_all_documents(self) -> bool:
|
406 |
"""
|
407 |
Clear all documents from the vector store collection.
|