Spaces:
Sleeping
Sleeping
Update ki_gen/data_retriever.py
Browse files- ki_gen/data_retriever.py +182 -121
ki_gen/data_retriever.py
CHANGED
@@ -6,7 +6,12 @@ import time
|
|
6 |
from random import shuffle, sample
|
7 |
from langgraph.checkpoint.sqlite import SqliteSaver
|
8 |
|
9 |
-
|
|
|
|
|
|
|
|
|
|
|
10 |
from langchain_openai import ChatOpenAI
|
11 |
from langchain_core.messages import HumanMessage
|
12 |
from langchain_community.graphs import Neo4jGraph
|
@@ -15,7 +20,7 @@ from langchain_core.output_parsers import StrOutputParser
|
|
15 |
from langchain_core.prompts import ChatPromptTemplate
|
16 |
from langchain_core.pydantic_v1 import Field
|
17 |
from pydantic import BaseModel
|
18 |
-
|
19 |
|
20 |
from langgraph.graph import StateGraph
|
21 |
|
@@ -28,11 +33,11 @@ from ki_gen.prompts import (
|
|
28 |
SCORE_GRADER_PROMPT,
|
29 |
RELEVANT_CONCEPTS_PROMPT,
|
30 |
)
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
|
35 |
|
|
|
36 |
def extract_cypher(text: str) -> str:
|
37 |
"""Extract Cypher code from a text.
|
38 |
|
@@ -55,33 +60,28 @@ def extract_cypher(text: str) -> str:
|
|
55 |
text
|
56 |
]
|
57 |
|
58 |
-
|
|
|
59 |
"""
|
60 |
Returns cypher gen chain using specified model for generation
|
61 |
This is used when the 'auto' cypher generation method has been configured
|
62 |
"""
|
63 |
-
|
64 |
-
if model=="openai":
|
65 |
-
llm_cypher_gen = ChatOpenAI(model='gpt-4o', base_url="https://llm.synapse.thalescloud.io/")
|
66 |
-
else:
|
67 |
-
llm_cypher_gen = ChatGroq(model = "deepseek-r1-distill-llama-70b")
|
68 |
cypher_gen_chain = CYPHER_GENERATION_PROMPT | llm_cypher_gen | StrOutputParser() | extract_cypher
|
69 |
return cypher_gen_chain
|
70 |
|
71 |
-
|
|
|
72 |
"""
|
73 |
Returns a chain to select the most relevant topic using specified model for generation.
|
74 |
This is used when the 'guided' cypher generation method has been configured
|
75 |
"""
|
76 |
-
|
77 |
-
if model == "openai":
|
78 |
-
llm_topic_selection = ChatOpenAI(model='gpt-4o', base_url="https://llm.synapse.thalescloud.io/")
|
79 |
-
else:
|
80 |
-
llm_topic_selection = ChatGroq(model="deepseek-r1-distill-llama-70b")
|
81 |
print(f"FOUND LLM TOPIC SELECTION FOR THE CONCEPT SELECTION PROMPT : {llm_topic_selection}")
|
82 |
topic_selection_chain = CONCEPT_SELECTION_PROMPT | llm_topic_selection | StrOutputParser()
|
83 |
return topic_selection_chain
|
84 |
|
|
|
85 |
def get_concepts(graph: Neo4jGraph):
|
86 |
concept_cypher = "MATCH (c:Concept) return c"
|
87 |
if isinstance(graph, Neo4jGraph):
|
@@ -93,37 +93,34 @@ def get_concepts(graph: Neo4jGraph):
|
|
93 |
concepts_name = [concept['c']['name'] for concept in concepts]
|
94 |
return concepts_name
|
95 |
|
|
|
|
|
96 |
def get_related_concepts(graph: Neo4jGraph, question: str):
|
97 |
concepts = get_concepts(graph)
|
98 |
-
|
|
|
99 |
print(f"this is the llm variable : {llm}")
|
100 |
def parse_answer(llm_answer : str):
|
101 |
try:
|
102 |
print(f"This the llm_answer : {llm_answer}")
|
|
|
103 |
return re.split("\n(?:\d)+\.\s", llm_answer.split("Concepts:")[1])[1:]
|
104 |
-
except:
|
105 |
-
|
|
|
106 |
related_concepts_chain = RELEVANT_CONCEPTS_PROMPT | llm | StrOutputParser() | parse_answer
|
107 |
|
108 |
print(f"This is the question of the user : {question}")
|
109 |
print(f"This is the concepts of the user : {concepts}")
|
110 |
|
111 |
-
|
112 |
-
#groq.APIStatusError: Error code: 413 - {'error': {'message': 'Request too large for model `deepseek-r1-distill-llama-70b` in organization `org_01j6xywkndffv96m3wgh81jm49` on tokens per minute
|
113 |
-
# (TPM): Limit 5000, Requested 17099, please reduce your message size and try again. Visit https://console.groq.com/docs/rate-limits for more information.',
|
114 |
-
# 'type': 'tokens', 'code': 'rate_limit_exceeded'}}
|
115 |
-
|
116 |
try:
|
117 |
related_concepts_raw = related_concepts_chain.invoke({"user_query" : question, "concepts" : '\n'.join(concepts)})
|
118 |
print(f"related_concepts_raw : {related_concepts_raw}")
|
119 |
except Exception as e:
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
print(type(question))
|
124 |
-
error_question = ["user_query", question]
|
125 |
-
related_concepts_raw = error_concept_groq(msg,concepts,related_concepts_chain,error_question)
|
126 |
-
pass
|
127 |
|
128 |
# We clean up the list we received from the LLM in case there were some hallucinations
|
129 |
related_concepts_cleaned = []
|
@@ -142,6 +139,7 @@ def get_related_concepts(graph: Neo4jGraph, question: str):
|
|
142 |
# TODO : Add concepts found via similarity search
|
143 |
return related_concepts_cleaned
|
144 |
|
|
|
145 |
def build_concept_string(graph: Neo4jGraph, concept_list: list[str]):
|
146 |
concept_string = ""
|
147 |
for concept in concept_list:
|
@@ -163,6 +161,7 @@ def get_global_concepts(graph: Neo4jGraph):
|
|
163 |
concepts_name = [concept['gc']['name'] for concept in concepts]
|
164 |
return concepts_name
|
165 |
|
|
|
166 |
def generate_cypher(state: DocRetrieverState, config: ConfigSchema):
|
167 |
"""
|
168 |
The node where the cypher is generated
|
@@ -180,30 +179,40 @@ def generate_cypher(state: DocRetrieverState, config: ConfigSchema):
|
|
180 |
"concepts": related_concepts
|
181 |
})
|
182 |
|
183 |
-
|
184 |
-
|
185 |
if config["configurable"].get("cypher_gen_method") == 'guided':
|
186 |
concept_selection_chain = get_concept_selection_chain()
|
187 |
print(f"Concept selection chain is : {concept_selection_chain}")
|
188 |
selected_topic = concept_selection_chain.invoke({"question" : question, "concepts": get_concepts(graph)})
|
189 |
print(f"Selected topic are : {selected_topic}")
|
190 |
-
|
191 |
-
except Exception as e:
|
192 |
-
error_question = ["question", question]
|
193 |
-
selected_topic = error_concept_groq(e.body["error"]["message"],get_concepts(graph),concept_selection_chain,error_question)
|
194 |
-
pass
|
195 |
-
|
196 |
-
if config["configurable"].get("cypher_gen_method") == 'guided':
|
197 |
cyphers = [generate_cypher_from_topic(selected_topic, state['current_plan_step'])]
|
198 |
print(f"Cyphers are : {cyphers}")
|
199 |
|
|
|
|
|
|
|
|
|
|
|
200 |
if config["configurable"].get("validate_cypher"):
|
201 |
-
|
202 |
-
|
203 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
204 |
|
205 |
return {"cyphers" : cyphers}
|
206 |
|
|
|
207 |
def generate_cypher_from_topic(selected_concept: str, plan_step: int):
|
208 |
"""
|
209 |
Helper function used when the 'guided' cypher generation method has been configured
|
@@ -226,36 +235,54 @@ def get_docs(state:DocRetrieverState, config:ConfigSchema):
|
|
226 |
"""
|
227 |
graph = config["configurable"].get("graph")
|
228 |
output = []
|
229 |
-
if graph is not None:
|
230 |
for cypher in state["cyphers"]:
|
231 |
try:
|
232 |
output = graph.query(cypher)
|
233 |
-
|
|
|
|
|
234 |
except Exception as e:
|
235 |
-
print("Failed to retrieve docs : {e}")
|
|
|
236 |
|
237 |
# Clean up the docs we received as there may be duplicates depending on the cypher query
|
238 |
all_docs = []
|
239 |
for doc in output:
|
240 |
unwinded_doc = {}
|
241 |
-
|
242 |
-
|
243 |
-
|
244 |
-
|
245 |
-
|
246 |
-
|
|
|
|
|
|
|
|
|
247 |
all_docs.append(unwinded_doc)
|
248 |
-
|
249 |
|
250 |
filtered_docs = []
|
251 |
-
for
|
252 |
-
if doc not in filtered_docs:
|
253 |
-
filtered_docs.append(doc)
|
254 |
-
|
255 |
-
return {"docs": filtered_docs}
|
256 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
257 |
|
258 |
|
|
|
259 |
|
260 |
|
261 |
# Data model
|
@@ -266,22 +293,25 @@ class GradeDocumentsBinary(BaseModel):
|
|
266 |
description="Documents are relevant to the question, 'yes' or 'no'"
|
267 |
)
|
268 |
|
269 |
-
#
|
270 |
-
|
271 |
-
|
272 |
-
def get_binary_grader(model="deepseek-r1-distill-llama-70b"):
|
273 |
"""
|
274 |
Returns a binary grader to evaluate relevance of documents using specified model for generation
|
275 |
This is used when the 'binary' evaluation method has been configured
|
276 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
277 |
|
278 |
-
|
279 |
-
if model == "gpt-4o":
|
280 |
-
llm_grader_binary = ChatOpenAI(model='gpt-4o', base_url="https://llm.synapse.thalescloud.io/", temperature=0)
|
281 |
-
else:
|
282 |
-
llm_grader_binary = ChatGroq(model="deepseek-r1-distill-llama-70b", temperature=0)
|
283 |
-
structured_llm_grader_binary = llm_grader_binary.with_structured_output(GradeDocumentsBinary)
|
284 |
-
retrieval_grader_binary = BINARY_GRADER_PROMPT | structured_llm_grader_binary
|
285 |
return retrieval_grader_binary
|
286 |
|
287 |
|
@@ -292,41 +322,58 @@ class GradeDocumentsScore(BaseModel):
|
|
292 |
description="Documents are relevant to the question, score between 0 (completely irrelevant) and 1 (perfectly relevant)"
|
293 |
)
|
294 |
|
295 |
-
|
|
|
296 |
"""
|
297 |
Returns a score grader to evaluate relevance of documents using specified model for generation
|
298 |
This is used when the 'score' evaluation method has been configured
|
299 |
"""
|
300 |
-
|
301 |
-
|
302 |
-
|
303 |
-
|
304 |
-
|
305 |
-
|
306 |
-
|
|
|
|
|
|
|
|
|
307 |
|
|
|
308 |
|
309 |
-
|
|
|
310 |
'''
|
311 |
doc : the document to evaluate
|
312 |
query : the query to which to doc shoud be relevant
|
313 |
method : "binary" or "score"
|
314 |
threshold : for "score" method, score above which a doc is considered relevant
|
315 |
'''
|
316 |
-
|
317 |
-
|
318 |
-
|
319 |
-
|
320 |
-
|
321 |
-
|
322 |
-
|
323 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
324 |
else:
|
325 |
-
|
326 |
-
|
327 |
-
|
328 |
-
|
329 |
|
|
|
330 |
def eval_docs(state: DocRetrieverState, config: ConfigSchema):
|
331 |
"""
|
332 |
This node performs evaluation of the retrieved docs and
|
@@ -334,29 +381,63 @@ def eval_docs(state: DocRetrieverState, config: ConfigSchema):
|
|
334 |
|
335 |
eval_method = config["configurable"].get("eval_method") or "binary"
|
336 |
MAX_DOCS = config["configurable"].get("max_docs") or 15
|
|
|
|
|
337 |
valid_doc_scores = []
|
338 |
|
339 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
340 |
score = eval_doc(
|
341 |
-
doc=
|
342 |
query=state["query"],
|
343 |
method=eval_method,
|
344 |
threshold=config["configurable"].get("eval_threshold") or 0.7,
|
345 |
-
eval_model
|
346 |
)
|
347 |
-
|
348 |
-
|
|
|
|
|
|
|
|
|
|
|
349 |
|
350 |
if eval_method == 'score':
|
351 |
# Get at most MAX_DOCS items with the highest score if score method was used
|
352 |
-
|
353 |
-
valid_docs = [valid_doc[0] for valid_doc in
|
354 |
else:
|
355 |
# Get at mots MAX_DOCS items at random if binary method was used
|
356 |
shuffle(valid_doc_scores)
|
357 |
valid_docs = [valid_doc[0] for valid_doc in valid_doc_scores[:MAX_DOCS]]
|
358 |
|
359 |
-
|
|
|
|
|
|
|
|
|
|
|
360 |
|
361 |
|
362 |
|
@@ -382,25 +463,5 @@ def build_data_retriever_graph(memory):
|
|
382 |
|
383 |
return graph_doc_retriever
|
384 |
|
385 |
-
|
386 |
-
|
387 |
-
start = msg.find("Requested") + len("Requested ")
|
388 |
-
end = msg.find(",", start)
|
389 |
-
rate_limit = int(msg[start:end])
|
390 |
-
related_concepts = []
|
391 |
-
i = 0
|
392 |
-
start = 0
|
393 |
-
end = len(concepts) // (rate_limit // 5000 + (1 if rate_limit%4500 != 0 else 0))
|
394 |
-
while (i < rate_limit // 5000):
|
395 |
-
smaller_concepts = concepts[start:end]
|
396 |
-
start = end
|
397 |
-
end = end + len(concepts) // (rate_limit//5000 + (1 if rate_limit%4500 != 0 else 0))
|
398 |
-
res = groq.invoke({question[0] : question[1], "concepts" : '\n'.join(smaller_concepts)})
|
399 |
-
for r in res:
|
400 |
-
related_concepts.append(r)
|
401 |
-
i+=1
|
402 |
-
return related_concepts
|
403 |
-
except Exception as e:
|
404 |
-
if e.status_code == 419:
|
405 |
-
time.sleep(65)
|
406 |
-
error_concept_groq(msg,concepts,groq,question)
|
|
|
6 |
from random import shuffle, sample
|
7 |
from langgraph.checkpoint.sqlite import SqliteSaver
|
8 |
|
9 |
+
# Remove ChatGroq import
|
10 |
+
# from langchain_groq import ChatGroq
|
11 |
+
# Add ChatGoogleGenerativeAI import
|
12 |
+
from langchain_google_genai import ChatGoogleGenerativeAI
|
13 |
+
import os # Add os import
|
14 |
+
|
15 |
from langchain_openai import ChatOpenAI
|
16 |
from langchain_core.messages import HumanMessage
|
17 |
from langchain_community.graphs import Neo4jGraph
|
|
|
20 |
from langchain_core.prompts import ChatPromptTemplate
|
21 |
from langchain_core.pydantic_v1 import Field
|
22 |
from pydantic import BaseModel
|
23 |
+
|
24 |
|
25 |
from langgraph.graph import StateGraph
|
26 |
|
|
|
33 |
SCORE_GRADER_PROMPT,
|
34 |
RELEVANT_CONCEPTS_PROMPT,
|
35 |
)
|
36 |
+
# Import get_model which now handles Gemini
|
37 |
+
from ki_gen.utils import ConfigSchema, DocRetrieverState, get_model, format_doc
|
|
|
38 |
|
39 |
|
40 |
+
# ... (extract_cypher remains the same)
|
41 |
def extract_cypher(text: str) -> str:
|
42 |
"""Extract Cypher code from a text.
|
43 |
|
|
|
60 |
text
|
61 |
]
|
62 |
|
63 |
+
# Update default model and use get_model
|
64 |
+
def get_cypher_gen_chain(model: str = "gemini-2.0-flash"):
|
65 |
"""
|
66 |
Returns cypher gen chain using specified model for generation
|
67 |
This is used when the 'auto' cypher generation method has been configured
|
68 |
"""
|
69 |
+
llm_cypher_gen = get_model(model)
|
|
|
|
|
|
|
|
|
70 |
cypher_gen_chain = CYPHER_GENERATION_PROMPT | llm_cypher_gen | StrOutputParser() | extract_cypher
|
71 |
return cypher_gen_chain
|
72 |
|
73 |
+
# Update default model and use get_model
|
74 |
+
def get_concept_selection_chain(model: str = "gemini-2.0-flash"):
|
75 |
"""
|
76 |
Returns a chain to select the most relevant topic using specified model for generation.
|
77 |
This is used when the 'guided' cypher generation method has been configured
|
78 |
"""
|
79 |
+
llm_topic_selection = get_model(model)
|
|
|
|
|
|
|
|
|
80 |
print(f"FOUND LLM TOPIC SELECTION FOR THE CONCEPT SELECTION PROMPT : {llm_topic_selection}")
|
81 |
topic_selection_chain = CONCEPT_SELECTION_PROMPT | llm_topic_selection | StrOutputParser()
|
82 |
return topic_selection_chain
|
83 |
|
84 |
+
# ... (get_concepts remains the same)
|
85 |
def get_concepts(graph: Neo4jGraph):
|
86 |
concept_cypher = "MATCH (c:Concept) return c"
|
87 |
if isinstance(graph, Neo4jGraph):
|
|
|
93 |
concepts_name = [concept['c']['name'] for concept in concepts]
|
94 |
return concepts_name
|
95 |
|
96 |
+
|
97 |
+
# Update to use get_model, remove Groq error handling
|
98 |
def get_related_concepts(graph: Neo4jGraph, question: str):
|
99 |
concepts = get_concepts(graph)
|
100 |
+
# Use get_model
|
101 |
+
llm = get_model()
|
102 |
print(f"this is the llm variable : {llm}")
|
103 |
def parse_answer(llm_answer : str):
|
104 |
try:
|
105 |
print(f"This the llm_answer : {llm_answer}")
|
106 |
+
# Adjust parsing if Gemini output format differs
|
107 |
return re.split("\n(?:\d)+\.\s", llm_answer.split("Concepts:")[1])[1:]
|
108 |
+
except Exception as e:
|
109 |
+
print(f"Error parsing LLM concept answer: {e}")
|
110 |
+
return [] # Return empty list on parsing error
|
111 |
related_concepts_chain = RELEVANT_CONCEPTS_PROMPT | llm | StrOutputParser() | parse_answer
|
112 |
|
113 |
print(f"This is the question of the user : {question}")
|
114 |
print(f"This is the concepts of the user : {concepts}")
|
115 |
|
116 |
+
# Remove specific Groq error handling block
|
|
|
|
|
|
|
|
|
117 |
try:
|
118 |
related_concepts_raw = related_concepts_chain.invoke({"user_query" : question, "concepts" : '\n'.join(concepts)})
|
119 |
print(f"related_concepts_raw : {related_concepts_raw}")
|
120 |
except Exception as e:
|
121 |
+
# Add generic error handling/logging for Gemini if needed
|
122 |
+
print(f"Error invoking related concepts chain: {e}")
|
123 |
+
related_concepts_raw = [] # Assign empty list on error
|
|
|
|
|
|
|
|
|
124 |
|
125 |
# We clean up the list we received from the LLM in case there were some hallucinations
|
126 |
related_concepts_cleaned = []
|
|
|
139 |
# TODO : Add concepts found via similarity search
|
140 |
return related_concepts_cleaned
|
141 |
|
142 |
+
# ... (build_concept_string, get_global_concepts remain the same)
|
143 |
def build_concept_string(graph: Neo4jGraph, concept_list: list[str]):
|
144 |
concept_string = ""
|
145 |
for concept in concept_list:
|
|
|
161 |
concepts_name = [concept['gc']['name'] for concept in concepts]
|
162 |
return concepts_name
|
163 |
|
164 |
+
# Update concept selection error handling
|
165 |
def generate_cypher(state: DocRetrieverState, config: ConfigSchema):
|
166 |
"""
|
167 |
The node where the cypher is generated
|
|
|
179 |
"concepts": related_concepts
|
180 |
})
|
181 |
|
182 |
+
# Remove specific Groq error handling block
|
183 |
+
try:
|
184 |
if config["configurable"].get("cypher_gen_method") == 'guided':
|
185 |
concept_selection_chain = get_concept_selection_chain()
|
186 |
print(f"Concept selection chain is : {concept_selection_chain}")
|
187 |
selected_topic = concept_selection_chain.invoke({"question" : question, "concepts": get_concepts(graph)})
|
188 |
print(f"Selected topic are : {selected_topic}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
189 |
cyphers = [generate_cypher_from_topic(selected_topic, state['current_plan_step'])]
|
190 |
print(f"Cyphers are : {cyphers}")
|
191 |
|
192 |
+
except Exception as e:
|
193 |
+
# Add generic error handling/logging for Gemini if needed
|
194 |
+
print(f"Error during guided cypher generation: {e}")
|
195 |
+
cyphers = [] # Assign empty list on error
|
196 |
+
|
197 |
if config["configurable"].get("validate_cypher"):
|
198 |
+
# Ensure graph schema is correctly fetched if needed
|
199 |
+
if graph and hasattr(graph, 'structured_schema'):
|
200 |
+
corrector_schema = [Schema(el["start"], el["type"], el["end"]) for el in graph.structured_schema.get("relationships", [])]
|
201 |
+
cypher_corrector = CypherQueryCorrector(corrector_schema)
|
202 |
+
# Apply corrector only if cyphers were generated
|
203 |
+
if cyphers:
|
204 |
+
try:
|
205 |
+
cyphers = [cypher_corrector(cypher) for cypher in cyphers]
|
206 |
+
except Exception as corr_e:
|
207 |
+
print(f"Error during cypher correction: {corr_e}")
|
208 |
+
# Decide how to handle correction errors, maybe keep original cyphers
|
209 |
+
else:
|
210 |
+
print("Warning: Cypher validation skipped, graph or schema unavailable.")
|
211 |
+
|
212 |
|
213 |
return {"cyphers" : cyphers}
|
214 |
|
215 |
+
# ... (generate_cypher_from_topic, get_docs remain the same)
|
216 |
def generate_cypher_from_topic(selected_concept: str, plan_step: int):
|
217 |
"""
|
218 |
Helper function used when the 'guided' cypher generation method has been configured
|
|
|
235 |
"""
|
236 |
graph = config["configurable"].get("graph")
|
237 |
output = []
|
238 |
+
if graph is not None and state.get("cyphers"): # Check if cyphers exist
|
239 |
for cypher in state["cyphers"]:
|
240 |
try:
|
241 |
output = graph.query(cypher)
|
242 |
+
# Assuming the first successful query is sufficient
|
243 |
+
if output:
|
244 |
+
break
|
245 |
except Exception as e:
|
246 |
+
print(f"Failed to retrieve docs with cypher '{cypher}': {e}")
|
247 |
+
# Continue to try next cypher if one fails
|
248 |
|
249 |
# Clean up the docs we received as there may be duplicates depending on the cypher query
|
250 |
all_docs = []
|
251 |
for doc in output:
|
252 |
unwinded_doc = {}
|
253 |
+
# Ensure doc is a dictionary before iterating
|
254 |
+
if isinstance(doc, dict):
|
255 |
+
for key in doc:
|
256 |
+
if isinstance(doc[key], dict):
|
257 |
+
# If a value is a dict, treat it as a separate document
|
258 |
+
all_docs.append(doc[key])
|
259 |
+
else:
|
260 |
+
unwinded_doc.update({key: doc[key]})
|
261 |
+
# Add the unwinded parts if any keys were not dictionaries
|
262 |
+
if unwinded_doc:
|
263 |
all_docs.append(unwinded_doc)
|
|
|
264 |
|
265 |
filtered_docs = []
|
266 |
+
seen_docs = set() # Use a set for faster duplicate checking based on a unique identifier
|
|
|
|
|
|
|
|
|
267 |
|
268 |
+
for doc in all_docs:
|
269 |
+
# Create a tuple of items to check for duplicates, assuming dicts are hashable
|
270 |
+
# If dicts contain unhashable types (like lists), convert them to strings or use a primary key
|
271 |
+
try:
|
272 |
+
doc_tuple = tuple(sorted(doc.items()))
|
273 |
+
if doc_tuple not in seen_docs:
|
274 |
+
filtered_docs.append(doc)
|
275 |
+
seen_docs.add(doc_tuple)
|
276 |
+
except TypeError:
|
277 |
+
# Handle cases where doc items are not hashable (e.g., contain lists/dicts)
|
278 |
+
# Fallback: convert doc to string for uniqueness check (less reliable)
|
279 |
+
doc_str = str(sorted(doc.items()))
|
280 |
+
if doc_str not in seen_docs:
|
281 |
+
filtered_docs.append(doc)
|
282 |
+
seen_docs.add(doc_str)
|
283 |
|
284 |
|
285 |
+
return {"docs": filtered_docs}
|
286 |
|
287 |
|
288 |
# Data model
|
|
|
293 |
description="Documents are relevant to the question, 'yes' or 'no'"
|
294 |
)
|
295 |
|
296 |
+
# Update default model and use get_model
|
297 |
+
def get_binary_grader(model="gemini-2.0-flash"):
|
|
|
|
|
298 |
"""
|
299 |
Returns a binary grader to evaluate relevance of documents using specified model for generation
|
300 |
This is used when the 'binary' evaluation method has been configured
|
301 |
"""
|
302 |
+
llm_grader_binary = get_model(model)
|
303 |
+
# Check if the model supports structured output, otherwise use standard invocation
|
304 |
+
try:
|
305 |
+
# Attempt to get structured output
|
306 |
+
structured_llm_grader_binary = llm_grader_binary.with_structured_output(GradeDocumentsBinary)
|
307 |
+
retrieval_grader_binary = BINARY_GRADER_PROMPT | structured_llm_grader_binary
|
308 |
+
except NotImplementedError:
|
309 |
+
print(f"Warning: Model {model} may not support structured output directly for binary grading. Falling back.")
|
310 |
+
# Fallback: parse the string output if structured output fails
|
311 |
+
from langchain_core.output_parsers import SimpleJsonOutputParser
|
312 |
+
# You might need to adjust the prompt to explicitly ask for JSON
|
313 |
+
retrieval_grader_binary = BINARY_GRADER_PROMPT | llm_grader_binary | SimpleJsonOutputParser() # Or StrOutputParser and manual parsing
|
314 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
315 |
return retrieval_grader_binary
|
316 |
|
317 |
|
|
|
322 |
description="Documents are relevant to the question, score between 0 (completely irrelevant) and 1 (perfectly relevant)"
|
323 |
)
|
324 |
|
325 |
+
# Update default model and use get_model
|
326 |
+
def get_score_grader(model="gemini-2.0-flash"):
|
327 |
"""
|
328 |
Returns a score grader to evaluate relevance of documents using specified model for generation
|
329 |
This is used when the 'score' evaluation method has been configured
|
330 |
"""
|
331 |
+
llm_grader_score = get_model(model)
|
332 |
+
# Check if the model supports structured output
|
333 |
+
try:
|
334 |
+
structured_llm_grader_score = llm_grader_score.with_structured_output(GradeDocumentsScore)
|
335 |
+
retrieval_grader_score = SCORE_GRADER_PROMPT | structured_llm_grader_score
|
336 |
+
except NotImplementedError:
|
337 |
+
print(f"Warning: Model {model} may not support structured output directly for score grading. Falling back.")
|
338 |
+
# Fallback: parse the string output if structured output fails
|
339 |
+
from langchain_core.output_parsers import SimpleJsonOutputParser
|
340 |
+
# Adjust prompt if needed
|
341 |
+
retrieval_grader_score = SCORE_GRADER_PROMPT | llm_grader_score | SimpleJsonOutputParser() # Or StrOutputParser and manual parsing
|
342 |
|
343 |
+
return retrieval_grader_score
|
344 |
|
345 |
+
# Update default model
|
346 |
+
def eval_doc(doc, query, method="binary", threshold=0.7, eval_model="gemini-2.0-flash"):
|
347 |
'''
|
348 |
doc : the document to evaluate
|
349 |
query : the query to which to doc shoud be relevant
|
350 |
method : "binary" or "score"
|
351 |
threshold : for "score" method, score above which a doc is considered relevant
|
352 |
'''
|
353 |
+
try:
|
354 |
+
if method == "binary":
|
355 |
+
retrieval_grader_binary = get_binary_grader(model=eval_model)
|
356 |
+
result = retrieval_grader_binary.invoke({"question": query, "document":doc})
|
357 |
+
# Handle both structured and parsed output
|
358 |
+
binary_score = result.binary_score if isinstance(result, GradeDocumentsBinary) else result.get("binary_score", "no")
|
359 |
+
return 1 if (binary_score.lower() == 'yes') else 0
|
360 |
+
elif method == "score":
|
361 |
+
retrieval_grader_score = get_score_grader(model=eval_model)
|
362 |
+
result = retrieval_grader_score.invoke({"query": query, "document":doc})
|
363 |
+
# Handle both structured and parsed output
|
364 |
+
score = result.score if isinstance(result, GradeDocumentsScore) else result.get("score")
|
365 |
+
if score is not None:
|
366 |
+
return score if float(score) >= threshold else 0
|
367 |
+
else:
|
368 |
+
print("Warning: Couldn't parse score, marking document as relevant by default.")
|
369 |
+
return 1 # Default to relevant if score parsing fails
|
370 |
else:
|
371 |
+
raise ValueError("Invalid method")
|
372 |
+
except Exception as e:
|
373 |
+
print(f"Error evaluating document: {e}")
|
374 |
+
return 0 # Default to irrelevant on error
|
375 |
|
376 |
+
# Update default model
|
377 |
def eval_docs(state: DocRetrieverState, config: ConfigSchema):
|
378 |
"""
|
379 |
This node performs evaluation of the retrieved docs and
|
|
|
381 |
|
382 |
eval_method = config["configurable"].get("eval_method") or "binary"
|
383 |
MAX_DOCS = config["configurable"].get("max_docs") or 15
|
384 |
+
# Update default model name
|
385 |
+
eval_model_name = config["configurable"].get("eval_model") or "gemini-2.0-flash"
|
386 |
valid_doc_scores = []
|
387 |
|
388 |
+
# Ensure 'docs' exists and is a list
|
389 |
+
docs_to_evaluate = state.get("docs", [])
|
390 |
+
if not isinstance(docs_to_evaluate, list):
|
391 |
+
print("Warning: 'docs' is not a list, skipping evaluation.")
|
392 |
+
docs_to_evaluate = []
|
393 |
+
|
394 |
+
# Sample safely
|
395 |
+
sample_size = min(25, len(docs_to_evaluate))
|
396 |
+
sampled_docs = sample(docs_to_evaluate, sample_size) if sample_size > 0 else []
|
397 |
+
|
398 |
+
|
399 |
+
for doc in sampled_docs:
|
400 |
+
# Ensure doc is not None before formatting
|
401 |
+
if doc is None:
|
402 |
+
print("Warning: Encountered None document during evaluation, skipping.")
|
403 |
+
continue
|
404 |
+
|
405 |
+
formatted_doc_str = format_doc(doc)
|
406 |
+
# Add basic check for empty formatted doc
|
407 |
+
if not formatted_doc_str.strip():
|
408 |
+
print(f"Warning: Skipping empty formatted document: {doc}")
|
409 |
+
continue
|
410 |
+
|
411 |
score = eval_doc(
|
412 |
+
doc=formatted_doc_str,
|
413 |
query=state["query"],
|
414 |
method=eval_method,
|
415 |
threshold=config["configurable"].get("eval_threshold") or 0.7,
|
416 |
+
eval_model=eval_model_name # Pass the eval_model name
|
417 |
)
|
418 |
+
# Ensure score is numeric before appending
|
419 |
+
if isinstance(score, (int, float)):
|
420 |
+
if score > 0: # Only add if relevant (score > 0 or binary score == 1)
|
421 |
+
valid_doc_scores.append((doc, score))
|
422 |
+
else:
|
423 |
+
print(f"Warning: Received non-numeric score ({score}) for doc {doc}, skipping.")
|
424 |
+
|
425 |
|
426 |
if eval_method == 'score':
|
427 |
# Get at most MAX_DOCS items with the highest score if score method was used
|
428 |
+
valid_docs_sorted = sorted(valid_doc_scores, key=lambda x: x[1], reverse=True) # Sort descending
|
429 |
+
valid_docs = [valid_doc[0] for valid_doc in valid_docs_sorted[:MAX_DOCS]]
|
430 |
else:
|
431 |
# Get at mots MAX_DOCS items at random if binary method was used
|
432 |
shuffle(valid_doc_scores)
|
433 |
valid_docs = [valid_doc[0] for valid_doc in valid_doc_scores[:MAX_DOCS]]
|
434 |
|
435 |
+
# Ensure existing valid_docs is a list before concatenating
|
436 |
+
existing_valid_docs = state.get("valid_docs", [])
|
437 |
+
if not isinstance(existing_valid_docs, list):
|
438 |
+
existing_valid_docs = []
|
439 |
+
|
440 |
+
return {"valid_docs": valid_docs + existing_valid_docs}
|
441 |
|
442 |
|
443 |
|
|
|
463 |
|
464 |
return graph_doc_retriever
|
465 |
|
466 |
+
# Remove Groq specific error handling function
|
467 |
+
# def error_concept_groq(msg,concepts,groq,question): ...
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|