Spaces:
Sleeping
Sleeping
Update ki_gen/data_retriever.py
Browse files- ki_gen/data_retriever.py +12 -4
ki_gen/data_retriever.py
CHANGED
@@ -166,7 +166,11 @@ def generate_cypher(state: DocRetrieverState, config: ConfigSchema):
|
|
166 |
"""
|
167 |
The node where the cypher is generated
|
168 |
"""
|
169 |
-
graph = config["configurable"].get("graph")
|
|
|
|
|
|
|
|
|
170 |
question = state['query']
|
171 |
related_concepts = get_related_concepts(graph, question)
|
172 |
cyphers = []
|
@@ -209,7 +213,7 @@ def generate_cypher(state: DocRetrieverState, config: ConfigSchema):
|
|
209 |
else:
|
210 |
print("Warning: Cypher validation skipped, graph or schema unavailable.")
|
211 |
|
212 |
-
|
213 |
return {"cyphers" : cyphers}
|
214 |
|
215 |
# ... (generate_cypher_from_topic, get_docs remain the same)
|
@@ -233,7 +237,11 @@ def get_docs(state:DocRetrieverState, config:ConfigSchema):
|
|
233 |
"""
|
234 |
This node retrieves docs from the graph using the generated cypher
|
235 |
"""
|
236 |
-
graph = config["configurable"].get("graph")
|
|
|
|
|
|
|
|
|
237 |
output = []
|
238 |
if graph is not None and state.get("cyphers"): # Check if cyphers exist
|
239 |
for cypher in state["cyphers"]:
|
@@ -281,7 +289,7 @@ def get_docs(state:DocRetrieverState, config:ConfigSchema):
|
|
281 |
filtered_docs.append(doc)
|
282 |
seen_docs.add(doc_str)
|
283 |
|
284 |
-
|
285 |
return {"docs": filtered_docs}
|
286 |
|
287 |
|
|
|
166 |
"""
|
167 |
The node where the cypher is generated
|
168 |
"""
|
169 |
+
#graph = config["configurable"].get("graph")
|
170 |
+
NEO4J_URI = "neo4j+s://4985272f.databases.neo4j.io"
|
171 |
+
NEO4J_USERNAME = "neo4j"
|
172 |
+
NEO4J_PASSWORD = os.getenv("neo4j_password")
|
173 |
+
graph = GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USERNAME, NEO4J_PASSWORD))
|
174 |
question = state['query']
|
175 |
related_concepts = get_related_concepts(graph, question)
|
176 |
cyphers = []
|
|
|
213 |
else:
|
214 |
print("Warning: Cypher validation skipped, graph or schema unavailable.")
|
215 |
|
216 |
+
graph.close()
|
217 |
return {"cyphers" : cyphers}
|
218 |
|
219 |
# ... (generate_cypher_from_topic, get_docs remain the same)
|
|
|
237 |
"""
|
238 |
This node retrieves docs from the graph using the generated cypher
|
239 |
"""
|
240 |
+
#graph = config["configurable"].get("graph")
|
241 |
+
NEO4J_URI = "neo4j+s://4985272f.databases.neo4j.io"
|
242 |
+
NEO4J_USERNAME = "neo4j"
|
243 |
+
NEO4J_PASSWORD = os.getenv("neo4j_password")
|
244 |
+
graph = GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USERNAME, NEO4J_PASSWORD))
|
245 |
output = []
|
246 |
if graph is not None and state.get("cyphers"): # Check if cyphers exist
|
247 |
for cypher in state["cyphers"]:
|
|
|
289 |
filtered_docs.append(doc)
|
290 |
seen_docs.add(doc_str)
|
291 |
|
292 |
+
graph.close()
|
293 |
return {"docs": filtered_docs}
|
294 |
|
295 |
|