AnseMin commited on
Commit
5da24ca
·
1 Parent(s): 53ff214

Refactor LimitedEnsembleRetriever for improved compatibility and functionality

Browse files

- Updated the `LimitedEnsembleRetriever` class to remove inheritance from `BaseRetriever`, simplifying its structure.
- Changed the method names to align with the new invoke interface, replacing deprecated methods with `invoke` and `ainvoke`.
- Added a compatibility method to handle both string input and other data types for the `invoke` method, enhancing usability.
- Improved documentation within the class to clarify the purpose and functionality of methods.

Files changed (1) hide show
  1. src/rag/vector_store.py +13 -8
src/rag/vector_store.py CHANGED
@@ -16,25 +16,30 @@ from src.core.logging_config import get_logger
16
  logger = get_logger(__name__)
17
 
18
 
19
- class LimitedEnsembleRetriever(BaseRetriever):
20
- """Wrapper around EnsembleRetriever that limits total results to k."""
21
 
22
  def __init__(self, ensemble_retriever: EnsembleRetriever, k: int):
23
- super().__init__()
24
  self.ensemble_retriever = ensemble_retriever
25
  self.k = k
26
 
27
- def _get_relevant_documents(self, query: str, *, run_manager=None) -> List[Document]:
28
  """Get relevant documents, limited to k results."""
29
- # Get all results from ensemble retriever
30
- docs = self.ensemble_retriever.get_relevant_documents(query)
31
  # Limit to k results
32
  return docs[:self.k]
33
 
34
- async def _aget_relevant_documents(self, query: str, *, run_manager=None) -> List[Document]:
35
  """Async version of get_relevant_documents."""
36
- docs = await self.ensemble_retriever.aget_relevant_documents(query)
37
  return docs[:self.k]
 
 
 
 
 
 
38
 
39
 
40
  class VectorStoreManager:
 
16
  logger = get_logger(__name__)
17
 
18
 
19
+ class LimitedEnsembleRetriever:
20
+ """Simple wrapper around EnsembleRetriever that limits total results to k."""
21
 
22
  def __init__(self, ensemble_retriever: EnsembleRetriever, k: int):
 
23
  self.ensemble_retriever = ensemble_retriever
24
  self.k = k
25
 
26
+ def get_relevant_documents(self, query: str) -> List[Document]:
27
  """Get relevant documents, limited to k results."""
28
+ # Use invoke method instead of deprecated get_relevant_documents
29
+ docs = self.ensemble_retriever.invoke(query)
30
  # Limit to k results
31
  return docs[:self.k]
32
 
33
+ async def aget_relevant_documents(self, query: str) -> List[Document]:
34
  """Async version of get_relevant_documents."""
35
+ docs = await self.ensemble_retriever.ainvoke(query)
36
  return docs[:self.k]
37
+
38
+ def invoke(self, input_data, config=None, **kwargs):
39
+ """Compatibility method for invoke interface."""
40
+ if isinstance(input_data, str):
41
+ return self.get_relevant_documents(input_data)
42
+ return self.get_relevant_documents(input_data)
43
 
44
 
45
  class VectorStoreManager: