adrienbrdne commited on
Commit
77b16df
·
verified ·
1 Parent(s): aacc458

Update ki_gen/planner.py

Browse files
Files changed (1) hide show
  1. ki_gen/planner.py +199 -115
ki_gen/planner.py CHANGED
@@ -4,7 +4,12 @@ import re
4
  from typing import Annotated
5
  from typing_extensions import TypedDict
6
 
7
- from langchain_groq import ChatGroq
 
 
 
 
 
8
  from langchain_openai import ChatOpenAI
9
  from langchain_core.messages import SystemMessage, HumanMessage
10
  from langchain_community.graphs import Neo4jGraph
@@ -15,11 +20,11 @@ from langgraph.graph import add_messages
15
  from ki_gen.prompts import PLAN_GEN_PROMPT, PLAN_MODIFICATION_PROMPT
16
  from ki_gen.data_retriever import build_data_retriever_graph
17
  from ki_gen.data_processor import build_data_processor_graph
18
- from ki_gen.utils import ConfigSchema, State, HumanValidationState, DocProcessorState, DocRetrieverState
 
19
  from langgraph.checkpoint.sqlite import SqliteSaver
20
 
21
 
22
-
23
  ##########################################################################
24
  ###### NODES DEFINITION ######
25
  ##########################################################################
@@ -33,108 +38,138 @@ If needed, give me an updated plan to follow this instruction. If your plan alre
33
  output = HumanMessage(content=prompt)
34
  return {"messages" : [output]}
35
 
36
-
37
- def error_chatbot_groq(error, model_name, query): # Pass model_name instead of llm_groq object
38
- # Switch API key logic...
39
- if os.environ["GROQ_API_KEY"] == os.getenv("groq_api_key"):
40
- os.environ["GROQ_API_KEY"] = os.getenv("groq_api_key2")
41
- elif os.environ["GROQ_API_KEY"] == os.getenv("groq_api_key2"):
42
- os.environ["GROQ_API_KEY"] = os.getenv("groq_api_key3")
43
- else:
44
- os.environ["GROQ_API_KEY"] = os.getenv("groq_api_key")
45
-
46
- # Re-initialize the model *after* switching the key
47
- try:
48
- # Use the model_name passed in
49
- llm_groq_retry = ChatGroq(model=model_name)
50
- # Pass the original query messages
51
- return {"messages" : [llm_groq_retry.invoke(query)]}
52
- except Exception as retry_error:
53
- # Handle potential error during retry
54
- print(f"Error during retry: {retry_error}")
55
- # Decide what to return or raise here
56
- return {"messages": [SystemMessage(content=f"Failed to process after retry: {retry_error}")]}
57
-
58
 
59
  # Wrappers to call LLMs on the state messsages field
60
- def chatbot_llama(state: State):
 
 
 
 
61
  try:
62
- llm_llama = ChatGroq(model="llama3-70b-8192")
63
- return {"messages" : [llm_llama.invoke(state["messages"])]}
64
- except Exception as error:
65
- error_chatbot_groq(error,llm_llama,state["messages"])
66
- def chatbot_mixtral(state: State):
67
- print(state)
68
- llm_mixtral = ChatGroq(model="deepseek-r1-distill-llama-70b")
69
- print(llm_mixtral)
70
- return {"messages" : [llm_mixtral.invoke(state["messages"])]}
71
- # except Exception as error:
72
- # error_chatbot_groq(error,llm_mixtral,state["messages"])
73
- def chatbot_openai(state: State):
74
- llm_openai = ChatOpenAI(model='gpt-4o', base_url="https://llm.synapse.thalescloud.io/")
75
- return {"messages" : [llm_openai.invoke(state["messages"])]}
76
-
77
- chatbots = {"gpt-4o" : chatbot_openai,
78
- "deepseek-r1-distill-llama-70b" : chatbot_mixtral,
79
- "llama3-70b-8192" : chatbot_llama
80
- }
81
 
 
 
 
82
 
83
  def parse_plan(state: State):
84
  """
85
  This node parses the generated plan and writes in the 'store_plan' field of the state
86
  """
87
- plan = state["messages"][-3].content
88
- store_plan = re.split("\d\.", plan.split("Plan:\n")[1])[1:]
 
 
 
 
 
 
 
 
 
 
 
 
 
89
  try:
90
- store_plan[len(store_plan) - 1] = store_plan[len(store_plan) - 1].split("<END_OF_PLAN>")[0]
 
 
 
91
  except Exception as e:
92
- print(f"Error while removing <END_OF_PLAN> : {e}")
 
 
93
 
94
  return {"store_plan" : store_plan}
95
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
  def detail_step(state: State, config: ConfigSchema):
97
  """
98
  This node updates the value of the 'current_plan_step' field and defines the query to be used for the data_retriever.
99
  """
100
- print("test")
101
- print(state)
 
 
 
 
 
 
 
 
 
 
 
 
102
 
103
- if 'current_plan_step' in state.keys():
104
- print("all good chief")
105
- else:
106
- state["current_plan_step"] = None
107
 
108
- current_plan_step = state["current_plan_step"] + 1 if state["current_plan_step"] is not None else 0 # We just began a new step so we will increase current_plan_step at the end
 
109
  if config["configurable"].get("use_detailed_query"):
110
  prompt = HumanMessage(f"""Specify what additional information you need to proceed with the next step of your plan :
111
- Step {current_plan_step + 1} : {state['store_plan'][current_plan_step]}""")
112
- query = get_detailed_query(context = state["messages"] + [prompt], model=config["configurable"].get("main_llm"))
113
- return {"messages" : [prompt, query], "current_plan_step": current_plan_step, 'query' : query}
114
-
115
- return {"current_plan_step": current_plan_step, 'query' : state["store_plan"][current_plan_step], "valid_docs" : []}
 
 
 
 
116
 
117
- def get_detailed_query(context : list, model : str = "deepseek-r1-distill-llama-70b"):
118
- """
119
- Simple helper function for the detail_step node
120
- """
121
- if model == 'gpt-4o':
122
- llm = ChatOpenAI(model=model, base_url="https://llm.synapse.thalescloud.io/")
123
- else:
124
- llm = ChatGroq(model=model)
125
- return llm.invoke(context)
126
 
127
  def concatenate_data(state: State):
128
  """
129
  This node concatenates all the data that was processed by the data_processor and inserts it in the state's messages
130
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
131
  prompt = f"""#########TECHNICAL INFORMATION ############
132
- {str(state["valid_docs"])}
133
 
134
  ########END OF TECHNICAL INFORMATION#######
135
 
136
- Using the information provided above, proceed with step {state['current_plan_step'] + 1} of your plan :
137
- {state['store_plan'][state['current_plan_step']]}
138
  """
139
 
140
  return {"messages": [HumanMessage(content=prompt)]}
@@ -142,18 +177,32 @@ Using the information provided above, proceed with step {state['current_plan_ste
142
 
143
  def human_validation(state: HumanValidationState) -> HumanValidationState:
144
  """
145
- Dummy node to interrupt before
146
  """
147
- return {'process_steps' : []}
 
 
148
 
149
  def generate_ki(state: State):
150
  """
151
  This node inserts the prompt to begin Key Issues generation
152
  """
153
- print(f"THIS IS THE STATE FOR CURRENT PLAN STEP IN GENERATE_KI : {state}")
 
 
 
 
 
 
 
 
 
 
 
154
 
155
- prompt = f"""Using the information provided above, proceed with step 4 of your plan to provide the user with NEW and INNOVATIVE Key Issues :
156
- {state['store_plan'][state['current_plan_step'] + 1]}"""
 
157
 
158
  return {"messages" : [HumanMessage(content=prompt)]}
159
 
@@ -161,8 +210,19 @@ def detail_ki(state: State):
161
  """
162
  This node inserts the last prompt to detail the generated Key Issues
163
  """
164
- prompt = f"""Using the information provided above, proceed with step 5 of your plan to provide the user with NEW and INNOVATIVE Key Issues :
165
- {state['store_plan'][state['current_plan_step'] + 2]}"""
 
 
 
 
 
 
 
 
 
 
 
166
 
167
  return {"messages" : [HumanMessage(content=prompt)]}
168
 
@@ -174,44 +234,54 @@ def validate_plan(state: State):
174
  """
175
  Whether to regenerate the plan or to parse it
176
  """
177
- if "messages" in state and "My plan is correct" in state["messages"][-1].content:
178
- return "parse"
 
 
 
 
179
  return "validate"
180
 
181
  def next_plan_step(state: State, config: ConfigSchema):
182
  """
183
  Proceed to next plan step (either generate KI or retrieve more data)
184
  """
185
- if (state["current_plan_step"] == 2) and (config["configurable"].get('plan_method') == "modification"):
186
- return "generate_key_issues"
187
- if state["current_plan_step"] == len(state["store_plan"]) - 1:
 
 
188
  return "generate_key_issues"
189
  else:
190
  return "detail_step"
 
191
 
192
  def detail_or_data_retriever(state: State, config: ConfigSchema):
193
  """
194
- Detail the query to use for data retrieval or not
195
  """
 
196
  if config["configurable"].get("use_detailed_query"):
197
- return "chatbot_detail"
 
198
  else:
 
199
  return "data_retriever"
200
 
201
  def retrieve_or_process(state: State):
202
  """
203
- Process the retrieved docs or keep retrieving
204
  """
205
- if state['human_validated']:
 
 
 
206
  return "process"
207
- return "retrieve"
208
- # while True:
209
- # user_input = input(f"{len(state['valid_docs'])} were retreived. Do you want more documents (y/[n]) : ")
210
- # if user_input.lower() == "y":
211
- # return "retrieve"
212
- # if not user_input or user_input.lower() == "n":
213
- # return "process"
214
- # print("Please answer with 'y' or 'n'.\n")
215
 
216
 
217
  def build_planner_graph(memory, config):
@@ -222,30 +292,36 @@ def build_planner_graph(memory, config):
222
 
223
  graph_doc_retriever = build_data_retriever_graph(memory)
224
  graph_doc_processor = build_data_processor_graph(memory)
225
- graph_builder.add_node("chatbot_planner", chatbots[config["main_llm"]])
 
 
226
  graph_builder.add_node("validate", validate_node)
227
- graph_builder.add_node("chatbot_detail", chatbot_llama)
 
228
  graph_builder.add_node("parse", parse_plan)
229
- graph_builder.add_node("detail_step", detail_step)
230
- graph_builder.add_node("data_retriever", graph_doc_retriever, input=DocRetrieverState)
231
- graph_builder.add_node("human_validation", human_validation)
232
- graph_builder.add_node("data_processor", graph_doc_processor, input=DocProcessorState)
 
233
  graph_builder.add_node("concatenate_data", concatenate_data)
234
- graph_builder.add_node("chatbot_exec_step", chatbots[config["main_llm"]])
 
235
  graph_builder.add_node("generate_ki", generate_ki)
236
- graph_builder.add_node("chatbot_ki", chatbots[config["main_llm"]])
 
237
  graph_builder.add_node("detail_ki", detail_ki)
238
- graph_builder.add_node("chatbot_final", chatbots[config["main_llm"]])
 
239
 
 
240
  graph_builder.add_edge("validate", "chatbot_planner")
241
  graph_builder.add_edge("parse", "detail_step")
242
 
 
 
243
 
244
- # graph_builder.add_edge("detail_step", "chatbot2")
245
- graph_builder.add_edge("chatbot_detail", "data_retriever")
246
  graph_builder.add_edge("data_retriever", "human_validation")
247
-
248
-
249
  graph_builder.add_edge("data_processor", "concatenate_data")
250
  graph_builder.add_edge("concatenate_data", "chatbot_exec_step")
251
  graph_builder.add_edge("generate_ki", "chatbot_ki")
@@ -253,15 +329,18 @@ def build_planner_graph(memory, config):
253
  graph_builder.add_edge("detail_ki", "chatbot_final")
254
  graph_builder.add_edge("chatbot_final", "__end__")
255
 
 
256
  graph_builder.add_conditional_edges(
257
  "detail_step",
258
- detail_or_data_retriever,
 
259
  {"chatbot_detail": "chatbot_detail", "data_retriever": "data_retriever"}
260
  )
261
  graph_builder.add_conditional_edges(
262
  "human_validation",
263
  retrieve_or_process,
264
- {"retrieve" : "data_retriever", "process" : "data_processor"}
 
265
  )
266
  graph_builder.add_conditional_edges(
267
  "chatbot_planner",
@@ -270,13 +349,18 @@ def build_planner_graph(memory, config):
270
  )
271
  graph_builder.add_conditional_edges(
272
  "chatbot_exec_step",
273
- next_plan_step,
 
274
  {"generate_key_issues" : "generate_ki", "detail_step": "detail_step"}
275
  )
276
 
277
- graph_builder.set_entry_point("chatbot_planner")
 
 
 
278
  graph = graph_builder.compile(
279
  checkpointer=memory,
280
- interrupt_after=["parse", "chatbot_exec_step", "chatbot_final", "data_retriever"],
 
281
  )
282
  return graph
 
4
  from typing import Annotated
5
  from typing_extensions import TypedDict
6
 
7
+ # Remove ChatGroq import
8
+ # from langchain_groq import ChatGroq
9
+ # Add ChatGoogleGenerativeAI import
10
+ from langchain_google_genai import ChatGoogleGenerativeAI
11
+ import os # Add os import
12
+
13
  from langchain_openai import ChatOpenAI
14
  from langchain_core.messages import SystemMessage, HumanMessage
15
  from langchain_community.graphs import Neo4jGraph
 
20
  from ki_gen.prompts import PLAN_GEN_PROMPT, PLAN_MODIFICATION_PROMPT
21
  from ki_gen.data_retriever import build_data_retriever_graph
22
  from ki_gen.data_processor import build_data_processor_graph
23
+ # Import get_model which now handles Gemini
24
+ from ki_gen.utils import ConfigSchema, State, HumanValidationState, DocProcessorState, DocRetrieverState, get_model
25
  from langgraph.checkpoint.sqlite import SqliteSaver
26
 
27
 
 
28
  ##########################################################################
29
  ###### NODES DEFINITION ######
30
  ##########################################################################
 
38
  output = HumanMessage(content=prompt)
39
  return {"messages" : [output]}
40
 
41
+ # Remove Groq-specific error handler
42
+ # def error_chatbot_groq(error, model_name, query): ...
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
  # Wrappers to call LLMs on the state messsages field
45
+ # Simplify: Use get_model directly or a single chatbot function
46
+ def chatbot_node(state: State, config: ConfigSchema):
47
+ """Generic chatbot node using the main_llm from config."""
48
+ model_name = config["configurable"].get("main_llm") or "gemini-2.0-flash"
49
+ llm = get_model(model_name)
50
  try:
51
+ # Check if messages exist and are not empty
52
+ if "messages" in state and state["messages"]:
53
+ response = llm.invoke(state["messages"])
54
+ return {"messages": [response]}
55
+ else:
56
+ print("Warning: No messages found in state for chatbot_node.")
57
+ # Return state unchanged or an empty message list?
58
+ return {} # Or {"messages": []}
59
+ except Exception as e:
60
+ print(f"Error invoking model {model_name}: {e}")
61
+ # Handle error, maybe return an error message or empty dict
62
+ return {"messages": [SystemMessage(content=f"Error during generation: {e}")]}
63
+
 
 
 
 
 
 
64
 
65
+ # Remove old chatbot functions (chatbot_llama, chatbot_mixtral, chatbot_openai)
66
+ # Replace the chatbots dictionary with direct calls to the generic function or specific models via get_model
67
+ # This simplifies planner.py, relying on utils.py and config for model selection.
68
 
69
  def parse_plan(state: State):
70
  """
71
  This node parses the generated plan and writes in the 'store_plan' field of the state
72
  """
73
+ # Find the AI message likely containing the plan (often the second to last if validate_node was used)
74
+ plan_message_content = ""
75
+ if "messages" in state and len(state["messages"]) >= 1:
76
+ # Search backwards for the plan, as its position might vary
77
+ for msg in reversed(state["messages"]):
78
+ if hasattr(msg, 'content') and "Plan:" in msg.content and "<END_OF_PLAN>" in msg.content:
79
+ plan_message_content = msg.content
80
+ break # Found the plan
81
+
82
+ if not plan_message_content:
83
+ print("Error: Could not find plan message in state.")
84
+ # Handle error: maybe return current state or raise an exception
85
+ return state # Return unchanged state if plan not found
86
+
87
+ store_plan = []
88
  try:
89
+ # Improved parsing: handle potential variations in formatting
90
+ plan_section = plan_message_content.split("Plan:")[1].split("<END_OF_PLAN>")[0]
91
+ # Split by numbered steps, removing empty entries
92
+ store_plan = [step.strip() for step in re.split(r"\n\s*\d+\.\s*", plan_section) if step.strip()]
93
  except Exception as e:
94
+ print(f"Error while parsing plan: {e}")
95
+ # Handle parsing error, potentially keep store_plan empty or log the error
96
+ store_plan = [] # Reset plan on error
97
 
98
  return {"store_plan" : store_plan}
99
 
100
+ # Update get_detailed_query to use get_model and default model
101
+ def get_detailed_query(context : list, model : str = "gemini-2.0-flash"):
102
+ """
103
+ Simple helper function for the detail_step node
104
+ """
105
+ llm = get_model(model) # Use get_model
106
+ try:
107
+ return llm.invoke(context)
108
+ except Exception as e:
109
+ print(f"Error in get_detailed_query with model {model}: {e}")
110
+ # Return a default message or raise error
111
+ return SystemMessage(content=f"Error generating detailed query: {e}")
112
+
113
+
114
  def detail_step(state: State, config: ConfigSchema):
115
  """
116
  This node updates the value of the 'current_plan_step' field and defines the query to be used for the data_retriever.
117
  """
118
+ print("Entering detail_step") # Debug print
119
+ print(f"Current state keys: {state.keys()}") # Debug print
120
+
121
+ # Initialize current_plan_step if not present
122
+ current_plan_step = state.get("current_plan_step", -1) + 1
123
+
124
+ # Ensure store_plan exists and has enough steps
125
+ store_plan = state.get("store_plan", [])
126
+ if not store_plan or current_plan_step >= len(store_plan):
127
+ print(f"Warning: Plan step {current_plan_step} out of bounds or plan is empty.")
128
+ # Decide how to handle: end graph, return error state?
129
+ # For now, let's prevent index error and maybe signal an issue
130
+ # Returning an empty query might halt progress or cause issues downstream
131
+ return {"current_plan_step": current_plan_step, 'query' : "Error: Plan step unavailable.", "valid_docs" : []}
132
 
 
 
 
 
133
 
134
+ plan_step_description = store_plan[current_plan_step]
135
+
136
  if config["configurable"].get("use_detailed_query"):
137
  prompt = HumanMessage(f"""Specify what additional information you need to proceed with the next step of your plan :
138
+ Step {current_plan_step + 1} : {plan_step_description}""")
139
+ # Ensure messages exist before appending
140
+ current_messages = state.get("messages", [])
141
+ query_message = get_detailed_query(context = current_messages + [prompt], model=config["configurable"].get("main_llm", "gemini-2.0-flash"))
142
+ query_content = query_message.content if hasattr(query_message, 'content') else "Error: Could not get detailed query content."
143
+ return {"messages" : [prompt, query_message], "current_plan_step": current_plan_step, 'query' : query_content, "valid_docs": state.get("valid_docs", [])} # Ensure valid_docs is preserved
144
+
145
+ # If not using detailed query, use the plan step description directly
146
+ return {"current_plan_step": current_plan_step, 'query' : plan_step_description, "valid_docs" : state.get("valid_docs", [])} # Ensure valid_docs is preserved
147
 
 
 
 
 
 
 
 
 
 
148
 
149
  def concatenate_data(state: State):
150
  """
151
  This node concatenates all the data that was processed by the data_processor and inserts it in the state's messages
152
  """
153
+ # Ensure valid_docs exists and current_plan_step is valid
154
+ valid_docs_content = state.get("valid_docs", "No processed documents available.")
155
+ current_plan_step = state.get("current_plan_step", -1)
156
+ store_plan = state.get("store_plan", [])
157
+
158
+ if current_plan_step < 0 or current_plan_step >= len(store_plan):
159
+ print(f"Warning: Invalid current_plan_step ({current_plan_step}) in concatenate_data.")
160
+ # Handle error - maybe return an error message
161
+ step_description = "Error: Current plan step invalid."
162
+ else:
163
+ step_description = store_plan[current_plan_step]
164
+
165
+
166
  prompt = f"""#########TECHNICAL INFORMATION ############
167
+ {str(valid_docs_content)}
168
 
169
  ########END OF TECHNICAL INFORMATION#######
170
 
171
+ Using the information provided above, proceed with step {current_plan_step + 1} of your plan :
172
+ {step_description}
173
  """
174
 
175
  return {"messages": [HumanMessage(content=prompt)]}
 
177
 
178
  def human_validation(state: HumanValidationState) -> HumanValidationState:
179
  """
180
+ Dummy node to interrupt before processing, can be used for manual validation later.
181
  """
182
+ # Defaulting to no processing steps needed unless specified elsewhere
183
+ return {'process_steps' : state.get('process_steps', [])}
184
+
185
 
186
  def generate_ki(state: State):
187
  """
188
  This node inserts the prompt to begin Key Issues generation
189
  """
190
+ print(f"THIS IS THE STATE FOR CURRENT PLAN STEP IN GENERATE_KI : {state.get('current_plan_step')}")
191
+
192
+ current_plan_step = state.get("current_plan_step", -1)
193
+ store_plan = state.get("store_plan", [])
194
+
195
+ # Check if the next step exists in the plan
196
+ next_step_index = current_plan_step + 1
197
+ if next_step_index < 0 or next_step_index >= len(store_plan):
198
+ print(f"Warning: Invalid next plan step ({next_step_index}) for KI generation.")
199
+ step_description = "Error: Plan step for KI generation unavailable."
200
+ else:
201
+ step_description = store_plan[next_step_index]
202
 
203
+
204
+ prompt = f"""Using the information provided above, proceed with step {next_step_index + 1} of your plan to provide the user with NEW and INNOVATIVE Key Issues :
205
+ {step_description}"""
206
 
207
  return {"messages" : [HumanMessage(content=prompt)]}
208
 
 
210
  """
211
  This node inserts the last prompt to detail the generated Key Issues
212
  """
213
+ current_plan_step = state.get("current_plan_step", -1)
214
+ store_plan = state.get("store_plan", [])
215
+
216
+ # Check if the step after next exists in the plan
217
+ detail_step_index = current_plan_step + 2
218
+ if detail_step_index < 0 or detail_step_index >= len(store_plan):
219
+ print(f"Warning: Invalid plan step ({detail_step_index}) for KI detailing.")
220
+ step_description = "Error: Plan step for KI detailing unavailable."
221
+ else:
222
+ step_description = store_plan[detail_step_index]
223
+
224
+ prompt = f"""Using the information provided above, proceed with step {detail_step_index + 1} of your plan to provide the user with NEW and INNOVATIVE Key Issues :
225
+ {step_description}"""
226
 
227
  return {"messages" : [HumanMessage(content=prompt)]}
228
 
 
234
  """
235
  Whether to regenerate the plan or to parse it
236
  """
237
+ # Check the last message for "My plan is correct"
238
+ if "messages" in state and state["messages"]:
239
+ last_message = state["messages"][-1]
240
+ if hasattr(last_message, 'content') and "My plan is correct" in last_message.content:
241
+ return "parse"
242
+ # Default to validate (regenerate) if condition not met or messages are missing
243
  return "validate"
244
 
245
  def next_plan_step(state: State, config: ConfigSchema):
246
  """
247
  Proceed to next plan step (either generate KI or retrieve more data)
248
  """
249
+ current_plan_step = state.get("current_plan_step", -1)
250
+ store_plan_len = len(state.get("store_plan", []))
251
+
252
+ # Simplified logic: go to KI generation if it's the last step based on plan length
253
+ if current_plan_step >= store_plan_len - 1:
254
  return "generate_key_issues"
255
  else:
256
  return "detail_step"
257
+
258
 
259
  def detail_or_data_retriever(state: State, config: ConfigSchema):
260
  """
261
+ Decide whether to detail the query or go straight to data retrieval.
262
  """
263
+ # Check configuration if detailed query is needed
264
  if config["configurable"].get("use_detailed_query"):
265
+ # Need to invoke the LLM to get the detailed query
266
+ return "chatbot_detail"
267
  else:
268
+ # Use the plan step directly as the query
269
  return "data_retriever"
270
 
271
  def retrieve_or_process(state: State):
272
  """
273
+ Process the retrieved docs or keep retrieving (based on human_validated flag).
274
  """
275
+ # Check the 'human_validated' flag in the state
276
+ # This flag needs to be set externally (e.g., by Streamlit UI or another mechanism)
277
+ # before this node is reached after data_retriever.
278
+ if state.get('human_validated'):
279
  return "process"
280
+ else:
281
+ # If not validated, loop back to retrieve more (or wait for validation)
282
+ # This assumes data_retriever might be called again or the graph waits.
283
+ # In the Streamlit app, the human_validation node allows setting this flag.
284
+ return "retrieve"
 
 
 
285
 
286
 
287
  def build_planner_graph(memory, config):
 
292
 
293
  graph_doc_retriever = build_data_retriever_graph(memory)
294
  graph_doc_processor = build_data_processor_graph(memory)
295
+
296
+ # Use the generic chatbot node function
297
+ graph_builder.add_node("chatbot_planner", lambda state: chatbot_node(state, config))
298
  graph_builder.add_node("validate", validate_node)
299
+ # Add node for chatbot interaction when detailed query is needed
300
+ graph_builder.add_node("chatbot_detail", lambda state: chatbot_node(state, config))
301
  graph_builder.add_node("parse", parse_plan)
302
+ # Pass config to detail_step as it needs it now
303
+ graph_builder.add_node("detail_step", lambda state: detail_step(state, config))
304
+ graph_builder.add_node("data_retriever", graph_doc_retriever) # Input mapping happens automatically if state keys match
305
+ graph_builder.add_node("human_validation", human_validation) # Needs input mapping if HumanValidationState differs significantly
306
+ graph_builder.add_node("data_processor", graph_doc_processor) # Needs input mapping if DocProcessorState differs significantly
307
  graph_builder.add_node("concatenate_data", concatenate_data)
308
+ # Use the generic chatbot node function
309
+ graph_builder.add_node("chatbot_exec_step", lambda state: chatbot_node(state, config))
310
  graph_builder.add_node("generate_ki", generate_ki)
311
+ # Use the generic chatbot node function
312
+ graph_builder.add_node("chatbot_ki", lambda state: chatbot_node(state, config))
313
  graph_builder.add_node("detail_ki", detail_ki)
314
+ # Use the generic chatbot node function
315
+ graph_builder.add_node("chatbot_final", lambda state: chatbot_node(state, config))
316
 
317
+ # Define edges
318
  graph_builder.add_edge("validate", "chatbot_planner")
319
  graph_builder.add_edge("parse", "detail_step")
320
 
321
+ # Edge from chatbot_detail (after getting detailed query) to data_retriever
322
+ graph_builder.add_edge("chatbot_detail", "data_retriever")
323
 
 
 
324
  graph_builder.add_edge("data_retriever", "human_validation")
 
 
325
  graph_builder.add_edge("data_processor", "concatenate_data")
326
  graph_builder.add_edge("concatenate_data", "chatbot_exec_step")
327
  graph_builder.add_edge("generate_ki", "chatbot_ki")
 
329
  graph_builder.add_edge("detail_ki", "chatbot_final")
330
  graph_builder.add_edge("chatbot_final", "__end__")
331
 
332
+ # Define conditional edges
333
  graph_builder.add_conditional_edges(
334
  "detail_step",
335
+ # Pass config to the conditional function
336
+ lambda state: detail_or_data_retriever(state, config),
337
  {"chatbot_detail": "chatbot_detail", "data_retriever": "data_retriever"}
338
  )
339
  graph_builder.add_conditional_edges(
340
  "human_validation",
341
  retrieve_or_process,
342
+ # Map 'retrieve' back to 'data_retriever' node, 'process' to 'data_processor'
343
+ {"retrieve" : "data_retriever", "process" : "data_processor"}
344
  )
345
  graph_builder.add_conditional_edges(
346
  "chatbot_planner",
 
349
  )
350
  graph_builder.add_conditional_edges(
351
  "chatbot_exec_step",
352
+ # Pass config to the conditional function
353
+ lambda state: next_plan_step(state, config),
354
  {"generate_key_issues" : "generate_ki", "detail_step": "detail_step"}
355
  )
356
 
357
+ # Set entry point
358
+ graph_builder.set_entry_point("chatbot_planner")
359
+
360
+ # Compile the graph
361
  graph = graph_builder.compile(
362
  checkpointer=memory,
363
+ # Define interrupt points if needed for human interaction or debugging
364
+ interrupt_after=["human_validation", "chatbot_final"],
365
  )
366
  return graph