Spaces:
Sleeping
Sleeping
Update ki_gen/planner.py
Browse files- ki_gen/planner.py +199 -115
ki_gen/planner.py
CHANGED
@@ -4,7 +4,12 @@ import re
|
|
4 |
from typing import Annotated
|
5 |
from typing_extensions import TypedDict
|
6 |
|
7 |
-
|
|
|
|
|
|
|
|
|
|
|
8 |
from langchain_openai import ChatOpenAI
|
9 |
from langchain_core.messages import SystemMessage, HumanMessage
|
10 |
from langchain_community.graphs import Neo4jGraph
|
@@ -15,11 +20,11 @@ from langgraph.graph import add_messages
|
|
15 |
from ki_gen.prompts import PLAN_GEN_PROMPT, PLAN_MODIFICATION_PROMPT
|
16 |
from ki_gen.data_retriever import build_data_retriever_graph
|
17 |
from ki_gen.data_processor import build_data_processor_graph
|
18 |
-
|
|
|
19 |
from langgraph.checkpoint.sqlite import SqliteSaver
|
20 |
|
21 |
|
22 |
-
|
23 |
##########################################################################
|
24 |
###### NODES DEFINITION ######
|
25 |
##########################################################################
|
@@ -33,108 +38,138 @@ If needed, give me an updated plan to follow this instruction. If your plan alre
|
|
33 |
output = HumanMessage(content=prompt)
|
34 |
return {"messages" : [output]}
|
35 |
|
36 |
-
|
37 |
-
def error_chatbot_groq(error, model_name, query):
|
38 |
-
# Switch API key logic...
|
39 |
-
if os.environ["GROQ_API_KEY"] == os.getenv("groq_api_key"):
|
40 |
-
os.environ["GROQ_API_KEY"] = os.getenv("groq_api_key2")
|
41 |
-
elif os.environ["GROQ_API_KEY"] == os.getenv("groq_api_key2"):
|
42 |
-
os.environ["GROQ_API_KEY"] = os.getenv("groq_api_key3")
|
43 |
-
else:
|
44 |
-
os.environ["GROQ_API_KEY"] = os.getenv("groq_api_key")
|
45 |
-
|
46 |
-
# Re-initialize the model *after* switching the key
|
47 |
-
try:
|
48 |
-
# Use the model_name passed in
|
49 |
-
llm_groq_retry = ChatGroq(model=model_name)
|
50 |
-
# Pass the original query messages
|
51 |
-
return {"messages" : [llm_groq_retry.invoke(query)]}
|
52 |
-
except Exception as retry_error:
|
53 |
-
# Handle potential error during retry
|
54 |
-
print(f"Error during retry: {retry_error}")
|
55 |
-
# Decide what to return or raise here
|
56 |
-
return {"messages": [SystemMessage(content=f"Failed to process after retry: {retry_error}")]}
|
57 |
-
|
58 |
|
59 |
# Wrappers to call LLMs on the state messsages field
|
60 |
-
|
|
|
|
|
|
|
|
|
61 |
try:
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
#
|
73 |
-
|
74 |
-
|
75 |
-
return {"messages" : [llm_openai.invoke(state["messages"])]}
|
76 |
-
|
77 |
-
chatbots = {"gpt-4o" : chatbot_openai,
|
78 |
-
"deepseek-r1-distill-llama-70b" : chatbot_mixtral,
|
79 |
-
"llama3-70b-8192" : chatbot_llama
|
80 |
-
}
|
81 |
|
|
|
|
|
|
|
82 |
|
83 |
def parse_plan(state: State):
|
84 |
"""
|
85 |
This node parses the generated plan and writes in the 'store_plan' field of the state
|
86 |
"""
|
87 |
-
plan
|
88 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
89 |
try:
|
90 |
-
|
|
|
|
|
|
|
91 |
except Exception as e:
|
92 |
-
print(f"Error while
|
|
|
|
|
93 |
|
94 |
return {"store_plan" : store_plan}
|
95 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
96 |
def detail_step(state: State, config: ConfigSchema):
|
97 |
"""
|
98 |
This node updates the value of the 'current_plan_step' field and defines the query to be used for the data_retriever.
|
99 |
"""
|
100 |
-
print("
|
101 |
-
print(state)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
102 |
|
103 |
-
if 'current_plan_step' in state.keys():
|
104 |
-
print("all good chief")
|
105 |
-
else:
|
106 |
-
state["current_plan_step"] = None
|
107 |
|
108 |
-
|
|
|
109 |
if config["configurable"].get("use_detailed_query"):
|
110 |
prompt = HumanMessage(f"""Specify what additional information you need to proceed with the next step of your plan :
|
111 |
-
Step {current_plan_step + 1} : {
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
|
|
|
|
|
|
|
|
116 |
|
117 |
-
def get_detailed_query(context : list, model : str = "deepseek-r1-distill-llama-70b"):
|
118 |
-
"""
|
119 |
-
Simple helper function for the detail_step node
|
120 |
-
"""
|
121 |
-
if model == 'gpt-4o':
|
122 |
-
llm = ChatOpenAI(model=model, base_url="https://llm.synapse.thalescloud.io/")
|
123 |
-
else:
|
124 |
-
llm = ChatGroq(model=model)
|
125 |
-
return llm.invoke(context)
|
126 |
|
127 |
def concatenate_data(state: State):
|
128 |
"""
|
129 |
This node concatenates all the data that was processed by the data_processor and inserts it in the state's messages
|
130 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
131 |
prompt = f"""#########TECHNICAL INFORMATION ############
|
132 |
-
{str(
|
133 |
|
134 |
########END OF TECHNICAL INFORMATION#######
|
135 |
|
136 |
-
Using the information provided above, proceed with step {
|
137 |
-
{
|
138 |
"""
|
139 |
|
140 |
return {"messages": [HumanMessage(content=prompt)]}
|
@@ -142,18 +177,32 @@ Using the information provided above, proceed with step {state['current_plan_ste
|
|
142 |
|
143 |
def human_validation(state: HumanValidationState) -> HumanValidationState:
|
144 |
"""
|
145 |
-
Dummy node to interrupt before
|
146 |
"""
|
147 |
-
|
|
|
|
|
148 |
|
149 |
def generate_ki(state: State):
|
150 |
"""
|
151 |
This node inserts the prompt to begin Key Issues generation
|
152 |
"""
|
153 |
-
print(f"THIS IS THE STATE FOR CURRENT PLAN STEP IN GENERATE_KI : {state}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
154 |
|
155 |
-
|
156 |
-
{
|
|
|
157 |
|
158 |
return {"messages" : [HumanMessage(content=prompt)]}
|
159 |
|
@@ -161,8 +210,19 @@ def detail_ki(state: State):
|
|
161 |
"""
|
162 |
This node inserts the last prompt to detail the generated Key Issues
|
163 |
"""
|
164 |
-
|
165 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
166 |
|
167 |
return {"messages" : [HumanMessage(content=prompt)]}
|
168 |
|
@@ -174,44 +234,54 @@ def validate_plan(state: State):
|
|
174 |
"""
|
175 |
Whether to regenerate the plan or to parse it
|
176 |
"""
|
177 |
-
|
178 |
-
|
|
|
|
|
|
|
|
|
179 |
return "validate"
|
180 |
|
181 |
def next_plan_step(state: State, config: ConfigSchema):
|
182 |
"""
|
183 |
Proceed to next plan step (either generate KI or retrieve more data)
|
184 |
"""
|
185 |
-
|
186 |
-
|
187 |
-
|
|
|
|
|
188 |
return "generate_key_issues"
|
189 |
else:
|
190 |
return "detail_step"
|
|
|
191 |
|
192 |
def detail_or_data_retriever(state: State, config: ConfigSchema):
|
193 |
"""
|
194 |
-
|
195 |
"""
|
|
|
196 |
if config["configurable"].get("use_detailed_query"):
|
197 |
-
|
|
|
198 |
else:
|
|
|
199 |
return "data_retriever"
|
200 |
|
201 |
def retrieve_or_process(state: State):
|
202 |
"""
|
203 |
-
Process the retrieved docs or keep retrieving
|
204 |
"""
|
205 |
-
|
|
|
|
|
|
|
206 |
return "process"
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
# if not user_input or user_input.lower() == "n":
|
213 |
-
# return "process"
|
214 |
-
# print("Please answer with 'y' or 'n'.\n")
|
215 |
|
216 |
|
217 |
def build_planner_graph(memory, config):
|
@@ -222,30 +292,36 @@ def build_planner_graph(memory, config):
|
|
222 |
|
223 |
graph_doc_retriever = build_data_retriever_graph(memory)
|
224 |
graph_doc_processor = build_data_processor_graph(memory)
|
225 |
-
|
|
|
|
|
226 |
graph_builder.add_node("validate", validate_node)
|
227 |
-
|
|
|
228 |
graph_builder.add_node("parse", parse_plan)
|
229 |
-
|
230 |
-
graph_builder.add_node("
|
231 |
-
graph_builder.add_node("
|
232 |
-
graph_builder.add_node("
|
|
|
233 |
graph_builder.add_node("concatenate_data", concatenate_data)
|
234 |
-
|
|
|
235 |
graph_builder.add_node("generate_ki", generate_ki)
|
236 |
-
|
|
|
237 |
graph_builder.add_node("detail_ki", detail_ki)
|
238 |
-
|
|
|
239 |
|
|
|
240 |
graph_builder.add_edge("validate", "chatbot_planner")
|
241 |
graph_builder.add_edge("parse", "detail_step")
|
242 |
|
|
|
|
|
243 |
|
244 |
-
# graph_builder.add_edge("detail_step", "chatbot2")
|
245 |
-
graph_builder.add_edge("chatbot_detail", "data_retriever")
|
246 |
graph_builder.add_edge("data_retriever", "human_validation")
|
247 |
-
|
248 |
-
|
249 |
graph_builder.add_edge("data_processor", "concatenate_data")
|
250 |
graph_builder.add_edge("concatenate_data", "chatbot_exec_step")
|
251 |
graph_builder.add_edge("generate_ki", "chatbot_ki")
|
@@ -253,15 +329,18 @@ def build_planner_graph(memory, config):
|
|
253 |
graph_builder.add_edge("detail_ki", "chatbot_final")
|
254 |
graph_builder.add_edge("chatbot_final", "__end__")
|
255 |
|
|
|
256 |
graph_builder.add_conditional_edges(
|
257 |
"detail_step",
|
258 |
-
|
|
|
259 |
{"chatbot_detail": "chatbot_detail", "data_retriever": "data_retriever"}
|
260 |
)
|
261 |
graph_builder.add_conditional_edges(
|
262 |
"human_validation",
|
263 |
retrieve_or_process,
|
264 |
-
|
|
|
265 |
)
|
266 |
graph_builder.add_conditional_edges(
|
267 |
"chatbot_planner",
|
@@ -270,13 +349,18 @@ def build_planner_graph(memory, config):
|
|
270 |
)
|
271 |
graph_builder.add_conditional_edges(
|
272 |
"chatbot_exec_step",
|
273 |
-
|
|
|
274 |
{"generate_key_issues" : "generate_ki", "detail_step": "detail_step"}
|
275 |
)
|
276 |
|
277 |
-
|
|
|
|
|
|
|
278 |
graph = graph_builder.compile(
|
279 |
checkpointer=memory,
|
280 |
-
|
|
|
281 |
)
|
282 |
return graph
|
|
|
4 |
from typing import Annotated
|
5 |
from typing_extensions import TypedDict
|
6 |
|
7 |
+
# Remove ChatGroq import
|
8 |
+
# from langchain_groq import ChatGroq
|
9 |
+
# Add ChatGoogleGenerativeAI import
|
10 |
+
from langchain_google_genai import ChatGoogleGenerativeAI
|
11 |
+
import os # Add os import
|
12 |
+
|
13 |
from langchain_openai import ChatOpenAI
|
14 |
from langchain_core.messages import SystemMessage, HumanMessage
|
15 |
from langchain_community.graphs import Neo4jGraph
|
|
|
20 |
from ki_gen.prompts import PLAN_GEN_PROMPT, PLAN_MODIFICATION_PROMPT
|
21 |
from ki_gen.data_retriever import build_data_retriever_graph
|
22 |
from ki_gen.data_processor import build_data_processor_graph
|
23 |
+
# Import get_model which now handles Gemini
|
24 |
+
from ki_gen.utils import ConfigSchema, State, HumanValidationState, DocProcessorState, DocRetrieverState, get_model
|
25 |
from langgraph.checkpoint.sqlite import SqliteSaver
|
26 |
|
27 |
|
|
|
28 |
##########################################################################
|
29 |
###### NODES DEFINITION ######
|
30 |
##########################################################################
|
|
|
38 |
output = HumanMessage(content=prompt)
|
39 |
return {"messages" : [output]}
|
40 |
|
41 |
+
# Remove Groq-specific error handler
|
42 |
+
# def error_chatbot_groq(error, model_name, query): ...
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
43 |
|
44 |
# Wrappers to call LLMs on the state messsages field
|
45 |
+
# Simplify: Use get_model directly or a single chatbot function
|
46 |
+
def chatbot_node(state: State, config: ConfigSchema):
|
47 |
+
"""Generic chatbot node using the main_llm from config."""
|
48 |
+
model_name = config["configurable"].get("main_llm") or "gemini-2.0-flash"
|
49 |
+
llm = get_model(model_name)
|
50 |
try:
|
51 |
+
# Check if messages exist and are not empty
|
52 |
+
if "messages" in state and state["messages"]:
|
53 |
+
response = llm.invoke(state["messages"])
|
54 |
+
return {"messages": [response]}
|
55 |
+
else:
|
56 |
+
print("Warning: No messages found in state for chatbot_node.")
|
57 |
+
# Return state unchanged or an empty message list?
|
58 |
+
return {} # Or {"messages": []}
|
59 |
+
except Exception as e:
|
60 |
+
print(f"Error invoking model {model_name}: {e}")
|
61 |
+
# Handle error, maybe return an error message or empty dict
|
62 |
+
return {"messages": [SystemMessage(content=f"Error during generation: {e}")]}
|
63 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
|
65 |
+
# Remove old chatbot functions (chatbot_llama, chatbot_mixtral, chatbot_openai)
|
66 |
+
# Replace the chatbots dictionary with direct calls to the generic function or specific models via get_model
|
67 |
+
# This simplifies planner.py, relying on utils.py and config for model selection.
|
68 |
|
69 |
def parse_plan(state: State):
|
70 |
"""
|
71 |
This node parses the generated plan and writes in the 'store_plan' field of the state
|
72 |
"""
|
73 |
+
# Find the AI message likely containing the plan (often the second to last if validate_node was used)
|
74 |
+
plan_message_content = ""
|
75 |
+
if "messages" in state and len(state["messages"]) >= 1:
|
76 |
+
# Search backwards for the plan, as its position might vary
|
77 |
+
for msg in reversed(state["messages"]):
|
78 |
+
if hasattr(msg, 'content') and "Plan:" in msg.content and "<END_OF_PLAN>" in msg.content:
|
79 |
+
plan_message_content = msg.content
|
80 |
+
break # Found the plan
|
81 |
+
|
82 |
+
if not plan_message_content:
|
83 |
+
print("Error: Could not find plan message in state.")
|
84 |
+
# Handle error: maybe return current state or raise an exception
|
85 |
+
return state # Return unchanged state if plan not found
|
86 |
+
|
87 |
+
store_plan = []
|
88 |
try:
|
89 |
+
# Improved parsing: handle potential variations in formatting
|
90 |
+
plan_section = plan_message_content.split("Plan:")[1].split("<END_OF_PLAN>")[0]
|
91 |
+
# Split by numbered steps, removing empty entries
|
92 |
+
store_plan = [step.strip() for step in re.split(r"\n\s*\d+\.\s*", plan_section) if step.strip()]
|
93 |
except Exception as e:
|
94 |
+
print(f"Error while parsing plan: {e}")
|
95 |
+
# Handle parsing error, potentially keep store_plan empty or log the error
|
96 |
+
store_plan = [] # Reset plan on error
|
97 |
|
98 |
return {"store_plan" : store_plan}
|
99 |
|
100 |
+
# Update get_detailed_query to use get_model and default model
|
101 |
+
def get_detailed_query(context : list, model : str = "gemini-2.0-flash"):
|
102 |
+
"""
|
103 |
+
Simple helper function for the detail_step node
|
104 |
+
"""
|
105 |
+
llm = get_model(model) # Use get_model
|
106 |
+
try:
|
107 |
+
return llm.invoke(context)
|
108 |
+
except Exception as e:
|
109 |
+
print(f"Error in get_detailed_query with model {model}: {e}")
|
110 |
+
# Return a default message or raise error
|
111 |
+
return SystemMessage(content=f"Error generating detailed query: {e}")
|
112 |
+
|
113 |
+
|
114 |
def detail_step(state: State, config: ConfigSchema):
|
115 |
"""
|
116 |
This node updates the value of the 'current_plan_step' field and defines the query to be used for the data_retriever.
|
117 |
"""
|
118 |
+
print("Entering detail_step") # Debug print
|
119 |
+
print(f"Current state keys: {state.keys()}") # Debug print
|
120 |
+
|
121 |
+
# Initialize current_plan_step if not present
|
122 |
+
current_plan_step = state.get("current_plan_step", -1) + 1
|
123 |
+
|
124 |
+
# Ensure store_plan exists and has enough steps
|
125 |
+
store_plan = state.get("store_plan", [])
|
126 |
+
if not store_plan or current_plan_step >= len(store_plan):
|
127 |
+
print(f"Warning: Plan step {current_plan_step} out of bounds or plan is empty.")
|
128 |
+
# Decide how to handle: end graph, return error state?
|
129 |
+
# For now, let's prevent index error and maybe signal an issue
|
130 |
+
# Returning an empty query might halt progress or cause issues downstream
|
131 |
+
return {"current_plan_step": current_plan_step, 'query' : "Error: Plan step unavailable.", "valid_docs" : []}
|
132 |
|
|
|
|
|
|
|
|
|
133 |
|
134 |
+
plan_step_description = store_plan[current_plan_step]
|
135 |
+
|
136 |
if config["configurable"].get("use_detailed_query"):
|
137 |
prompt = HumanMessage(f"""Specify what additional information you need to proceed with the next step of your plan :
|
138 |
+
Step {current_plan_step + 1} : {plan_step_description}""")
|
139 |
+
# Ensure messages exist before appending
|
140 |
+
current_messages = state.get("messages", [])
|
141 |
+
query_message = get_detailed_query(context = current_messages + [prompt], model=config["configurable"].get("main_llm", "gemini-2.0-flash"))
|
142 |
+
query_content = query_message.content if hasattr(query_message, 'content') else "Error: Could not get detailed query content."
|
143 |
+
return {"messages" : [prompt, query_message], "current_plan_step": current_plan_step, 'query' : query_content, "valid_docs": state.get("valid_docs", [])} # Ensure valid_docs is preserved
|
144 |
+
|
145 |
+
# If not using detailed query, use the plan step description directly
|
146 |
+
return {"current_plan_step": current_plan_step, 'query' : plan_step_description, "valid_docs" : state.get("valid_docs", [])} # Ensure valid_docs is preserved
|
147 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
148 |
|
149 |
def concatenate_data(state: State):
|
150 |
"""
|
151 |
This node concatenates all the data that was processed by the data_processor and inserts it in the state's messages
|
152 |
"""
|
153 |
+
# Ensure valid_docs exists and current_plan_step is valid
|
154 |
+
valid_docs_content = state.get("valid_docs", "No processed documents available.")
|
155 |
+
current_plan_step = state.get("current_plan_step", -1)
|
156 |
+
store_plan = state.get("store_plan", [])
|
157 |
+
|
158 |
+
if current_plan_step < 0 or current_plan_step >= len(store_plan):
|
159 |
+
print(f"Warning: Invalid current_plan_step ({current_plan_step}) in concatenate_data.")
|
160 |
+
# Handle error - maybe return an error message
|
161 |
+
step_description = "Error: Current plan step invalid."
|
162 |
+
else:
|
163 |
+
step_description = store_plan[current_plan_step]
|
164 |
+
|
165 |
+
|
166 |
prompt = f"""#########TECHNICAL INFORMATION ############
|
167 |
+
{str(valid_docs_content)}
|
168 |
|
169 |
########END OF TECHNICAL INFORMATION#######
|
170 |
|
171 |
+
Using the information provided above, proceed with step {current_plan_step + 1} of your plan :
|
172 |
+
{step_description}
|
173 |
"""
|
174 |
|
175 |
return {"messages": [HumanMessage(content=prompt)]}
|
|
|
177 |
|
178 |
def human_validation(state: HumanValidationState) -> HumanValidationState:
|
179 |
"""
|
180 |
+
Dummy node to interrupt before processing, can be used for manual validation later.
|
181 |
"""
|
182 |
+
# Defaulting to no processing steps needed unless specified elsewhere
|
183 |
+
return {'process_steps' : state.get('process_steps', [])}
|
184 |
+
|
185 |
|
186 |
def generate_ki(state: State):
|
187 |
"""
|
188 |
This node inserts the prompt to begin Key Issues generation
|
189 |
"""
|
190 |
+
print(f"THIS IS THE STATE FOR CURRENT PLAN STEP IN GENERATE_KI : {state.get('current_plan_step')}")
|
191 |
+
|
192 |
+
current_plan_step = state.get("current_plan_step", -1)
|
193 |
+
store_plan = state.get("store_plan", [])
|
194 |
+
|
195 |
+
# Check if the next step exists in the plan
|
196 |
+
next_step_index = current_plan_step + 1
|
197 |
+
if next_step_index < 0 or next_step_index >= len(store_plan):
|
198 |
+
print(f"Warning: Invalid next plan step ({next_step_index}) for KI generation.")
|
199 |
+
step_description = "Error: Plan step for KI generation unavailable."
|
200 |
+
else:
|
201 |
+
step_description = store_plan[next_step_index]
|
202 |
|
203 |
+
|
204 |
+
prompt = f"""Using the information provided above, proceed with step {next_step_index + 1} of your plan to provide the user with NEW and INNOVATIVE Key Issues :
|
205 |
+
{step_description}"""
|
206 |
|
207 |
return {"messages" : [HumanMessage(content=prompt)]}
|
208 |
|
|
|
210 |
"""
|
211 |
This node inserts the last prompt to detail the generated Key Issues
|
212 |
"""
|
213 |
+
current_plan_step = state.get("current_plan_step", -1)
|
214 |
+
store_plan = state.get("store_plan", [])
|
215 |
+
|
216 |
+
# Check if the step after next exists in the plan
|
217 |
+
detail_step_index = current_plan_step + 2
|
218 |
+
if detail_step_index < 0 or detail_step_index >= len(store_plan):
|
219 |
+
print(f"Warning: Invalid plan step ({detail_step_index}) for KI detailing.")
|
220 |
+
step_description = "Error: Plan step for KI detailing unavailable."
|
221 |
+
else:
|
222 |
+
step_description = store_plan[detail_step_index]
|
223 |
+
|
224 |
+
prompt = f"""Using the information provided above, proceed with step {detail_step_index + 1} of your plan to provide the user with NEW and INNOVATIVE Key Issues :
|
225 |
+
{step_description}"""
|
226 |
|
227 |
return {"messages" : [HumanMessage(content=prompt)]}
|
228 |
|
|
|
234 |
"""
|
235 |
Whether to regenerate the plan or to parse it
|
236 |
"""
|
237 |
+
# Check the last message for "My plan is correct"
|
238 |
+
if "messages" in state and state["messages"]:
|
239 |
+
last_message = state["messages"][-1]
|
240 |
+
if hasattr(last_message, 'content') and "My plan is correct" in last_message.content:
|
241 |
+
return "parse"
|
242 |
+
# Default to validate (regenerate) if condition not met or messages are missing
|
243 |
return "validate"
|
244 |
|
245 |
def next_plan_step(state: State, config: ConfigSchema):
|
246 |
"""
|
247 |
Proceed to next plan step (either generate KI or retrieve more data)
|
248 |
"""
|
249 |
+
current_plan_step = state.get("current_plan_step", -1)
|
250 |
+
store_plan_len = len(state.get("store_plan", []))
|
251 |
+
|
252 |
+
# Simplified logic: go to KI generation if it's the last step based on plan length
|
253 |
+
if current_plan_step >= store_plan_len - 1:
|
254 |
return "generate_key_issues"
|
255 |
else:
|
256 |
return "detail_step"
|
257 |
+
|
258 |
|
259 |
def detail_or_data_retriever(state: State, config: ConfigSchema):
|
260 |
"""
|
261 |
+
Decide whether to detail the query or go straight to data retrieval.
|
262 |
"""
|
263 |
+
# Check configuration if detailed query is needed
|
264 |
if config["configurable"].get("use_detailed_query"):
|
265 |
+
# Need to invoke the LLM to get the detailed query
|
266 |
+
return "chatbot_detail"
|
267 |
else:
|
268 |
+
# Use the plan step directly as the query
|
269 |
return "data_retriever"
|
270 |
|
271 |
def retrieve_or_process(state: State):
|
272 |
"""
|
273 |
+
Process the retrieved docs or keep retrieving (based on human_validated flag).
|
274 |
"""
|
275 |
+
# Check the 'human_validated' flag in the state
|
276 |
+
# This flag needs to be set externally (e.g., by Streamlit UI or another mechanism)
|
277 |
+
# before this node is reached after data_retriever.
|
278 |
+
if state.get('human_validated'):
|
279 |
return "process"
|
280 |
+
else:
|
281 |
+
# If not validated, loop back to retrieve more (or wait for validation)
|
282 |
+
# This assumes data_retriever might be called again or the graph waits.
|
283 |
+
# In the Streamlit app, the human_validation node allows setting this flag.
|
284 |
+
return "retrieve"
|
|
|
|
|
|
|
285 |
|
286 |
|
287 |
def build_planner_graph(memory, config):
|
|
|
292 |
|
293 |
graph_doc_retriever = build_data_retriever_graph(memory)
|
294 |
graph_doc_processor = build_data_processor_graph(memory)
|
295 |
+
|
296 |
+
# Use the generic chatbot node function
|
297 |
+
graph_builder.add_node("chatbot_planner", lambda state: chatbot_node(state, config))
|
298 |
graph_builder.add_node("validate", validate_node)
|
299 |
+
# Add node for chatbot interaction when detailed query is needed
|
300 |
+
graph_builder.add_node("chatbot_detail", lambda state: chatbot_node(state, config))
|
301 |
graph_builder.add_node("parse", parse_plan)
|
302 |
+
# Pass config to detail_step as it needs it now
|
303 |
+
graph_builder.add_node("detail_step", lambda state: detail_step(state, config))
|
304 |
+
graph_builder.add_node("data_retriever", graph_doc_retriever) # Input mapping happens automatically if state keys match
|
305 |
+
graph_builder.add_node("human_validation", human_validation) # Needs input mapping if HumanValidationState differs significantly
|
306 |
+
graph_builder.add_node("data_processor", graph_doc_processor) # Needs input mapping if DocProcessorState differs significantly
|
307 |
graph_builder.add_node("concatenate_data", concatenate_data)
|
308 |
+
# Use the generic chatbot node function
|
309 |
+
graph_builder.add_node("chatbot_exec_step", lambda state: chatbot_node(state, config))
|
310 |
graph_builder.add_node("generate_ki", generate_ki)
|
311 |
+
# Use the generic chatbot node function
|
312 |
+
graph_builder.add_node("chatbot_ki", lambda state: chatbot_node(state, config))
|
313 |
graph_builder.add_node("detail_ki", detail_ki)
|
314 |
+
# Use the generic chatbot node function
|
315 |
+
graph_builder.add_node("chatbot_final", lambda state: chatbot_node(state, config))
|
316 |
|
317 |
+
# Define edges
|
318 |
graph_builder.add_edge("validate", "chatbot_planner")
|
319 |
graph_builder.add_edge("parse", "detail_step")
|
320 |
|
321 |
+
# Edge from chatbot_detail (after getting detailed query) to data_retriever
|
322 |
+
graph_builder.add_edge("chatbot_detail", "data_retriever")
|
323 |
|
|
|
|
|
324 |
graph_builder.add_edge("data_retriever", "human_validation")
|
|
|
|
|
325 |
graph_builder.add_edge("data_processor", "concatenate_data")
|
326 |
graph_builder.add_edge("concatenate_data", "chatbot_exec_step")
|
327 |
graph_builder.add_edge("generate_ki", "chatbot_ki")
|
|
|
329 |
graph_builder.add_edge("detail_ki", "chatbot_final")
|
330 |
graph_builder.add_edge("chatbot_final", "__end__")
|
331 |
|
332 |
+
# Define conditional edges
|
333 |
graph_builder.add_conditional_edges(
|
334 |
"detail_step",
|
335 |
+
# Pass config to the conditional function
|
336 |
+
lambda state: detail_or_data_retriever(state, config),
|
337 |
{"chatbot_detail": "chatbot_detail", "data_retriever": "data_retriever"}
|
338 |
)
|
339 |
graph_builder.add_conditional_edges(
|
340 |
"human_validation",
|
341 |
retrieve_or_process,
|
342 |
+
# Map 'retrieve' back to 'data_retriever' node, 'process' to 'data_processor'
|
343 |
+
{"retrieve" : "data_retriever", "process" : "data_processor"}
|
344 |
)
|
345 |
graph_builder.add_conditional_edges(
|
346 |
"chatbot_planner",
|
|
|
349 |
)
|
350 |
graph_builder.add_conditional_edges(
|
351 |
"chatbot_exec_step",
|
352 |
+
# Pass config to the conditional function
|
353 |
+
lambda state: next_plan_step(state, config),
|
354 |
{"generate_key_issues" : "generate_ki", "detail_step": "detail_step"}
|
355 |
)
|
356 |
|
357 |
+
# Set entry point
|
358 |
+
graph_builder.set_entry_point("chatbot_planner")
|
359 |
+
|
360 |
+
# Compile the graph
|
361 |
graph = graph_builder.compile(
|
362 |
checkpointer=memory,
|
363 |
+
# Define interrupt points if needed for human interaction or debugging
|
364 |
+
interrupt_after=["human_validation", "chatbot_final"],
|
365 |
)
|
366 |
return graph
|