adrienbrdne commited on
Commit
dbe6919
·
verified ·
1 Parent(s): 79c56bc

Update ki_gen/data_retriever.py

Browse files
Files changed (1) hide show
  1. ki_gen/data_retriever.py +44 -36
ki_gen/data_retriever.py CHANGED
@@ -7,9 +7,9 @@ from random import shuffle, sample
7
  from langgraph.checkpoint.sqlite import SqliteSaver
8
 
9
  # Remove ChatGroq import
10
- # from langchain_groq import ChatGroq
11
  # Add ChatGoogleGenerativeAI import
12
- from langchain_google_genai import ChatGoogleGenerativeAI
13
  import os # Add os import
14
 
15
  from langchain_openai import ChatOpenAI
@@ -21,21 +21,20 @@ from langchain_core.prompts import ChatPromptTemplate
21
  from langchain_core.pydantic_v1 import Field
22
  from pydantic import BaseModel
23
 
24
- from neo4j import GraphDatabase
25
 
26
  from langgraph.graph import StateGraph
27
 
28
  from llmlingua import PromptCompressor
29
 
30
  from ki_gen.prompts import (
31
- CYPHER_GENERATION_PROMPT,
32
  CONCEPT_SELECTION_PROMPT,
33
  BINARY_GRADER_PROMPT,
34
  SCORE_GRADER_PROMPT,
35
  RELEVANT_CONCEPTS_PROMPT,
36
  )
37
  # Import get_model which now handles Gemini
38
- from ki_gen.utils import ConfigSchema, DocRetrieverState, get_model, format_doc
39
 
40
 
41
  # ... (extract_cypher remains the same)
@@ -99,7 +98,7 @@ def get_concepts(graph: Neo4jGraph):
99
  def get_related_concepts(graph: Neo4jGraph, question: str):
100
  concepts = get_concepts(graph)
101
  # Use get_model
102
- llm = get_model()
103
  print(f"this is the llm variable : {llm}")
104
  def parse_answer(llm_answer : str):
105
  try:
@@ -113,7 +112,7 @@ def get_related_concepts(graph: Neo4jGraph, question: str):
113
 
114
  print(f"This is the question of the user : {question}")
115
  print(f"This is the concepts of the user : {concepts}")
116
-
117
  # Remove specific Groq error handling block
118
  try:
119
  related_concepts_raw = related_concepts_chain.invoke({"user_query" : question, "concepts" : '\n'.join(concepts)})
@@ -148,7 +147,7 @@ def build_concept_string(graph: Neo4jGraph, concept_list: list[str]):
148
  MATCH (c:Concept {{name: "{concept}" }}) RETURN c.description
149
  """
150
  concept_description = graph.query(concept_description_query)[0]['c.description']
151
- concept_string += f"name: {concept}\ndescription: {concept_description}\n\n"
152
  return concept_string
153
 
154
  def get_global_concepts(graph: Neo4jGraph):
@@ -167,12 +166,20 @@ def generate_cypher(state: DocRetrieverState, config: ConfigSchema):
167
  """
168
  The node where the cypher is generated
169
  """
170
- #graph = config["configurable"].get("graph")
171
- NEO4J_URI = "neo4j+s://4985272f.databases.neo4j.io"
172
- NEO4J_USERNAME = "neo4j"
173
- NEO4J_PASSWORD = os.getenv("neo4j_password")
174
- graph = GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USERNAME, NEO4J_PASSWORD))
175
- question = state['query']
 
 
 
 
 
 
 
 
176
  related_concepts = get_related_concepts(graph, question)
177
  cyphers = []
178
 
@@ -183,15 +190,18 @@ def generate_cypher(state: DocRetrieverState, config: ConfigSchema):
183
  "question": question,
184
  "concepts": related_concepts
185
  })
186
-
187
  # Remove specific Groq error handling block
188
  try:
189
  if config["configurable"].get("cypher_gen_method") == 'guided':
190
  concept_selection_chain = get_concept_selection_chain()
191
  print(f"Concept selection chain is : {concept_selection_chain}")
 
192
  selected_topic = concept_selection_chain.invoke({"question" : question, "concepts": get_concepts(graph)})
193
  print(f"Selected topic are : {selected_topic}")
194
- cyphers = [generate_cypher_from_topic(selected_topic, state['current_plan_step'])]
 
 
195
  print(f"Cyphers are : {cyphers}")
196
 
197
  except Exception as e:
@@ -205,7 +215,7 @@ def generate_cypher(state: DocRetrieverState, config: ConfigSchema):
205
  corrector_schema = [Schema(el["start"], el["type"], el["end"]) for el in graph.structured_schema.get("relationships", [])]
206
  cypher_corrector = CypherQueryCorrector(corrector_schema)
207
  # Apply corrector only if cyphers were generated
208
- if cyphers:
209
  try:
210
  cyphers = [cypher_corrector(cypher) for cypher in cyphers]
211
  except Exception as corr_e:
@@ -214,9 +224,10 @@ def generate_cypher(state: DocRetrieverState, config: ConfigSchema):
214
  else:
215
  print("Warning: Cypher validation skipped, graph or schema unavailable.")
216
 
217
- graph.close()
218
  return {"cyphers" : cyphers}
219
 
 
220
  # ... (generate_cypher_from_topic, get_docs remain the same)
221
  def generate_cypher_from_topic(selected_concept: str, plan_step: int):
222
  """
@@ -232,25 +243,21 @@ def generate_cypher_from_topic(selected_concept: str, plan_step: int):
232
  cypher_el = "(rp:ResearchPaper) RETURN rp.title, rp.abstract"
233
  case 2:
234
  cypher_el = "(ki:KeyIssue) RETURN ki.description"
235
- return f"MATCH (c:Concept {{name:'{selected_concept}'}})-[:RELATED_TO]-{cypher_el}"
236
 
237
  def get_docs(state:DocRetrieverState, config:ConfigSchema):
238
  """
239
  This node retrieves docs from the graph using the generated cypher
240
  """
241
- #graph = config["configurable"].get("graph")
242
- NEO4J_URI = "neo4j+s://4985272f.databases.neo4j.io"
243
- NEO4J_USERNAME = "neo4j"
244
- NEO4J_PASSWORD = os.getenv("neo4j_password")
245
- graph = GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USERNAME, NEO4J_PASSWORD))
246
  output = []
247
  if graph is not None and state.get("cyphers"): # Check if cyphers exist
248
  for cypher in state["cyphers"]:
249
  try:
250
  output = graph.query(cypher)
251
  # Assuming the first successful query is sufficient
252
- if output:
253
- break
254
  except Exception as e:
255
  print(f"Failed to retrieve docs with cypher '{cypher}': {e}")
256
  # Continue to try next cypher if one fails
@@ -264,13 +271,13 @@ def get_docs(state:DocRetrieverState, config:ConfigSchema):
264
  for key in doc:
265
  if isinstance(doc[key], dict):
266
  # If a value is a dict, treat it as a separate document
267
- all_docs.append(doc[key])
268
  else:
269
  unwinded_doc.update({key: doc[key]})
270
  # Add the unwinded parts if any keys were not dictionaries
271
- if unwinded_doc:
272
  all_docs.append(unwinded_doc)
273
-
274
  filtered_docs = []
275
  seen_docs = set() # Use a set for faster duplicate checking based on a unique identifier
276
 
@@ -278,7 +285,7 @@ def get_docs(state:DocRetrieverState, config:ConfigSchema):
278
  # Create a tuple of items to check for duplicates, assuming dicts are hashable
279
  # If dicts contain unhashable types (like lists), convert them to strings or use a primary key
280
  try:
281
- doc_tuple = tuple(sorted(doc.items()))
282
  if doc_tuple not in seen_docs:
283
  filtered_docs.append(doc)
284
  seen_docs.add(doc_tuple)
@@ -290,7 +297,7 @@ def get_docs(state:DocRetrieverState, config:ConfigSchema):
290
  filtered_docs.append(doc)
291
  seen_docs.add(doc_str)
292
 
293
- graph.close()
294
  return {"docs": filtered_docs}
295
 
296
 
@@ -385,13 +392,13 @@ def eval_doc(doc, query, method="binary", threshold=0.7, eval_model="gemini-2.0-
385
  # Update default model
386
  def eval_docs(state: DocRetrieverState, config: ConfigSchema):
387
  """
388
- This node performs evaluation of the retrieved docs and
389
  """
390
 
391
  eval_method = config["configurable"].get("eval_method") or "binary"
392
  MAX_DOCS = config["configurable"].get("max_docs") or 15
393
  # Update default model name
394
- eval_model_name = config["configurable"].get("eval_model") or "gemini-2.0-flash"
395
  valid_doc_scores = []
396
 
397
  # Ensure 'docs' exists and is a list
@@ -419,7 +426,7 @@ def eval_docs(state: DocRetrieverState, config: ConfigSchema):
419
 
420
  score = eval_doc(
421
  doc=formatted_doc_str,
422
- query=state["query"],
423
  method=eval_method,
424
  threshold=config["configurable"].get("eval_threshold") or 0.7,
425
  eval_model=eval_model_name # Pass the eval_model name
@@ -431,7 +438,7 @@ def eval_docs(state: DocRetrieverState, config: ConfigSchema):
431
  else:
432
  print(f"Warning: Received non-numeric score ({score}) for doc {doc}, skipping.")
433
 
434
-
435
  if eval_method == 'score':
436
  # Get at most MAX_DOCS items with the highest score if score method was used
437
  valid_docs_sorted = sorted(valid_doc_scores, key=lambda x: x[1], reverse=True) # Sort descending
@@ -454,7 +461,7 @@ def build_data_retriever_graph(memory):
454
  """
455
  Builds the data_retriever graph
456
  """
457
- #with SqliteSaver.from_conn_string(":memory:") as memory :
458
 
459
  graph_builder_doc_retriever = StateGraph(DocRetrieverState)
460
 
@@ -469,6 +476,7 @@ def build_data_retriever_graph(memory):
469
  graph_builder_doc_retriever.add_edge("eval_docs", "__end__")
470
 
471
  graph_doc_retriever = graph_builder_doc_retriever.compile(checkpointer=memory)
 
472
  return graph_doc_retriever
473
 
474
  # Remove Groq specific error handling function
 
7
  from langgraph.checkpoint.sqlite import SqliteSaver
8
 
9
  # Remove ChatGroq import
10
+ # from langchain_groq import ChatGroq
11
  # Add ChatGoogleGenerativeAI import
12
+ from langchain_google_genai import ChatGoogleGenerativeAI
13
  import os # Add os import
14
 
15
  from langchain_openai import ChatOpenAI
 
21
  from langchain_core.pydantic_v1 import Field
22
  from pydantic import BaseModel
23
 
 
24
 
25
  from langgraph.graph import StateGraph
26
 
27
  from llmlingua import PromptCompressor
28
 
29
  from ki_gen.prompts import (
30
+ CYPHER_GENERATION_PROMPT,
31
  CONCEPT_SELECTION_PROMPT,
32
  BINARY_GRADER_PROMPT,
33
  SCORE_GRADER_PROMPT,
34
  RELEVANT_CONCEPTS_PROMPT,
35
  )
36
  # Import get_model which now handles Gemini
37
+ from ki_gen.utils import ConfigSchema, DocRetrieverState, get_model, format_doc
38
 
39
 
40
  # ... (extract_cypher remains the same)
 
98
  def get_related_concepts(graph: Neo4jGraph, question: str):
99
  concepts = get_concepts(graph)
100
  # Use get_model
101
+ llm = get_model()
102
  print(f"this is the llm variable : {llm}")
103
  def parse_answer(llm_answer : str):
104
  try:
 
112
 
113
  print(f"This is the question of the user : {question}")
114
  print(f"This is the concepts of the user : {concepts}")
115
+
116
  # Remove specific Groq error handling block
117
  try:
118
  related_concepts_raw = related_concepts_chain.invoke({"user_query" : question, "concepts" : '\n'.join(concepts)})
 
147
  MATCH (c:Concept {{name: "{concept}" }}) RETURN c.description
148
  """
149
  concept_description = graph.query(concept_description_query)[0]['c.description']
150
+ concept_string += f"name: {concept}\ndescription: {concept_description}\n\n"
151
  return concept_string
152
 
153
  def get_global_concepts(graph: Neo4jGraph):
 
166
  """
167
  The node where the cypher is generated
168
  """
169
+ graph = config["configurable"].get("graph")
170
+
171
+ # --- Correction Applied Here ---
172
+ # Use .get() for safer access to 'query'
173
+ question = state.get('query')
174
+ if not question:
175
+ # Handle the case where query is missing
176
+ print("Error: 'query' key not found in state for generate_cypher node.")
177
+ # Return an empty list or appropriate error state
178
+ # This prevents the KeyError and stops processing for this branch if query is missing
179
+ return {"cyphers": []}
180
+ # --- End of Correction ---
181
+
182
+
183
  related_concepts = get_related_concepts(graph, question)
184
  cyphers = []
185
 
 
190
  "question": question,
191
  "concepts": related_concepts
192
  })
193
+
194
  # Remove specific Groq error handling block
195
  try:
196
  if config["configurable"].get("cypher_gen_method") == 'guided':
197
  concept_selection_chain = get_concept_selection_chain()
198
  print(f"Concept selection chain is : {concept_selection_chain}")
199
+ # Ensure 'current_plan_step' is also safely accessed if needed here, though it's used later
200
  selected_topic = concept_selection_chain.invoke({"question" : question, "concepts": get_concepts(graph)})
201
  print(f"Selected topic are : {selected_topic}")
202
+ # Safely get 'current_plan_step', defaulting to 0 if not found
203
+ current_plan_step = state.get('current_plan_step', 0)
204
+ cyphers = [generate_cypher_from_topic(selected_topic, current_plan_step)]
205
  print(f"Cyphers are : {cyphers}")
206
 
207
  except Exception as e:
 
215
  corrector_schema = [Schema(el["start"], el["type"], el["end"]) for el in graph.structured_schema.get("relationships", [])]
216
  cypher_corrector = CypherQueryCorrector(corrector_schema)
217
  # Apply corrector only if cyphers were generated
218
+ if cyphers:
219
  try:
220
  cyphers = [cypher_corrector(cypher) for cypher in cyphers]
221
  except Exception as corr_e:
 
224
  else:
225
  print("Warning: Cypher validation skipped, graph or schema unavailable.")
226
 
227
+
228
  return {"cyphers" : cyphers}
229
 
230
+
231
  # ... (generate_cypher_from_topic, get_docs remain the same)
232
  def generate_cypher_from_topic(selected_concept: str, plan_step: int):
233
  """
 
243
  cypher_el = "(rp:ResearchPaper) RETURN rp.title, rp.abstract"
244
  case 2:
245
  cypher_el = "(ki:KeyIssue) RETURN ki.description"
246
+ return f"MATCH (c:Concept {{name:'{selected_concept}'}})-[:RELATED_TO]-{cypher_el}"
247
 
248
  def get_docs(state:DocRetrieverState, config:ConfigSchema):
249
  """
250
  This node retrieves docs from the graph using the generated cypher
251
  """
252
+ graph = config["configurable"].get("graph")
 
 
 
 
253
  output = []
254
  if graph is not None and state.get("cyphers"): # Check if cyphers exist
255
  for cypher in state["cyphers"]:
256
  try:
257
  output = graph.query(cypher)
258
  # Assuming the first successful query is sufficient
259
+ if output:
260
+ break
261
  except Exception as e:
262
  print(f"Failed to retrieve docs with cypher '{cypher}': {e}")
263
  # Continue to try next cypher if one fails
 
271
  for key in doc:
272
  if isinstance(doc[key], dict):
273
  # If a value is a dict, treat it as a separate document
274
+ all_docs.append(doc[key])
275
  else:
276
  unwinded_doc.update({key: doc[key]})
277
  # Add the unwinded parts if any keys were not dictionaries
278
+ if unwinded_doc:
279
  all_docs.append(unwinded_doc)
280
+
281
  filtered_docs = []
282
  seen_docs = set() # Use a set for faster duplicate checking based on a unique identifier
283
 
 
285
  # Create a tuple of items to check for duplicates, assuming dicts are hashable
286
  # If dicts contain unhashable types (like lists), convert them to strings or use a primary key
287
  try:
288
+ doc_tuple = tuple(sorted(doc.items()))
289
  if doc_tuple not in seen_docs:
290
  filtered_docs.append(doc)
291
  seen_docs.add(doc_tuple)
 
297
  filtered_docs.append(doc)
298
  seen_docs.add(doc_str)
299
 
300
+
301
  return {"docs": filtered_docs}
302
 
303
 
 
392
  # Update default model
393
  def eval_docs(state: DocRetrieverState, config: ConfigSchema):
394
  """
395
+ This node performs evaluation of the retrieved docs and
396
  """
397
 
398
  eval_method = config["configurable"].get("eval_method") or "binary"
399
  MAX_DOCS = config["configurable"].get("max_docs") or 15
400
  # Update default model name
401
+ eval_model_name = config["configurable"].get("eval_model") or "gemini-2.0-flash"
402
  valid_doc_scores = []
403
 
404
  # Ensure 'docs' exists and is a list
 
426
 
427
  score = eval_doc(
428
  doc=formatted_doc_str,
429
+ query=state["query"], # This line assumes "query" exists in state
430
  method=eval_method,
431
  threshold=config["configurable"].get("eval_threshold") or 0.7,
432
  eval_model=eval_model_name # Pass the eval_model name
 
438
  else:
439
  print(f"Warning: Received non-numeric score ({score}) for doc {doc}, skipping.")
440
 
441
+
442
  if eval_method == 'score':
443
  # Get at most MAX_DOCS items with the highest score if score method was used
444
  valid_docs_sorted = sorted(valid_doc_scores, key=lambda x: x[1], reverse=True) # Sort descending
 
461
  """
462
  Builds the data_retriever graph
463
  """
464
+ #with SqliteSaver.from_conn_string(":memory:") as memory :
465
 
466
  graph_builder_doc_retriever = StateGraph(DocRetrieverState)
467
 
 
476
  graph_builder_doc_retriever.add_edge("eval_docs", "__end__")
477
 
478
  graph_doc_retriever = graph_builder_doc_retriever.compile(checkpointer=memory)
479
+
480
  return graph_doc_retriever
481
 
482
  # Remove Groq specific error handling function