Spaces:
Running
Running
Refactor LimitedEnsembleRetriever for improved compatibility and functionality
Browse files- Updated the `LimitedEnsembleRetriever` class to remove inheritance from `BaseRetriever`, simplifying its structure.
- Changed the method names to align with the new invoke interface, replacing deprecated methods with `invoke` and `ainvoke`.
- Added a compatibility method to handle both string input and other data types for the `invoke` method, enhancing usability.
- Improved documentation within the class to clarify the purpose and functionality of methods.
- src/rag/vector_store.py +13 -8
src/rag/vector_store.py
CHANGED
@@ -16,25 +16,30 @@ from src.core.logging_config import get_logger
|
|
16 |
logger = get_logger(__name__)
|
17 |
|
18 |
|
19 |
-
class LimitedEnsembleRetriever
|
20 |
-
"""
|
21 |
|
22 |
def __init__(self, ensemble_retriever: EnsembleRetriever, k: int):
|
23 |
-
super().__init__()
|
24 |
self.ensemble_retriever = ensemble_retriever
|
25 |
self.k = k
|
26 |
|
27 |
-
def
|
28 |
"""Get relevant documents, limited to k results."""
|
29 |
-
#
|
30 |
-
docs = self.ensemble_retriever.
|
31 |
# Limit to k results
|
32 |
return docs[:self.k]
|
33 |
|
34 |
-
async def
|
35 |
"""Async version of get_relevant_documents."""
|
36 |
-
docs = await self.ensemble_retriever.
|
37 |
return docs[:self.k]
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
|
39 |
|
40 |
class VectorStoreManager:
|
|
|
16 |
logger = get_logger(__name__)
|
17 |
|
18 |
|
19 |
+
class LimitedEnsembleRetriever:
|
20 |
+
"""Simple wrapper around EnsembleRetriever that limits total results to k."""
|
21 |
|
22 |
def __init__(self, ensemble_retriever: EnsembleRetriever, k: int):
|
|
|
23 |
self.ensemble_retriever = ensemble_retriever
|
24 |
self.k = k
|
25 |
|
26 |
+
def get_relevant_documents(self, query: str) -> List[Document]:
|
27 |
"""Get relevant documents, limited to k results."""
|
28 |
+
# Use invoke method instead of deprecated get_relevant_documents
|
29 |
+
docs = self.ensemble_retriever.invoke(query)
|
30 |
# Limit to k results
|
31 |
return docs[:self.k]
|
32 |
|
33 |
+
async def aget_relevant_documents(self, query: str) -> List[Document]:
|
34 |
"""Async version of get_relevant_documents."""
|
35 |
+
docs = await self.ensemble_retriever.ainvoke(query)
|
36 |
return docs[:self.k]
|
37 |
+
|
38 |
+
def invoke(self, input_data, config=None, **kwargs):
|
39 |
+
"""Compatibility method for invoke interface."""
|
40 |
+
if isinstance(input_data, str):
|
41 |
+
return self.get_relevant_documents(input_data)
|
42 |
+
return self.get_relevant_documents(input_data)
|
43 |
|
44 |
|
45 |
class VectorStoreManager:
|