adrienbrdne commited on
Commit
aacc458
·
verified ·
1 Parent(s): 42c00ab

Update ki_gen/data_retriever.py

Browse files
Files changed (1) hide show
  1. ki_gen/data_retriever.py +182 -121
ki_gen/data_retriever.py CHANGED
@@ -6,7 +6,12 @@ import time
6
  from random import shuffle, sample
7
  from langgraph.checkpoint.sqlite import SqliteSaver
8
 
9
- from langchain_groq import ChatGroq
 
 
 
 
 
10
  from langchain_openai import ChatOpenAI
11
  from langchain_core.messages import HumanMessage
12
  from langchain_community.graphs import Neo4jGraph
@@ -15,7 +20,7 @@ from langchain_core.output_parsers import StrOutputParser
15
  from langchain_core.prompts import ChatPromptTemplate
16
  from langchain_core.pydantic_v1 import Field
17
  from pydantic import BaseModel
18
- from langchain_groq import ChatGroq
19
 
20
  from langgraph.graph import StateGraph
21
 
@@ -28,11 +33,11 @@ from ki_gen.prompts import (
28
  SCORE_GRADER_PROMPT,
29
  RELEVANT_CONCEPTS_PROMPT,
30
  )
31
- from ki_gen.utils import ConfigSchema, DocRetrieverState, get_model, format_doc
32
-
33
-
34
 
35
 
 
36
  def extract_cypher(text: str) -> str:
37
  """Extract Cypher code from a text.
38
 
@@ -55,33 +60,28 @@ def extract_cypher(text: str) -> str:
55
  text
56
  ]
57
 
58
- def get_cypher_gen_chain(model: str = "deepseek-r1-distill-llama-70b"):
 
59
  """
60
  Returns cypher gen chain using specified model for generation
61
  This is used when the 'auto' cypher generation method has been configured
62
  """
63
-
64
- if model=="openai":
65
- llm_cypher_gen = ChatOpenAI(model='gpt-4o', base_url="https://llm.synapse.thalescloud.io/")
66
- else:
67
- llm_cypher_gen = ChatGroq(model = "deepseek-r1-distill-llama-70b")
68
  cypher_gen_chain = CYPHER_GENERATION_PROMPT | llm_cypher_gen | StrOutputParser() | extract_cypher
69
  return cypher_gen_chain
70
 
71
- def get_concept_selection_chain(model: str = "deepseek-r1-distill-llama-70b"):
 
72
  """
73
  Returns a chain to select the most relevant topic using specified model for generation.
74
  This is used when the 'guided' cypher generation method has been configured
75
  """
76
-
77
- if model == "openai":
78
- llm_topic_selection = ChatOpenAI(model='gpt-4o', base_url="https://llm.synapse.thalescloud.io/")
79
- else:
80
- llm_topic_selection = ChatGroq(model="deepseek-r1-distill-llama-70b")
81
  print(f"FOUND LLM TOPIC SELECTION FOR THE CONCEPT SELECTION PROMPT : {llm_topic_selection}")
82
  topic_selection_chain = CONCEPT_SELECTION_PROMPT | llm_topic_selection | StrOutputParser()
83
  return topic_selection_chain
84
 
 
85
  def get_concepts(graph: Neo4jGraph):
86
  concept_cypher = "MATCH (c:Concept) return c"
87
  if isinstance(graph, Neo4jGraph):
@@ -93,37 +93,34 @@ def get_concepts(graph: Neo4jGraph):
93
  concepts_name = [concept['c']['name'] for concept in concepts]
94
  return concepts_name
95
 
 
 
96
  def get_related_concepts(graph: Neo4jGraph, question: str):
97
  concepts = get_concepts(graph)
98
- llm = get_model()
 
99
  print(f"this is the llm variable : {llm}")
100
  def parse_answer(llm_answer : str):
101
  try:
102
  print(f"This the llm_answer : {llm_answer}")
 
103
  return re.split("\n(?:\d)+\.\s", llm_answer.split("Concepts:")[1])[1:]
104
- except:
105
- return "No concept"
 
106
  related_concepts_chain = RELEVANT_CONCEPTS_PROMPT | llm | StrOutputParser() | parse_answer
107
 
108
  print(f"This is the question of the user : {question}")
109
  print(f"This is the concepts of the user : {concepts}")
110
 
111
-
112
- #groq.APIStatusError: Error code: 413 - {'error': {'message': 'Request too large for model `deepseek-r1-distill-llama-70b` in organization `org_01j6xywkndffv96m3wgh81jm49` on tokens per minute
113
- # (TPM): Limit 5000, Requested 17099, please reduce your message size and try again. Visit https://console.groq.com/docs/rate-limits for more information.',
114
- # 'type': 'tokens', 'code': 'rate_limit_exceeded'}}
115
-
116
  try:
117
  related_concepts_raw = related_concepts_chain.invoke({"user_query" : question, "concepts" : '\n'.join(concepts)})
118
  print(f"related_concepts_raw : {related_concepts_raw}")
119
  except Exception as e:
120
- if e.status_code == 413:
121
- msg = e.body["error"]["message"]
122
- print(f"question is : {question}")
123
- print(type(question))
124
- error_question = ["user_query", question]
125
- related_concepts_raw = error_concept_groq(msg,concepts,related_concepts_chain,error_question)
126
- pass
127
 
128
  # We clean up the list we received from the LLM in case there were some hallucinations
129
  related_concepts_cleaned = []
@@ -142,6 +139,7 @@ def get_related_concepts(graph: Neo4jGraph, question: str):
142
  # TODO : Add concepts found via similarity search
143
  return related_concepts_cleaned
144
 
 
145
  def build_concept_string(graph: Neo4jGraph, concept_list: list[str]):
146
  concept_string = ""
147
  for concept in concept_list:
@@ -163,6 +161,7 @@ def get_global_concepts(graph: Neo4jGraph):
163
  concepts_name = [concept['gc']['name'] for concept in concepts]
164
  return concepts_name
165
 
 
166
  def generate_cypher(state: DocRetrieverState, config: ConfigSchema):
167
  """
168
  The node where the cypher is generated
@@ -180,30 +179,40 @@ def generate_cypher(state: DocRetrieverState, config: ConfigSchema):
180
  "concepts": related_concepts
181
  })
182
 
183
- try :
184
-
185
  if config["configurable"].get("cypher_gen_method") == 'guided':
186
  concept_selection_chain = get_concept_selection_chain()
187
  print(f"Concept selection chain is : {concept_selection_chain}")
188
  selected_topic = concept_selection_chain.invoke({"question" : question, "concepts": get_concepts(graph)})
189
  print(f"Selected topic are : {selected_topic}")
190
-
191
- except Exception as e:
192
- error_question = ["question", question]
193
- selected_topic = error_concept_groq(e.body["error"]["message"],get_concepts(graph),concept_selection_chain,error_question)
194
- pass
195
-
196
- if config["configurable"].get("cypher_gen_method") == 'guided':
197
  cyphers = [generate_cypher_from_topic(selected_topic, state['current_plan_step'])]
198
  print(f"Cyphers are : {cyphers}")
199
 
 
 
 
 
 
200
  if config["configurable"].get("validate_cypher"):
201
- corrector_schema = [Schema(el["start"], el["type"], el["end"]) for el in graph.structured_schema.get("relationships")]
202
- cypher_corrector = CypherQueryCorrector(corrector_schema)
203
- cyphers = [cypher_corrector(cypher) for cypher in cyphers]
 
 
 
 
 
 
 
 
 
 
 
204
 
205
  return {"cyphers" : cyphers}
206
 
 
207
  def generate_cypher_from_topic(selected_concept: str, plan_step: int):
208
  """
209
  Helper function used when the 'guided' cypher generation method has been configured
@@ -226,36 +235,54 @@ def get_docs(state:DocRetrieverState, config:ConfigSchema):
226
  """
227
  graph = config["configurable"].get("graph")
228
  output = []
229
- if graph is not None:
230
  for cypher in state["cyphers"]:
231
  try:
232
  output = graph.query(cypher)
233
- break
 
 
234
  except Exception as e:
235
- print("Failed to retrieve docs : {e}")
 
236
 
237
  # Clean up the docs we received as there may be duplicates depending on the cypher query
238
  all_docs = []
239
  for doc in output:
240
  unwinded_doc = {}
241
- for key in doc:
242
- if isinstance(doc[key], dict):
243
- all_docs.append(doc[key])
244
- else:
245
- unwinded_doc.update({key: doc[key]})
246
- if unwinded_doc:
 
 
 
 
247
  all_docs.append(unwinded_doc)
248
-
249
 
250
  filtered_docs = []
251
- for doc in all_docs:
252
- if doc not in filtered_docs:
253
- filtered_docs.append(doc)
254
-
255
- return {"docs": filtered_docs}
256
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
257
 
258
 
 
259
 
260
 
261
  # Data model
@@ -266,22 +293,25 @@ class GradeDocumentsBinary(BaseModel):
266
  description="Documents are relevant to the question, 'yes' or 'no'"
267
  )
268
 
269
- # LLM with function call
270
- # llm_grader_binary = ChatGroq(model="deepseek-r1-distill-llama-70b", temperature=0)
271
-
272
- def get_binary_grader(model="deepseek-r1-distill-llama-70b"):
273
  """
274
  Returns a binary grader to evaluate relevance of documents using specified model for generation
275
  This is used when the 'binary' evaluation method has been configured
276
  """
 
 
 
 
 
 
 
 
 
 
 
 
277
 
278
-
279
- if model == "gpt-4o":
280
- llm_grader_binary = ChatOpenAI(model='gpt-4o', base_url="https://llm.synapse.thalescloud.io/", temperature=0)
281
- else:
282
- llm_grader_binary = ChatGroq(model="deepseek-r1-distill-llama-70b", temperature=0)
283
- structured_llm_grader_binary = llm_grader_binary.with_structured_output(GradeDocumentsBinary)
284
- retrieval_grader_binary = BINARY_GRADER_PROMPT | structured_llm_grader_binary
285
  return retrieval_grader_binary
286
 
287
 
@@ -292,41 +322,58 @@ class GradeDocumentsScore(BaseModel):
292
  description="Documents are relevant to the question, score between 0 (completely irrelevant) and 1 (perfectly relevant)"
293
  )
294
 
295
- def get_score_grader(model="deepseek-r1-distill-llama-70b"):
 
296
  """
297
  Returns a score grader to evaluate relevance of documents using specified model for generation
298
  This is used when the 'score' evaluation method has been configured
299
  """
300
- if model == "gpt-4o":
301
- llm_grader_score = ChatOpenAI(model='gpt-4o', base_url="https://llm.synapse.thalescloud.io/", temperature=0)
302
- else:
303
- llm_grader_score = ChatGroq(model="deepseek-r1-distill-llama-70b", temperature = 0)
304
- structured_llm_grader_score = llm_grader_score.with_structured_output(GradeDocumentsScore)
305
- retrieval_grader_score = SCORE_GRADER_PROMPT | structured_llm_grader_score
306
- return retrieval_grader_score
 
 
 
 
307
 
 
308
 
309
- def eval_doc(doc, query, method="binary", threshold=0.7, eval_model="deepseek-r1-distill-llama-70b"):
 
310
  '''
311
  doc : the document to evaluate
312
  query : the query to which to doc shoud be relevant
313
  method : "binary" or "score"
314
  threshold : for "score" method, score above which a doc is considered relevant
315
  '''
316
- if method == "binary":
317
- retrieval_grader_binary = get_binary_grader(model=eval_model)
318
- return 1 if (retrieval_grader_binary.invoke({"question": query, "document":doc}).binary_score == 'yes') else 0
319
- elif method == "score":
320
- retrieval_grader_score = get_score_grader(model=eval_model)
321
- score = retrieval_grader_score.invoke({"query": query, "document":doc}).score or None
322
- if score is not None:
323
- return score if score >= threshold else 0
 
 
 
 
 
 
 
 
 
324
  else:
325
- # Couldn't parse score, marking document as relevant by default
326
- return 1
327
- else:
328
- raise ValueError("Invalid method")
329
 
 
330
  def eval_docs(state: DocRetrieverState, config: ConfigSchema):
331
  """
332
  This node performs evaluation of the retrieved docs and
@@ -334,29 +381,63 @@ def eval_docs(state: DocRetrieverState, config: ConfigSchema):
334
 
335
  eval_method = config["configurable"].get("eval_method") or "binary"
336
  MAX_DOCS = config["configurable"].get("max_docs") or 15
 
 
337
  valid_doc_scores = []
338
 
339
- for doc in sample(state["docs"], min(25, len(state["docs"]))):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
340
  score = eval_doc(
341
- doc=format_doc(doc),
342
  query=state["query"],
343
  method=eval_method,
344
  threshold=config["configurable"].get("eval_threshold") or 0.7,
345
- eval_model = config["configurable"].get("eval_model") or "deepseek-r1-distill-llama-70b"
346
  )
347
- if score:
348
- valid_doc_scores.append((doc, score))
 
 
 
 
 
349
 
350
  if eval_method == 'score':
351
  # Get at most MAX_DOCS items with the highest score if score method was used
352
- valid_docs = sorted(valid_doc_scores, key=lambda x: x[1])
353
- valid_docs = [valid_doc[0] for valid_doc in valid_docs[:MAX_DOCS]]
354
  else:
355
  # Get at mots MAX_DOCS items at random if binary method was used
356
  shuffle(valid_doc_scores)
357
  valid_docs = [valid_doc[0] for valid_doc in valid_doc_scores[:MAX_DOCS]]
358
 
359
- return {"valid_docs": valid_docs + (state["valid_docs"] or [])}
 
 
 
 
 
360
 
361
 
362
 
@@ -382,25 +463,5 @@ def build_data_retriever_graph(memory):
382
 
383
  return graph_doc_retriever
384
 
385
- def error_concept_groq(msg,concepts,groq,question):
386
- try:
387
- start = msg.find("Requested") + len("Requested ")
388
- end = msg.find(",", start)
389
- rate_limit = int(msg[start:end])
390
- related_concepts = []
391
- i = 0
392
- start = 0
393
- end = len(concepts) // (rate_limit // 5000 + (1 if rate_limit%4500 != 0 else 0))
394
- while (i < rate_limit // 5000):
395
- smaller_concepts = concepts[start:end]
396
- start = end
397
- end = end + len(concepts) // (rate_limit//5000 + (1 if rate_limit%4500 != 0 else 0))
398
- res = groq.invoke({question[0] : question[1], "concepts" : '\n'.join(smaller_concepts)})
399
- for r in res:
400
- related_concepts.append(r)
401
- i+=1
402
- return related_concepts
403
- except Exception as e:
404
- if e.status_code == 419:
405
- time.sleep(65)
406
- error_concept_groq(msg,concepts,groq,question)
 
6
  from random import shuffle, sample
7
  from langgraph.checkpoint.sqlite import SqliteSaver
8
 
9
+ # Remove ChatGroq import
10
+ # from langchain_groq import ChatGroq
11
+ # Add ChatGoogleGenerativeAI import
12
+ from langchain_google_genai import ChatGoogleGenerativeAI
13
+ import os # Add os import
14
+
15
  from langchain_openai import ChatOpenAI
16
  from langchain_core.messages import HumanMessage
17
  from langchain_community.graphs import Neo4jGraph
 
20
  from langchain_core.prompts import ChatPromptTemplate
21
  from langchain_core.pydantic_v1 import Field
22
  from pydantic import BaseModel
23
+
24
 
25
  from langgraph.graph import StateGraph
26
 
 
33
  SCORE_GRADER_PROMPT,
34
  RELEVANT_CONCEPTS_PROMPT,
35
  )
36
+ # Import get_model which now handles Gemini
37
+ from ki_gen.utils import ConfigSchema, DocRetrieverState, get_model, format_doc
 
38
 
39
 
40
+ # ... (extract_cypher remains the same)
41
  def extract_cypher(text: str) -> str:
42
  """Extract Cypher code from a text.
43
 
 
60
  text
61
  ]
62
 
63
+ # Update default model and use get_model
64
+ def get_cypher_gen_chain(model: str = "gemini-2.0-flash"):
65
  """
66
  Returns cypher gen chain using specified model for generation
67
  This is used when the 'auto' cypher generation method has been configured
68
  """
69
+ llm_cypher_gen = get_model(model)
 
 
 
 
70
  cypher_gen_chain = CYPHER_GENERATION_PROMPT | llm_cypher_gen | StrOutputParser() | extract_cypher
71
  return cypher_gen_chain
72
 
73
+ # Update default model and use get_model
74
+ def get_concept_selection_chain(model: str = "gemini-2.0-flash"):
75
  """
76
  Returns a chain to select the most relevant topic using specified model for generation.
77
  This is used when the 'guided' cypher generation method has been configured
78
  """
79
+ llm_topic_selection = get_model(model)
 
 
 
 
80
  print(f"FOUND LLM TOPIC SELECTION FOR THE CONCEPT SELECTION PROMPT : {llm_topic_selection}")
81
  topic_selection_chain = CONCEPT_SELECTION_PROMPT | llm_topic_selection | StrOutputParser()
82
  return topic_selection_chain
83
 
84
+ # ... (get_concepts remains the same)
85
  def get_concepts(graph: Neo4jGraph):
86
  concept_cypher = "MATCH (c:Concept) return c"
87
  if isinstance(graph, Neo4jGraph):
 
93
  concepts_name = [concept['c']['name'] for concept in concepts]
94
  return concepts_name
95
 
96
+
97
+ # Update to use get_model, remove Groq error handling
98
  def get_related_concepts(graph: Neo4jGraph, question: str):
99
  concepts = get_concepts(graph)
100
+ # Use get_model
101
+ llm = get_model()
102
  print(f"this is the llm variable : {llm}")
103
  def parse_answer(llm_answer : str):
104
  try:
105
  print(f"This the llm_answer : {llm_answer}")
106
+ # Adjust parsing if Gemini output format differs
107
  return re.split("\n(?:\d)+\.\s", llm_answer.split("Concepts:")[1])[1:]
108
+ except Exception as e:
109
+ print(f"Error parsing LLM concept answer: {e}")
110
+ return [] # Return empty list on parsing error
111
  related_concepts_chain = RELEVANT_CONCEPTS_PROMPT | llm | StrOutputParser() | parse_answer
112
 
113
  print(f"This is the question of the user : {question}")
114
  print(f"This is the concepts of the user : {concepts}")
115
 
116
+ # Remove specific Groq error handling block
 
 
 
 
117
  try:
118
  related_concepts_raw = related_concepts_chain.invoke({"user_query" : question, "concepts" : '\n'.join(concepts)})
119
  print(f"related_concepts_raw : {related_concepts_raw}")
120
  except Exception as e:
121
+ # Add generic error handling/logging for Gemini if needed
122
+ print(f"Error invoking related concepts chain: {e}")
123
+ related_concepts_raw = [] # Assign empty list on error
 
 
 
 
124
 
125
  # We clean up the list we received from the LLM in case there were some hallucinations
126
  related_concepts_cleaned = []
 
139
  # TODO : Add concepts found via similarity search
140
  return related_concepts_cleaned
141
 
142
+ # ... (build_concept_string, get_global_concepts remain the same)
143
  def build_concept_string(graph: Neo4jGraph, concept_list: list[str]):
144
  concept_string = ""
145
  for concept in concept_list:
 
161
  concepts_name = [concept['gc']['name'] for concept in concepts]
162
  return concepts_name
163
 
164
+ # Update concept selection error handling
165
  def generate_cypher(state: DocRetrieverState, config: ConfigSchema):
166
  """
167
  The node where the cypher is generated
 
179
  "concepts": related_concepts
180
  })
181
 
182
+ # Remove specific Groq error handling block
183
+ try:
184
  if config["configurable"].get("cypher_gen_method") == 'guided':
185
  concept_selection_chain = get_concept_selection_chain()
186
  print(f"Concept selection chain is : {concept_selection_chain}")
187
  selected_topic = concept_selection_chain.invoke({"question" : question, "concepts": get_concepts(graph)})
188
  print(f"Selected topic are : {selected_topic}")
 
 
 
 
 
 
 
189
  cyphers = [generate_cypher_from_topic(selected_topic, state['current_plan_step'])]
190
  print(f"Cyphers are : {cyphers}")
191
 
192
+ except Exception as e:
193
+ # Add generic error handling/logging for Gemini if needed
194
+ print(f"Error during guided cypher generation: {e}")
195
+ cyphers = [] # Assign empty list on error
196
+
197
  if config["configurable"].get("validate_cypher"):
198
+ # Ensure graph schema is correctly fetched if needed
199
+ if graph and hasattr(graph, 'structured_schema'):
200
+ corrector_schema = [Schema(el["start"], el["type"], el["end"]) for el in graph.structured_schema.get("relationships", [])]
201
+ cypher_corrector = CypherQueryCorrector(corrector_schema)
202
+ # Apply corrector only if cyphers were generated
203
+ if cyphers:
204
+ try:
205
+ cyphers = [cypher_corrector(cypher) for cypher in cyphers]
206
+ except Exception as corr_e:
207
+ print(f"Error during cypher correction: {corr_e}")
208
+ # Decide how to handle correction errors, maybe keep original cyphers
209
+ else:
210
+ print("Warning: Cypher validation skipped, graph or schema unavailable.")
211
+
212
 
213
  return {"cyphers" : cyphers}
214
 
215
+ # ... (generate_cypher_from_topic, get_docs remain the same)
216
  def generate_cypher_from_topic(selected_concept: str, plan_step: int):
217
  """
218
  Helper function used when the 'guided' cypher generation method has been configured
 
235
  """
236
  graph = config["configurable"].get("graph")
237
  output = []
238
+ if graph is not None and state.get("cyphers"): # Check if cyphers exist
239
  for cypher in state["cyphers"]:
240
  try:
241
  output = graph.query(cypher)
242
+ # Assuming the first successful query is sufficient
243
+ if output:
244
+ break
245
  except Exception as e:
246
+ print(f"Failed to retrieve docs with cypher '{cypher}': {e}")
247
+ # Continue to try next cypher if one fails
248
 
249
  # Clean up the docs we received as there may be duplicates depending on the cypher query
250
  all_docs = []
251
  for doc in output:
252
  unwinded_doc = {}
253
+ # Ensure doc is a dictionary before iterating
254
+ if isinstance(doc, dict):
255
+ for key in doc:
256
+ if isinstance(doc[key], dict):
257
+ # If a value is a dict, treat it as a separate document
258
+ all_docs.append(doc[key])
259
+ else:
260
+ unwinded_doc.update({key: doc[key]})
261
+ # Add the unwinded parts if any keys were not dictionaries
262
+ if unwinded_doc:
263
  all_docs.append(unwinded_doc)
 
264
 
265
  filtered_docs = []
266
+ seen_docs = set() # Use a set for faster duplicate checking based on a unique identifier
 
 
 
 
267
 
268
+ for doc in all_docs:
269
+ # Create a tuple of items to check for duplicates, assuming dicts are hashable
270
+ # If dicts contain unhashable types (like lists), convert them to strings or use a primary key
271
+ try:
272
+ doc_tuple = tuple(sorted(doc.items()))
273
+ if doc_tuple not in seen_docs:
274
+ filtered_docs.append(doc)
275
+ seen_docs.add(doc_tuple)
276
+ except TypeError:
277
+ # Handle cases where doc items are not hashable (e.g., contain lists/dicts)
278
+ # Fallback: convert doc to string for uniqueness check (less reliable)
279
+ doc_str = str(sorted(doc.items()))
280
+ if doc_str not in seen_docs:
281
+ filtered_docs.append(doc)
282
+ seen_docs.add(doc_str)
283
 
284
 
285
+ return {"docs": filtered_docs}
286
 
287
 
288
  # Data model
 
293
  description="Documents are relevant to the question, 'yes' or 'no'"
294
  )
295
 
296
+ # Update default model and use get_model
297
+ def get_binary_grader(model="gemini-2.0-flash"):
 
 
298
  """
299
  Returns a binary grader to evaluate relevance of documents using specified model for generation
300
  This is used when the 'binary' evaluation method has been configured
301
  """
302
+ llm_grader_binary = get_model(model)
303
+ # Check if the model supports structured output, otherwise use standard invocation
304
+ try:
305
+ # Attempt to get structured output
306
+ structured_llm_grader_binary = llm_grader_binary.with_structured_output(GradeDocumentsBinary)
307
+ retrieval_grader_binary = BINARY_GRADER_PROMPT | structured_llm_grader_binary
308
+ except NotImplementedError:
309
+ print(f"Warning: Model {model} may not support structured output directly for binary grading. Falling back.")
310
+ # Fallback: parse the string output if structured output fails
311
+ from langchain_core.output_parsers import SimpleJsonOutputParser
312
+ # You might need to adjust the prompt to explicitly ask for JSON
313
+ retrieval_grader_binary = BINARY_GRADER_PROMPT | llm_grader_binary | SimpleJsonOutputParser() # Or StrOutputParser and manual parsing
314
 
 
 
 
 
 
 
 
315
  return retrieval_grader_binary
316
 
317
 
 
322
  description="Documents are relevant to the question, score between 0 (completely irrelevant) and 1 (perfectly relevant)"
323
  )
324
 
325
+ # Update default model and use get_model
326
+ def get_score_grader(model="gemini-2.0-flash"):
327
  """
328
  Returns a score grader to evaluate relevance of documents using specified model for generation
329
  This is used when the 'score' evaluation method has been configured
330
  """
331
+ llm_grader_score = get_model(model)
332
+ # Check if the model supports structured output
333
+ try:
334
+ structured_llm_grader_score = llm_grader_score.with_structured_output(GradeDocumentsScore)
335
+ retrieval_grader_score = SCORE_GRADER_PROMPT | structured_llm_grader_score
336
+ except NotImplementedError:
337
+ print(f"Warning: Model {model} may not support structured output directly for score grading. Falling back.")
338
+ # Fallback: parse the string output if structured output fails
339
+ from langchain_core.output_parsers import SimpleJsonOutputParser
340
+ # Adjust prompt if needed
341
+ retrieval_grader_score = SCORE_GRADER_PROMPT | llm_grader_score | SimpleJsonOutputParser() # Or StrOutputParser and manual parsing
342
 
343
+ return retrieval_grader_score
344
 
345
+ # Update default model
346
+ def eval_doc(doc, query, method="binary", threshold=0.7, eval_model="gemini-2.0-flash"):
347
  '''
348
  doc : the document to evaluate
349
  query : the query to which to doc shoud be relevant
350
  method : "binary" or "score"
351
  threshold : for "score" method, score above which a doc is considered relevant
352
  '''
353
+ try:
354
+ if method == "binary":
355
+ retrieval_grader_binary = get_binary_grader(model=eval_model)
356
+ result = retrieval_grader_binary.invoke({"question": query, "document":doc})
357
+ # Handle both structured and parsed output
358
+ binary_score = result.binary_score if isinstance(result, GradeDocumentsBinary) else result.get("binary_score", "no")
359
+ return 1 if (binary_score.lower() == 'yes') else 0
360
+ elif method == "score":
361
+ retrieval_grader_score = get_score_grader(model=eval_model)
362
+ result = retrieval_grader_score.invoke({"query": query, "document":doc})
363
+ # Handle both structured and parsed output
364
+ score = result.score if isinstance(result, GradeDocumentsScore) else result.get("score")
365
+ if score is not None:
366
+ return score if float(score) >= threshold else 0
367
+ else:
368
+ print("Warning: Couldn't parse score, marking document as relevant by default.")
369
+ return 1 # Default to relevant if score parsing fails
370
  else:
371
+ raise ValueError("Invalid method")
372
+ except Exception as e:
373
+ print(f"Error evaluating document: {e}")
374
+ return 0 # Default to irrelevant on error
375
 
376
+ # Update default model
377
  def eval_docs(state: DocRetrieverState, config: ConfigSchema):
378
  """
379
  This node performs evaluation of the retrieved docs and
 
381
 
382
  eval_method = config["configurable"].get("eval_method") or "binary"
383
  MAX_DOCS = config["configurable"].get("max_docs") or 15
384
+ # Update default model name
385
+ eval_model_name = config["configurable"].get("eval_model") or "gemini-2.0-flash"
386
  valid_doc_scores = []
387
 
388
+ # Ensure 'docs' exists and is a list
389
+ docs_to_evaluate = state.get("docs", [])
390
+ if not isinstance(docs_to_evaluate, list):
391
+ print("Warning: 'docs' is not a list, skipping evaluation.")
392
+ docs_to_evaluate = []
393
+
394
+ # Sample safely
395
+ sample_size = min(25, len(docs_to_evaluate))
396
+ sampled_docs = sample(docs_to_evaluate, sample_size) if sample_size > 0 else []
397
+
398
+
399
+ for doc in sampled_docs:
400
+ # Ensure doc is not None before formatting
401
+ if doc is None:
402
+ print("Warning: Encountered None document during evaluation, skipping.")
403
+ continue
404
+
405
+ formatted_doc_str = format_doc(doc)
406
+ # Add basic check for empty formatted doc
407
+ if not formatted_doc_str.strip():
408
+ print(f"Warning: Skipping empty formatted document: {doc}")
409
+ continue
410
+
411
  score = eval_doc(
412
+ doc=formatted_doc_str,
413
  query=state["query"],
414
  method=eval_method,
415
  threshold=config["configurable"].get("eval_threshold") or 0.7,
416
+ eval_model=eval_model_name # Pass the eval_model name
417
  )
418
+ # Ensure score is numeric before appending
419
+ if isinstance(score, (int, float)):
420
+ if score > 0: # Only add if relevant (score > 0 or binary score == 1)
421
+ valid_doc_scores.append((doc, score))
422
+ else:
423
+ print(f"Warning: Received non-numeric score ({score}) for doc {doc}, skipping.")
424
+
425
 
426
  if eval_method == 'score':
427
  # Get at most MAX_DOCS items with the highest score if score method was used
428
+ valid_docs_sorted = sorted(valid_doc_scores, key=lambda x: x[1], reverse=True) # Sort descending
429
+ valid_docs = [valid_doc[0] for valid_doc in valid_docs_sorted[:MAX_DOCS]]
430
  else:
431
  # Get at mots MAX_DOCS items at random if binary method was used
432
  shuffle(valid_doc_scores)
433
  valid_docs = [valid_doc[0] for valid_doc in valid_doc_scores[:MAX_DOCS]]
434
 
435
+ # Ensure existing valid_docs is a list before concatenating
436
+ existing_valid_docs = state.get("valid_docs", [])
437
+ if not isinstance(existing_valid_docs, list):
438
+ existing_valid_docs = []
439
+
440
+ return {"valid_docs": valid_docs + existing_valid_docs}
441
 
442
 
443
 
 
463
 
464
  return graph_doc_retriever
465
 
466
+ # Remove Groq specific error handling function
467
+ # def error_concept_groq(msg,concepts,groq,question): ...