Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
| import os | |
| from functools import lru_cache | |
| from typing import Literal | |
| from langchain_core.vectorstores import VectorStoreRetriever | |
| from langchain_openai import OpenAIEmbeddings | |
| from langchain_qdrant import FastEmbedSparse, QdrantVectorStore, RetrievalMode | |
| os.environ["GRPC_VERBOSITY"] = "NONE" | |
| class RetrieversConfig: | |
| REQUIRED_ENV_VARS = ["QDRANT_API_KEY", "QDRANT_URL", "OPENAI_API_KEY"] | |
| def __init__( | |
| self, | |
| dense_model_name: Literal["text-embedding-3-small"] = "text-embedding-3-small", | |
| sparse_model_name: Literal[ | |
| "prithivida/Splade_PP_en_v1" | |
| ] = "prithivida/Splade_PP_en_v1", | |
| ): | |
| self._validate_environment() | |
| self.qdrant_url = os.getenv("QDRANT_URL") | |
| self.qdrant_api_key = os.getenv("QDRANT_API_KEY") | |
| self.dense_model_name = dense_model_name | |
| self.sparse_model_name = sparse_model_name | |
| def _validate_environment(): | |
| missing_vars = [ | |
| var | |
| for var in RetrieversConfig.REQUIRED_ENV_VARS | |
| if not os.getenv(var, "").strip() | |
| ] | |
| if missing_vars: | |
| raise EnvironmentError( | |
| f"Missing or empty environment variable(s): {', '.join(missing_vars)}" | |
| ) | |
| def dense_embeddings(self): | |
| return OpenAIEmbeddings(model=self.dense_model_name) | |
| def sparse_embeddings(self): | |
| return FastEmbedSparse(model_name=self.sparse_model_name) | |
| def get_qdrant_retriever( | |
| self, | |
| collection_name: str, | |
| dense_vector_name: str, | |
| sparse_vector_name: str, | |
| k: int = 5, | |
| ) -> VectorStoreRetriever: | |
| qdrantdb = QdrantVectorStore.from_existing_collection( | |
| embedding=self.dense_embeddings, | |
| sparse_embedding=self.sparse_embeddings, | |
| url=self.qdrant_url, | |
| api_key=self.qdrant_api_key, | |
| prefer_grpc=True, | |
| collection_name=collection_name, | |
| retrieval_mode=RetrievalMode.HYBRID, | |
| vector_name=dense_vector_name, | |
| sparse_vector_name=sparse_vector_name, | |
| ) | |
| return qdrantdb.as_retriever(search_kwargs={"k": k}) | |
| def get_practitioners_retriever(self, k: int = 5) -> VectorStoreRetriever: | |
| return self.get_qdrant_retriever( | |
| collection_name="practitioners_hybrid_db", | |
| dense_vector_name="practitioners_dense_vectors", | |
| sparse_vector_name="practitioners_sparse_vectors", | |
| k=k, | |
| ) | |
| def get_documents_retriever(self, k: int = 5) -> VectorStoreRetriever: | |
| return self.get_qdrant_retriever( | |
| collection_name="docs_hybrid_db", | |
| dense_vector_name="docs_dense_vectors", | |
| sparse_vector_name="docs_sparse_vectors", | |
| k=k, | |
| ) | |