AnseMin commited on
Commit
4dfec96
·
1 Parent(s): bf4414c

Update embedding model to Google Generative AI and enhance vector store functionality

Browse files

- Changed the embedding model from OpenAI to Google Generative AI in the EmbeddingManager class.
- Updated the configuration to reflect the new embedding model path.
- Modified validation checks to ensure the presence of the Google API key for RAG embeddings.
- Added a new method to reset the vector store, allowing for a complete clear and recreation of the collection.
- Enhanced logging to provide clearer feedback on embedding model initialization and vector store operations.

src/core/config.py CHANGED
@@ -86,7 +86,7 @@ class RAGConfig:
86
  chat_history_path: str = "./data/chat_history"
87
 
88
  # Embedding settings
89
- embedding_model: str = "text-embedding-3-small"
90
  embedding_chunk_size: int = 1000
91
 
92
  # Chunking settings
@@ -182,8 +182,8 @@ class Config:
182
  validation_results["warnings"].append("Mistral API key not found - Mistral parser will be unavailable")
183
 
184
  # Check RAG dependencies
185
- if not self.api.openai_api_key:
186
- validation_results["warnings"].append("OpenAI API key not found - RAG embeddings will be unavailable")
187
 
188
  if not self.api.google_api_key:
189
  validation_results["warnings"].append("Google API key not found - RAG chat will be unavailable")
 
86
  chat_history_path: str = "./data/chat_history"
87
 
88
  # Embedding settings
89
+ embedding_model: str = "models/text-embedding-004"
90
  embedding_chunk_size: int = 1000
91
 
92
  # Chunking settings
 
182
  validation_results["warnings"].append("Mistral API key not found - Mistral parser will be unavailable")
183
 
184
  # Check RAG dependencies
185
+ if not self.api.google_api_key:
186
+ validation_results["warnings"].append("Google API key not found - RAG embeddings will be unavailable")
187
 
188
  if not self.api.google_api_key:
189
  validation_results["warnings"].append("Google API key not found - RAG chat will be unavailable")
src/rag/embeddings.py CHANGED
@@ -2,7 +2,7 @@
2
 
3
  import os
4
  from typing import Optional
5
- from langchain_openai import OpenAIEmbeddings
6
  from src.core.config import config
7
  from src.core.logging_config import get_logger
8
 
@@ -12,30 +12,28 @@ class EmbeddingManager:
12
  """Manages embedding models for document vectorization."""
13
 
14
  def __init__(self):
15
- self._embedding_model: Optional[OpenAIEmbeddings] = None
16
 
17
- def get_embedding_model(self) -> OpenAIEmbeddings:
18
- """Get or create the OpenAI embedding model."""
19
  if self._embedding_model is None:
20
  try:
21
- # Get OpenAI API key from config/environment
22
- openai_api_key = config.api.openai_api_key or os.getenv("OPENAI_API_KEY")
23
 
24
- if not openai_api_key:
25
- raise ValueError("OpenAI API key not found. Please set OPENAI_API_KEY in environment variables.")
26
 
27
- self._embedding_model = OpenAIEmbeddings(
28
- model="text-embedding-3-small",
29
- openai_api_key=openai_api_key,
30
- chunk_size=1000, # Process documents in chunks
31
- max_retries=3,
32
- timeout=30
33
  )
34
 
35
- logger.info("OpenAI embedding model initialized successfully")
36
 
37
  except Exception as e:
38
- logger.error(f"Failed to initialize OpenAI embedding model: {e}")
39
  raise
40
 
41
  return self._embedding_model
@@ -50,14 +48,14 @@ class EmbeddingManager:
50
 
51
  # Check if we got a valid embedding (list of floats)
52
  if isinstance(embedding, list) and len(embedding) > 0 and isinstance(embedding[0], float):
53
- logger.info("Embedding model test successful")
54
  return True
55
  else:
56
- logger.error("Embedding model test failed: Invalid embedding format")
57
  return False
58
 
59
  except Exception as e:
60
- logger.error(f"Embedding model test failed: {e}")
61
  return False
62
 
63
  # Global embedding manager instance
 
2
 
3
  import os
4
  from typing import Optional
5
+ from langchain_google_genai import GoogleGenerativeAIEmbeddings
6
  from src.core.config import config
7
  from src.core.logging_config import get_logger
8
 
 
12
  """Manages embedding models for document vectorization."""
13
 
14
  def __init__(self):
15
+ self._embedding_model: Optional[GoogleGenerativeAIEmbeddings] = None
16
 
17
+ def get_embedding_model(self) -> GoogleGenerativeAIEmbeddings:
18
+ """Get or create the Gemini embedding model."""
19
  if self._embedding_model is None:
20
  try:
21
+ # Get Google API key from config/environment
22
+ google_api_key = config.api.google_api_key or os.getenv("GOOGLE_API_KEY")
23
 
24
+ if not google_api_key:
25
+ raise ValueError("Google API key not found. Please set GOOGLE_API_KEY in environment variables.")
26
 
27
+ self._embedding_model = GoogleGenerativeAIEmbeddings(
28
+ model=config.rag.embedding_model,
29
+ google_api_key=google_api_key,
30
+ task_type="RETRIEVAL_DOCUMENT"
 
 
31
  )
32
 
33
+ logger.info(f"Gemini embedding model ({config.rag.embedding_model}) initialized successfully")
34
 
35
  except Exception as e:
36
+ logger.error(f"Failed to initialize Gemini embedding model: {e}")
37
  raise
38
 
39
  return self._embedding_model
 
48
 
49
  # Check if we got a valid embedding (list of floats)
50
  if isinstance(embedding, list) and len(embedding) > 0 and isinstance(embedding[0], float):
51
+ logger.info("Gemini embedding model test successful")
52
  return True
53
  else:
54
+ logger.error("Gemini embedding model test failed: Invalid embedding format")
55
  return False
56
 
57
  except Exception as e:
58
+ logger.error(f"Gemini embedding model test failed: {e}")
59
  return False
60
 
61
  # Global embedding manager instance
src/rag/vector_store.py CHANGED
@@ -70,6 +70,7 @@ class VectorStoreManager:
70
 
71
  logger.info(f"VectorStoreManager initialized with persist_directory={self.persist_directory}")
72
 
 
73
  def get_vector_store(self) -> Chroma:
74
  """Get or create the Chroma vector store."""
75
  if self._vector_store is None:
@@ -314,7 +315,7 @@ class VectorStoreManager:
314
  "collection_name": self.collection_name,
315
  "persist_directory": self.persist_directory,
316
  "document_count": count,
317
- "embedding_model": "text-embedding-3-small"
318
  }
319
 
320
  logger.info(f"Collection info: {info}")
@@ -371,6 +372,36 @@ class VectorStoreManager:
371
  logger.error(f"Error searching with metadata filter: {e}")
372
  return []
373
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
374
  def clear_all_documents(self) -> bool:
375
  """
376
  Clear all documents from the vector store collection.
 
70
 
71
  logger.info(f"VectorStoreManager initialized with persist_directory={self.persist_directory}")
72
 
73
+
74
  def get_vector_store(self) -> Chroma:
75
  """Get or create the Chroma vector store."""
76
  if self._vector_store is None:
 
315
  "collection_name": self.collection_name,
316
  "persist_directory": self.persist_directory,
317
  "document_count": count,
318
+ "embedding_model": config.rag.embedding_model
319
  }
320
 
321
  logger.info(f"Collection info: {info}")
 
372
  logger.error(f"Error searching with metadata filter: {e}")
373
  return []
374
 
375
+ def reset_vector_store(self) -> bool:
376
+ """
377
+ Reset the vector store completely.
378
+ This will clear all documents and recreate the collection.
379
+
380
+ Returns:
381
+ True if successful, False otherwise
382
+ """
383
+ try:
384
+ logger.info("Resetting vector store...")
385
+
386
+ # Clear all documents and reset the vector store
387
+ success = self.clear_all_documents()
388
+
389
+ if success:
390
+ # Also delete the collection to ensure clean state
391
+ if self._vector_store is not None:
392
+ self._vector_store.delete_collection()
393
+ self._vector_store = None
394
+
395
+ logger.info("Vector store reset successfully")
396
+ return True
397
+ else:
398
+ logger.error("Failed to reset vector store")
399
+ return False
400
+
401
+ except Exception as e:
402
+ logger.error(f"Error resetting vector store: {e}")
403
+ return False
404
+
405
  def clear_all_documents(self) -> bool:
406
  """
407
  Clear all documents from the vector store collection.