shamik commited on
Commit
a8c86eb
·
unverified ·
1 Parent(s): b5c18b6

fix: code fix.

Browse files
src/agent_hackathon/create_vector_db.py CHANGED
@@ -138,12 +138,12 @@ class VectorDBCreator:
138
  logger.info("Pipeline finished.")
139
 
140
 
141
- if __name__ == "__main__":
142
- logger.info("Script started.")
143
- # Optionally load environment variables if needed
144
- _ = load_dotenv(dotenv_path=find_dotenv(raise_error_if_not_found=True))
145
- creator = VectorDBCreator(
146
- data_path=f"{PROJECT_ROOT_DIR}/data/cs_data_arxiv.json", db_uri="arxiv_docs.db"
147
- )
148
- creator.run()
149
- logger.info("Script finished.")
 
138
  logger.info("Pipeline finished.")
139
 
140
 
141
+ # if __name__ == "__main__":
142
+ # logger.info("Script started.")
143
+ # # Optionally load environment variables if needed
144
+ # _ = load_dotenv(dotenv_path=find_dotenv(raise_error_if_not_found=True))
145
+ # creator = VectorDBCreator(
146
+ # data_path=f"{PROJECT_ROOT_DIR}/data/cs_data_arxiv.json", db_uri="arxiv_docs.db"
147
+ # )
148
+ # creator.run()
149
+ # logger.info("Script finished.")
src/agent_hackathon/generate_arxiv_responses.py CHANGED
@@ -21,6 +21,7 @@ class ArxivResponseGenerator:
21
  """Initializes the ArxivResponseGenerator."""
22
  self.vector_store_path = vector_store_path
23
  self.client = self._initialise_client()
 
24
  logger.info("ArxivResponseGenerator initialized.")
25
 
26
  def _initialise_retriever(self) -> Any:
@@ -40,7 +41,7 @@ class ArxivResponseGenerator:
40
  )
41
  retriever = retriever_class.build_retriever_engine()
42
  logger.info("Retriever engine initialized.")
43
- return retriever, retriever_class
44
 
45
  def _initialise_client(self) -> InferenceClient:
46
  """
@@ -68,11 +69,15 @@ class ArxivResponseGenerator:
68
  str: Formatted response from the LLM.
69
  """
70
  logger.info(f"Retrieving arXiv papers for query: {query}")
71
- retriever, retriever_class = self._initialise_retriever()
72
- retrieved_content = json.dumps(
73
- obj=[(i.get_content(), i.metadata) for i in retriever.retrieve(query)]
74
- )
75
- logger.info("Retrieved content from vector DB.")
 
 
 
 
76
  completion = self.client.chat.completions.create(
77
  model="meta-llama/Llama-4-Scout-17B-16E-Instruct",
78
  temperature=0.1,
@@ -89,17 +94,15 @@ class ArxivResponseGenerator:
89
  ],
90
  )
91
  logger.info("Received completion from LLM.")
92
- retriever_class.vector_store.client.close()
93
- logger.info("Closed vector store client.")
94
  return completion.choices[0].message.content
95
 
96
 
97
- if __name__ == "__main__":
98
- logger.info("Script started.")
99
- generator = ArxivResponseGenerator(
100
- vector_store_path=PROJECT_ROOT_DIR / "db/arxiv_docs.db"
101
- )
102
- query = "deep learning for NLP" # Example query, replace as needed
103
- result = generator.retrieve_arxiv_papers(query=query)
104
- print(result)
105
- logger.info("Script finished.")
 
21
  """Initializes the ArxivResponseGenerator."""
22
  self.vector_store_path = vector_store_path
23
  self.client = self._initialise_client()
24
+ self.retriever = self._initialise_retriever()
25
  logger.info("ArxivResponseGenerator initialized.")
26
 
27
  def _initialise_retriever(self) -> Any:
 
41
  )
42
  retriever = retriever_class.build_retriever_engine()
43
  logger.info("Retriever engine initialized.")
44
+ return retriever
45
 
46
  def _initialise_client(self) -> InferenceClient:
47
  """
 
69
  str: Formatted response from the LLM.
70
  """
71
  logger.info(f"Retrieving arXiv papers for query: {query}")
72
+
73
+ try:
74
+ retrieved_content = json.dumps(
75
+ obj=[(i.get_content(), i.metadata) for i in self.retriever.retrieve(query)]
76
+ )
77
+ logger.info("Retrieved content from vector DB.")
78
+ except Exception as err:
79
+ logger.error(f"Error retrieving from vector DB: {err}")
80
+ raise
81
  completion = self.client.chat.completions.create(
82
  model="meta-llama/Llama-4-Scout-17B-16E-Instruct",
83
  temperature=0.1,
 
94
  ],
95
  )
96
  logger.info("Received completion from LLM.")
 
 
97
  return completion.choices[0].message.content
98
 
99
 
100
+ # if __name__ == "__main__":
101
+ # logger.info("Script started.")
102
+ # generator = ArxivResponseGenerator(
103
+ # vector_store_path=PROJECT_ROOT_DIR / "db/arxiv_docs.db"
104
+ # )
105
+ # query = "deep learning for NLP" # Example query, replace as needed
106
+ # result = generator.retrieve_arxiv_papers(query=query)
107
+ # print(result)
108
+ # logger.info("Script finished.")
src/agent_hackathon/query_vector_db.py CHANGED
@@ -5,7 +5,6 @@ from dotenv import find_dotenv, load_dotenv
5
  from huggingface_hub import login
6
  from llama_index.core import VectorStoreIndex
7
  from llama_index.embeddings.huggingface import HuggingFaceEmbedding
8
- from llama_index.llms.huggingface_api import HuggingFaceInferenceAPI
9
  from llama_index.vector_stores.milvus import MilvusVectorStore
10
 
11
  from src.agent_hackathon.consts import PROJECT_ROOT_DIR
@@ -23,7 +22,6 @@ class RetrieverEngineBuilder:
23
  self,
24
  hf_token_env: str = "HF_TOKEN",
25
  embedding_model: str = "Qwen/Qwen3-Embedding-0.6B",
26
- llm_model: str = "meta-llama/Llama-4-Scout-17B-16E-Instruct",
27
  vector_store: MilvusVectorStore = None,
28
  device: str = "cpu",
29
  ) -> None:
@@ -33,27 +31,21 @@ class RetrieverEngineBuilder:
33
  Args:
34
  hf_token_env: Environment variable name for HuggingFace token.
35
  embedding_model: Name of the embedding model.
36
- llm_model: Name of the LLM model.
37
  vector_store: An instance of MilvusVectorStore.
38
  device: Device to run the embedding model on.
39
  """
40
  self.hf_token_env = hf_token_env
41
  self.embedding_model = embedding_model
42
- self.llm_model = llm_model
43
  self.vector_store = vector_store
44
  self.device = device
45
 
46
  logger.info("Initializing RetrieverEngineBuilder.")
47
- self._login_huggingface()
48
- self._load_env()
49
 
50
  self.embed_model = HuggingFaceEmbedding(
51
  model_name=self.embedding_model, device=self.device
52
  )
53
- self.llm = HuggingFaceInferenceAPI(
54
- model=self.llm_model,
55
- provider="auto",
56
- )
57
  logger.info("RetrieverEngineBuilder initialized.")
58
 
59
  def _login_huggingface(self) -> None:
@@ -65,7 +57,7 @@ class RetrieverEngineBuilder:
65
  def _load_env(self) -> None:
66
  """Load environment variables from .env file."""
67
  logger.info("Loading environment variables.")
68
- _ = load_dotenv(dotenv_path=find_dotenv(raise_error_if_not_found=True))
69
  logger.info("Environment variables loaded.")
70
 
71
  def build_retriever_engine(self) -> Any:
 
5
  from huggingface_hub import login
6
  from llama_index.core import VectorStoreIndex
7
  from llama_index.embeddings.huggingface import HuggingFaceEmbedding
 
8
  from llama_index.vector_stores.milvus import MilvusVectorStore
9
 
10
  from src.agent_hackathon.consts import PROJECT_ROOT_DIR
 
22
  self,
23
  hf_token_env: str = "HF_TOKEN",
24
  embedding_model: str = "Qwen/Qwen3-Embedding-0.6B",
 
25
  vector_store: MilvusVectorStore = None,
26
  device: str = "cpu",
27
  ) -> None:
 
31
  Args:
32
  hf_token_env: Environment variable name for HuggingFace token.
33
  embedding_model: Name of the embedding model.
 
34
  vector_store: An instance of MilvusVectorStore.
35
  device: Device to run the embedding model on.
36
  """
37
  self.hf_token_env = hf_token_env
38
  self.embedding_model = embedding_model
 
39
  self.vector_store = vector_store
40
  self.device = device
41
 
42
  logger.info("Initializing RetrieverEngineBuilder.")
43
+ # self._login_huggingface()
44
+ # self._load_env()
45
 
46
  self.embed_model = HuggingFaceEmbedding(
47
  model_name=self.embedding_model, device=self.device
48
  )
 
 
 
 
49
  logger.info("RetrieverEngineBuilder initialized.")
50
 
51
  def _login_huggingface(self) -> None:
 
57
  def _load_env(self) -> None:
58
  """Load environment variables from .env file."""
59
  logger.info("Loading environment variables.")
60
+ _ = load_dotenv(dotenv_path=find_dotenv(raise_error_if_not_found=False))
61
  logger.info("Environment variables loaded.")
62
 
63
  def build_retriever_engine(self) -> Any: