heymenn commited on
Commit
c8b6c15
·
verified ·
1 Parent(s): 4229b61

Update ki_gen/data_retriever.py

Browse files
Files changed (1) hide show
  1. ki_gen/data_retriever.py +12 -4
ki_gen/data_retriever.py CHANGED
@@ -166,7 +166,11 @@ def generate_cypher(state: DocRetrieverState, config: ConfigSchema):
166
  """
167
  The node where the cypher is generated
168
  """
169
- graph = config["configurable"].get("graph")
 
 
 
 
170
  question = state['query']
171
  related_concepts = get_related_concepts(graph, question)
172
  cyphers = []
@@ -209,7 +213,7 @@ def generate_cypher(state: DocRetrieverState, config: ConfigSchema):
209
  else:
210
  print("Warning: Cypher validation skipped, graph or schema unavailable.")
211
 
212
-
213
  return {"cyphers" : cyphers}
214
 
215
  # ... (generate_cypher_from_topic, get_docs remain the same)
@@ -233,7 +237,11 @@ def get_docs(state:DocRetrieverState, config:ConfigSchema):
233
  """
234
  This node retrieves docs from the graph using the generated cypher
235
  """
236
- graph = config["configurable"].get("graph")
 
 
 
 
237
  output = []
238
  if graph is not None and state.get("cyphers"): # Check if cyphers exist
239
  for cypher in state["cyphers"]:
@@ -281,7 +289,7 @@ def get_docs(state:DocRetrieverState, config:ConfigSchema):
281
  filtered_docs.append(doc)
282
  seen_docs.add(doc_str)
283
 
284
-
285
  return {"docs": filtered_docs}
286
 
287
 
 
166
  """
167
  The node where the cypher is generated
168
  """
169
+ #graph = config["configurable"].get("graph")
170
+ NEO4J_URI = "neo4j+s://4985272f.databases.neo4j.io"
171
+ NEO4J_USERNAME = "neo4j"
172
+ NEO4J_PASSWORD = os.getenv("neo4j_password")
173
+ graph = GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USERNAME, NEO4J_PASSWORD))
174
  question = state['query']
175
  related_concepts = get_related_concepts(graph, question)
176
  cyphers = []
 
213
  else:
214
  print("Warning: Cypher validation skipped, graph or schema unavailable.")
215
 
216
+ graph.close()
217
  return {"cyphers" : cyphers}
218
 
219
  # ... (generate_cypher_from_topic, get_docs remain the same)
 
237
  """
238
  This node retrieves docs from the graph using the generated cypher
239
  """
240
+ #graph = config["configurable"].get("graph")
241
+ NEO4J_URI = "neo4j+s://4985272f.databases.neo4j.io"
242
+ NEO4J_USERNAME = "neo4j"
243
+ NEO4J_PASSWORD = os.getenv("neo4j_password")
244
+ graph = GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USERNAME, NEO4J_PASSWORD))
245
  output = []
246
  if graph is not None and state.get("cyphers"): # Check if cyphers exist
247
  for cypher in state["cyphers"]:
 
289
  filtered_docs.append(doc)
290
  seen_docs.add(doc_str)
291
 
292
+ graph.close()
293
  return {"docs": filtered_docs}
294
 
295