Spaces:
Sleeping
Sleeping
Update ki_gen/data_retriever.py
Browse files- ki_gen/data_retriever.py +44 -36
ki_gen/data_retriever.py
CHANGED
@@ -7,9 +7,9 @@ from random import shuffle, sample
|
|
7 |
from langgraph.checkpoint.sqlite import SqliteSaver
|
8 |
|
9 |
# Remove ChatGroq import
|
10 |
-
# from langchain_groq import ChatGroq
|
11 |
# Add ChatGoogleGenerativeAI import
|
12 |
-
from langchain_google_genai import ChatGoogleGenerativeAI
|
13 |
import os # Add os import
|
14 |
|
15 |
from langchain_openai import ChatOpenAI
|
@@ -21,21 +21,20 @@ from langchain_core.prompts import ChatPromptTemplate
|
|
21 |
from langchain_core.pydantic_v1 import Field
|
22 |
from pydantic import BaseModel
|
23 |
|
24 |
-
from neo4j import GraphDatabase
|
25 |
|
26 |
from langgraph.graph import StateGraph
|
27 |
|
28 |
from llmlingua import PromptCompressor
|
29 |
|
30 |
from ki_gen.prompts import (
|
31 |
-
CYPHER_GENERATION_PROMPT,
|
32 |
CONCEPT_SELECTION_PROMPT,
|
33 |
BINARY_GRADER_PROMPT,
|
34 |
SCORE_GRADER_PROMPT,
|
35 |
RELEVANT_CONCEPTS_PROMPT,
|
36 |
)
|
37 |
# Import get_model which now handles Gemini
|
38 |
-
from ki_gen.utils import ConfigSchema, DocRetrieverState, get_model, format_doc
|
39 |
|
40 |
|
41 |
# ... (extract_cypher remains the same)
|
@@ -99,7 +98,7 @@ def get_concepts(graph: Neo4jGraph):
|
|
99 |
def get_related_concepts(graph: Neo4jGraph, question: str):
|
100 |
concepts = get_concepts(graph)
|
101 |
# Use get_model
|
102 |
-
llm = get_model()
|
103 |
print(f"this is the llm variable : {llm}")
|
104 |
def parse_answer(llm_answer : str):
|
105 |
try:
|
@@ -113,7 +112,7 @@ def get_related_concepts(graph: Neo4jGraph, question: str):
|
|
113 |
|
114 |
print(f"This is the question of the user : {question}")
|
115 |
print(f"This is the concepts of the user : {concepts}")
|
116 |
-
|
117 |
# Remove specific Groq error handling block
|
118 |
try:
|
119 |
related_concepts_raw = related_concepts_chain.invoke({"user_query" : question, "concepts" : '\n'.join(concepts)})
|
@@ -148,7 +147,7 @@ def build_concept_string(graph: Neo4jGraph, concept_list: list[str]):
|
|
148 |
MATCH (c:Concept {{name: "{concept}" }}) RETURN c.description
|
149 |
"""
|
150 |
concept_description = graph.query(concept_description_query)[0]['c.description']
|
151 |
-
concept_string += f"name: {concept}\ndescription: {concept_description}\n\n"
|
152 |
return concept_string
|
153 |
|
154 |
def get_global_concepts(graph: Neo4jGraph):
|
@@ -167,12 +166,20 @@ def generate_cypher(state: DocRetrieverState, config: ConfigSchema):
|
|
167 |
"""
|
168 |
The node where the cypher is generated
|
169 |
"""
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
176 |
related_concepts = get_related_concepts(graph, question)
|
177 |
cyphers = []
|
178 |
|
@@ -183,15 +190,18 @@ def generate_cypher(state: DocRetrieverState, config: ConfigSchema):
|
|
183 |
"question": question,
|
184 |
"concepts": related_concepts
|
185 |
})
|
186 |
-
|
187 |
# Remove specific Groq error handling block
|
188 |
try:
|
189 |
if config["configurable"].get("cypher_gen_method") == 'guided':
|
190 |
concept_selection_chain = get_concept_selection_chain()
|
191 |
print(f"Concept selection chain is : {concept_selection_chain}")
|
|
|
192 |
selected_topic = concept_selection_chain.invoke({"question" : question, "concepts": get_concepts(graph)})
|
193 |
print(f"Selected topic are : {selected_topic}")
|
194 |
-
|
|
|
|
|
195 |
print(f"Cyphers are : {cyphers}")
|
196 |
|
197 |
except Exception as e:
|
@@ -205,7 +215,7 @@ def generate_cypher(state: DocRetrieverState, config: ConfigSchema):
|
|
205 |
corrector_schema = [Schema(el["start"], el["type"], el["end"]) for el in graph.structured_schema.get("relationships", [])]
|
206 |
cypher_corrector = CypherQueryCorrector(corrector_schema)
|
207 |
# Apply corrector only if cyphers were generated
|
208 |
-
if cyphers:
|
209 |
try:
|
210 |
cyphers = [cypher_corrector(cypher) for cypher in cyphers]
|
211 |
except Exception as corr_e:
|
@@ -214,9 +224,10 @@ def generate_cypher(state: DocRetrieverState, config: ConfigSchema):
|
|
214 |
else:
|
215 |
print("Warning: Cypher validation skipped, graph or schema unavailable.")
|
216 |
|
217 |
-
|
218 |
return {"cyphers" : cyphers}
|
219 |
|
|
|
220 |
# ... (generate_cypher_from_topic, get_docs remain the same)
|
221 |
def generate_cypher_from_topic(selected_concept: str, plan_step: int):
|
222 |
"""
|
@@ -232,25 +243,21 @@ def generate_cypher_from_topic(selected_concept: str, plan_step: int):
|
|
232 |
cypher_el = "(rp:ResearchPaper) RETURN rp.title, rp.abstract"
|
233 |
case 2:
|
234 |
cypher_el = "(ki:KeyIssue) RETURN ki.description"
|
235 |
-
return f"MATCH (c:Concept {{name:'{selected_concept}'}})-[:RELATED_TO]-{cypher_el}"
|
236 |
|
237 |
def get_docs(state:DocRetrieverState, config:ConfigSchema):
|
238 |
"""
|
239 |
This node retrieves docs from the graph using the generated cypher
|
240 |
"""
|
241 |
-
|
242 |
-
NEO4J_URI = "neo4j+s://4985272f.databases.neo4j.io"
|
243 |
-
NEO4J_USERNAME = "neo4j"
|
244 |
-
NEO4J_PASSWORD = os.getenv("neo4j_password")
|
245 |
-
graph = GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USERNAME, NEO4J_PASSWORD))
|
246 |
output = []
|
247 |
if graph is not None and state.get("cyphers"): # Check if cyphers exist
|
248 |
for cypher in state["cyphers"]:
|
249 |
try:
|
250 |
output = graph.query(cypher)
|
251 |
# Assuming the first successful query is sufficient
|
252 |
-
if output:
|
253 |
-
break
|
254 |
except Exception as e:
|
255 |
print(f"Failed to retrieve docs with cypher '{cypher}': {e}")
|
256 |
# Continue to try next cypher if one fails
|
@@ -264,13 +271,13 @@ def get_docs(state:DocRetrieverState, config:ConfigSchema):
|
|
264 |
for key in doc:
|
265 |
if isinstance(doc[key], dict):
|
266 |
# If a value is a dict, treat it as a separate document
|
267 |
-
all_docs.append(doc[key])
|
268 |
else:
|
269 |
unwinded_doc.update({key: doc[key]})
|
270 |
# Add the unwinded parts if any keys were not dictionaries
|
271 |
-
if unwinded_doc:
|
272 |
all_docs.append(unwinded_doc)
|
273 |
-
|
274 |
filtered_docs = []
|
275 |
seen_docs = set() # Use a set for faster duplicate checking based on a unique identifier
|
276 |
|
@@ -278,7 +285,7 @@ def get_docs(state:DocRetrieverState, config:ConfigSchema):
|
|
278 |
# Create a tuple of items to check for duplicates, assuming dicts are hashable
|
279 |
# If dicts contain unhashable types (like lists), convert them to strings or use a primary key
|
280 |
try:
|
281 |
-
doc_tuple = tuple(sorted(doc.items()))
|
282 |
if doc_tuple not in seen_docs:
|
283 |
filtered_docs.append(doc)
|
284 |
seen_docs.add(doc_tuple)
|
@@ -290,7 +297,7 @@ def get_docs(state:DocRetrieverState, config:ConfigSchema):
|
|
290 |
filtered_docs.append(doc)
|
291 |
seen_docs.add(doc_str)
|
292 |
|
293 |
-
|
294 |
return {"docs": filtered_docs}
|
295 |
|
296 |
|
@@ -385,13 +392,13 @@ def eval_doc(doc, query, method="binary", threshold=0.7, eval_model="gemini-2.0-
|
|
385 |
# Update default model
|
386 |
def eval_docs(state: DocRetrieverState, config: ConfigSchema):
|
387 |
"""
|
388 |
-
This node performs evaluation of the retrieved docs and
|
389 |
"""
|
390 |
|
391 |
eval_method = config["configurable"].get("eval_method") or "binary"
|
392 |
MAX_DOCS = config["configurable"].get("max_docs") or 15
|
393 |
# Update default model name
|
394 |
-
eval_model_name = config["configurable"].get("eval_model") or "gemini-2.0-flash"
|
395 |
valid_doc_scores = []
|
396 |
|
397 |
# Ensure 'docs' exists and is a list
|
@@ -419,7 +426,7 @@ def eval_docs(state: DocRetrieverState, config: ConfigSchema):
|
|
419 |
|
420 |
score = eval_doc(
|
421 |
doc=formatted_doc_str,
|
422 |
-
query=state["query"],
|
423 |
method=eval_method,
|
424 |
threshold=config["configurable"].get("eval_threshold") or 0.7,
|
425 |
eval_model=eval_model_name # Pass the eval_model name
|
@@ -431,7 +438,7 @@ def eval_docs(state: DocRetrieverState, config: ConfigSchema):
|
|
431 |
else:
|
432 |
print(f"Warning: Received non-numeric score ({score}) for doc {doc}, skipping.")
|
433 |
|
434 |
-
|
435 |
if eval_method == 'score':
|
436 |
# Get at most MAX_DOCS items with the highest score if score method was used
|
437 |
valid_docs_sorted = sorted(valid_doc_scores, key=lambda x: x[1], reverse=True) # Sort descending
|
@@ -454,7 +461,7 @@ def build_data_retriever_graph(memory):
|
|
454 |
"""
|
455 |
Builds the data_retriever graph
|
456 |
"""
|
457 |
-
#with SqliteSaver.from_conn_string(":memory:") as memory :
|
458 |
|
459 |
graph_builder_doc_retriever = StateGraph(DocRetrieverState)
|
460 |
|
@@ -469,6 +476,7 @@ def build_data_retriever_graph(memory):
|
|
469 |
graph_builder_doc_retriever.add_edge("eval_docs", "__end__")
|
470 |
|
471 |
graph_doc_retriever = graph_builder_doc_retriever.compile(checkpointer=memory)
|
|
|
472 |
return graph_doc_retriever
|
473 |
|
474 |
# Remove Groq specific error handling function
|
|
|
7 |
from langgraph.checkpoint.sqlite import SqliteSaver
|
8 |
|
9 |
# Remove ChatGroq import
|
10 |
+
# from langchain_groq import ChatGroq
|
11 |
# Add ChatGoogleGenerativeAI import
|
12 |
+
from langchain_google_genai import ChatGoogleGenerativeAI
|
13 |
import os # Add os import
|
14 |
|
15 |
from langchain_openai import ChatOpenAI
|
|
|
21 |
from langchain_core.pydantic_v1 import Field
|
22 |
from pydantic import BaseModel
|
23 |
|
|
|
24 |
|
25 |
from langgraph.graph import StateGraph
|
26 |
|
27 |
from llmlingua import PromptCompressor
|
28 |
|
29 |
from ki_gen.prompts import (
|
30 |
+
CYPHER_GENERATION_PROMPT,
|
31 |
CONCEPT_SELECTION_PROMPT,
|
32 |
BINARY_GRADER_PROMPT,
|
33 |
SCORE_GRADER_PROMPT,
|
34 |
RELEVANT_CONCEPTS_PROMPT,
|
35 |
)
|
36 |
# Import get_model which now handles Gemini
|
37 |
+
from ki_gen.utils import ConfigSchema, DocRetrieverState, get_model, format_doc
|
38 |
|
39 |
|
40 |
# ... (extract_cypher remains the same)
|
|
|
98 |
def get_related_concepts(graph: Neo4jGraph, question: str):
|
99 |
concepts = get_concepts(graph)
|
100 |
# Use get_model
|
101 |
+
llm = get_model()
|
102 |
print(f"this is the llm variable : {llm}")
|
103 |
def parse_answer(llm_answer : str):
|
104 |
try:
|
|
|
112 |
|
113 |
print(f"This is the question of the user : {question}")
|
114 |
print(f"This is the concepts of the user : {concepts}")
|
115 |
+
|
116 |
# Remove specific Groq error handling block
|
117 |
try:
|
118 |
related_concepts_raw = related_concepts_chain.invoke({"user_query" : question, "concepts" : '\n'.join(concepts)})
|
|
|
147 |
MATCH (c:Concept {{name: "{concept}" }}) RETURN c.description
|
148 |
"""
|
149 |
concept_description = graph.query(concept_description_query)[0]['c.description']
|
150 |
+
concept_string += f"name: {concept}\ndescription: {concept_description}\n\n"
|
151 |
return concept_string
|
152 |
|
153 |
def get_global_concepts(graph: Neo4jGraph):
|
|
|
166 |
"""
|
167 |
The node where the cypher is generated
|
168 |
"""
|
169 |
+
graph = config["configurable"].get("graph")
|
170 |
+
|
171 |
+
# --- Correction Applied Here ---
|
172 |
+
# Use .get() for safer access to 'query'
|
173 |
+
question = state.get('query')
|
174 |
+
if not question:
|
175 |
+
# Handle the case where query is missing
|
176 |
+
print("Error: 'query' key not found in state for generate_cypher node.")
|
177 |
+
# Return an empty list or appropriate error state
|
178 |
+
# This prevents the KeyError and stops processing for this branch if query is missing
|
179 |
+
return {"cyphers": []}
|
180 |
+
# --- End of Correction ---
|
181 |
+
|
182 |
+
|
183 |
related_concepts = get_related_concepts(graph, question)
|
184 |
cyphers = []
|
185 |
|
|
|
190 |
"question": question,
|
191 |
"concepts": related_concepts
|
192 |
})
|
193 |
+
|
194 |
# Remove specific Groq error handling block
|
195 |
try:
|
196 |
if config["configurable"].get("cypher_gen_method") == 'guided':
|
197 |
concept_selection_chain = get_concept_selection_chain()
|
198 |
print(f"Concept selection chain is : {concept_selection_chain}")
|
199 |
+
# Ensure 'current_plan_step' is also safely accessed if needed here, though it's used later
|
200 |
selected_topic = concept_selection_chain.invoke({"question" : question, "concepts": get_concepts(graph)})
|
201 |
print(f"Selected topic are : {selected_topic}")
|
202 |
+
# Safely get 'current_plan_step', defaulting to 0 if not found
|
203 |
+
current_plan_step = state.get('current_plan_step', 0)
|
204 |
+
cyphers = [generate_cypher_from_topic(selected_topic, current_plan_step)]
|
205 |
print(f"Cyphers are : {cyphers}")
|
206 |
|
207 |
except Exception as e:
|
|
|
215 |
corrector_schema = [Schema(el["start"], el["type"], el["end"]) for el in graph.structured_schema.get("relationships", [])]
|
216 |
cypher_corrector = CypherQueryCorrector(corrector_schema)
|
217 |
# Apply corrector only if cyphers were generated
|
218 |
+
if cyphers:
|
219 |
try:
|
220 |
cyphers = [cypher_corrector(cypher) for cypher in cyphers]
|
221 |
except Exception as corr_e:
|
|
|
224 |
else:
|
225 |
print("Warning: Cypher validation skipped, graph or schema unavailable.")
|
226 |
|
227 |
+
|
228 |
return {"cyphers" : cyphers}
|
229 |
|
230 |
+
|
231 |
# ... (generate_cypher_from_topic, get_docs remain the same)
|
232 |
def generate_cypher_from_topic(selected_concept: str, plan_step: int):
|
233 |
"""
|
|
|
243 |
cypher_el = "(rp:ResearchPaper) RETURN rp.title, rp.abstract"
|
244 |
case 2:
|
245 |
cypher_el = "(ki:KeyIssue) RETURN ki.description"
|
246 |
+
return f"MATCH (c:Concept {{name:'{selected_concept}'}})-[:RELATED_TO]-{cypher_el}"
|
247 |
|
248 |
def get_docs(state:DocRetrieverState, config:ConfigSchema):
|
249 |
"""
|
250 |
This node retrieves docs from the graph using the generated cypher
|
251 |
"""
|
252 |
+
graph = config["configurable"].get("graph")
|
|
|
|
|
|
|
|
|
253 |
output = []
|
254 |
if graph is not None and state.get("cyphers"): # Check if cyphers exist
|
255 |
for cypher in state["cyphers"]:
|
256 |
try:
|
257 |
output = graph.query(cypher)
|
258 |
# Assuming the first successful query is sufficient
|
259 |
+
if output:
|
260 |
+
break
|
261 |
except Exception as e:
|
262 |
print(f"Failed to retrieve docs with cypher '{cypher}': {e}")
|
263 |
# Continue to try next cypher if one fails
|
|
|
271 |
for key in doc:
|
272 |
if isinstance(doc[key], dict):
|
273 |
# If a value is a dict, treat it as a separate document
|
274 |
+
all_docs.append(doc[key])
|
275 |
else:
|
276 |
unwinded_doc.update({key: doc[key]})
|
277 |
# Add the unwinded parts if any keys were not dictionaries
|
278 |
+
if unwinded_doc:
|
279 |
all_docs.append(unwinded_doc)
|
280 |
+
|
281 |
filtered_docs = []
|
282 |
seen_docs = set() # Use a set for faster duplicate checking based on a unique identifier
|
283 |
|
|
|
285 |
# Create a tuple of items to check for duplicates, assuming dicts are hashable
|
286 |
# If dicts contain unhashable types (like lists), convert them to strings or use a primary key
|
287 |
try:
|
288 |
+
doc_tuple = tuple(sorted(doc.items()))
|
289 |
if doc_tuple not in seen_docs:
|
290 |
filtered_docs.append(doc)
|
291 |
seen_docs.add(doc_tuple)
|
|
|
297 |
filtered_docs.append(doc)
|
298 |
seen_docs.add(doc_str)
|
299 |
|
300 |
+
|
301 |
return {"docs": filtered_docs}
|
302 |
|
303 |
|
|
|
392 |
# Update default model
|
393 |
def eval_docs(state: DocRetrieverState, config: ConfigSchema):
|
394 |
"""
|
395 |
+
This node performs evaluation of the retrieved docs and
|
396 |
"""
|
397 |
|
398 |
eval_method = config["configurable"].get("eval_method") or "binary"
|
399 |
MAX_DOCS = config["configurable"].get("max_docs") or 15
|
400 |
# Update default model name
|
401 |
+
eval_model_name = config["configurable"].get("eval_model") or "gemini-2.0-flash"
|
402 |
valid_doc_scores = []
|
403 |
|
404 |
# Ensure 'docs' exists and is a list
|
|
|
426 |
|
427 |
score = eval_doc(
|
428 |
doc=formatted_doc_str,
|
429 |
+
query=state["query"], # This line assumes "query" exists in state
|
430 |
method=eval_method,
|
431 |
threshold=config["configurable"].get("eval_threshold") or 0.7,
|
432 |
eval_model=eval_model_name # Pass the eval_model name
|
|
|
438 |
else:
|
439 |
print(f"Warning: Received non-numeric score ({score}) for doc {doc}, skipping.")
|
440 |
|
441 |
+
|
442 |
if eval_method == 'score':
|
443 |
# Get at most MAX_DOCS items with the highest score if score method was used
|
444 |
valid_docs_sorted = sorted(valid_doc_scores, key=lambda x: x[1], reverse=True) # Sort descending
|
|
|
461 |
"""
|
462 |
Builds the data_retriever graph
|
463 |
"""
|
464 |
+
#with SqliteSaver.from_conn_string(":memory:") as memory :
|
465 |
|
466 |
graph_builder_doc_retriever = StateGraph(DocRetrieverState)
|
467 |
|
|
|
476 |
graph_builder_doc_retriever.add_edge("eval_docs", "__end__")
|
477 |
|
478 |
graph_doc_retriever = graph_builder_doc_retriever.compile(checkpointer=memory)
|
479 |
+
|
480 |
return graph_doc_retriever
|
481 |
|
482 |
# Remove Groq specific error handling function
|