Spaces:
Sleeping
Sleeping
import os | |
import re | |
from typing import Annotated | |
from typing_extensions import TypedDict | |
# Remove ChatGroq import | |
# from langchain_groq import ChatGroq | |
# Add ChatGoogleGenerativeAI import | |
from langchain_google_genai import ChatGoogleGenerativeAI | |
import os # Add os import | |
from langchain_openai import ChatOpenAI | |
from langchain_core.messages import SystemMessage, HumanMessage | |
from langchain_community.graphs import Neo4jGraph | |
from langgraph.graph import StateGraph | |
from langgraph.graph import add_messages | |
from ki_gen.prompts import PLAN_GEN_PROMPT, PLAN_MODIFICATION_PROMPT | |
from ki_gen.data_retriever import build_data_retriever_graph | |
from ki_gen.data_processor import build_data_processor_graph | |
# Import get_model which now handles Gemini | |
from ki_gen.utils import ConfigSchema, State, HumanValidationState, DocProcessorState, DocRetrieverState, get_model | |
from langgraph.checkpoint.sqlite import SqliteSaver | |
########################################################################## | |
###### NODES DEFINITION ###### | |
########################################################################## | |
def validate_node(state: State): | |
""" | |
This node inserts the plan validation prompt. | |
""" | |
prompt = """System : You only need to focus on Key Issues, no need to focus on solutions or stakeholders yet and your plan should be concise. | |
If needed, give me an updated plan to follow this instruction. If your plan already follows the instruction just say "My plan is correct".""" | |
output = HumanMessage(content=prompt) | |
return {"messages" : [output]} | |
# Remove Groq-specific error handler | |
# def error_chatbot_groq(error, model_name, query): ... | |
# Wrappers to call LLMs on the state messsages field | |
# Simplify: Use get_model directly or a single chatbot function | |
def chatbot_node(state: State, config: ConfigSchema): | |
"""Generic chatbot node using the main_llm from config.""" | |
model_name = config["configurable"].get("main_llm") or "gemini-2.0-flash" | |
llm = get_model(model_name) | |
try: | |
# Check if messages exist and are not empty | |
if "messages" in state and state["messages"]: | |
response = llm.invoke(state["messages"]) | |
return {"messages": [response]} | |
else: | |
print("Warning: No messages found in state for chatbot_node.") | |
# Return state unchanged or an empty message list? | |
return {} # Or {"messages": []} | |
except Exception as e: | |
print(f"Error invoking model {model_name}: {e}") | |
# Handle error, maybe return an error message or empty dict | |
return {"messages": [SystemMessage(content=f"Error during generation: {e}")]} | |
# Remove old chatbot functions (chatbot_llama, chatbot_mixtral, chatbot_openai) | |
# Replace the chatbots dictionary with direct calls to the generic function or specific models via get_model | |
# This simplifies planner.py, relying on utils.py and config for model selection. | |
def parse_plan(state: State): | |
""" | |
This node parses the generated plan and writes in the 'store_plan' field of the state | |
""" | |
# Find the AI message likely containing the plan (often the second to last if validate_node was used) | |
plan_message_content = "" | |
if "messages" in state and len(state["messages"]) >= 1: | |
# Search backwards for the plan, as its position might vary | |
for msg in reversed(state["messages"]): | |
if hasattr(msg, 'content') and "Plan:" in msg.content and "<END_OF_PLAN>" in msg.content: | |
plan_message_content = msg.content | |
break # Found the plan | |
if not plan_message_content: | |
print("Error: Could not find plan message in state.") | |
# Handle error: maybe return current state or raise an exception | |
return state # Return unchanged state if plan not found | |
store_plan = [] | |
try: | |
# Improved parsing: handle potential variations in formatting | |
plan_section = plan_message_content.split("Plan:")[1].split("<END_OF_PLAN>")[0] | |
# Split by numbered steps, removing empty entries | |
store_plan = [step.strip() for step in re.split(r"\n\s*\d+\.\s*", plan_section) if step.strip()] | |
except Exception as e: | |
print(f"Error while parsing plan: {e}") | |
# Handle parsing error, potentially keep store_plan empty or log the error | |
store_plan = [] # Reset plan on error | |
return {"store_plan" : store_plan} | |
# Update get_detailed_query to use get_model and default model | |
def get_detailed_query(context : list, model : str = "gemini-2.0-flash"): | |
""" | |
Simple helper function for the detail_step node | |
""" | |
llm = get_model(model) # Use get_model | |
try: | |
return llm.invoke(context) | |
except Exception as e: | |
print(f"Error in get_detailed_query with model {model}: {e}") | |
# Return a default message or raise error | |
return SystemMessage(content=f"Error generating detailed query: {e}") | |
def detail_step(state: State, config: ConfigSchema): | |
""" | |
This node updates the value of the 'current_plan_step' field and defines the query to be used for the data_retriever. | |
""" | |
print("Entering detail_step") # Debug print | |
print(f"Current state keys: {state.keys()}") # Debug print | |
# Initialize current_plan_step if not present | |
current_plan_step = state.get("current_plan_step", -1) + 1 | |
# Ensure store_plan exists and has enough steps | |
store_plan = state.get("store_plan", []) | |
if not store_plan or current_plan_step >= len(store_plan): | |
print(f"Warning: Plan step {current_plan_step} out of bounds or plan is empty.") | |
# Decide how to handle: end graph, return error state? | |
# For now, let's prevent index error and maybe signal an issue | |
# Returning an empty query might halt progress or cause issues downstream | |
return {"current_plan_step": current_plan_step, 'query' : "Error: Plan step unavailable.", "valid_docs" : []} | |
plan_step_description = store_plan[current_plan_step] | |
if config["configurable"].get("use_detailed_query"): | |
prompt = HumanMessage(f"""Specify what additional information you need to proceed with the next step of your plan : | |
Step {current_plan_step + 1} : {plan_step_description}""") | |
# Ensure messages exist before appending | |
current_messages = state.get("messages", []) | |
query_message = get_detailed_query(context = current_messages + [prompt], model=config["configurable"].get("main_llm", "gemini-2.0-flash")) | |
query_content = query_message.content if hasattr(query_message, 'content') else "Error: Could not get detailed query content." | |
return {"messages" : [prompt, query_message], "current_plan_step": current_plan_step, 'query' : query_content, "valid_docs": state.get("valid_docs", [])} # Ensure valid_docs is preserved | |
# If not using detailed query, use the plan step description directly | |
return {"current_plan_step": current_plan_step, 'query' : plan_step_description, "valid_docs" : state.get("valid_docs", [])} # Ensure valid_docs is preserved | |
def concatenate_data(state: State): | |
""" | |
This node concatenates all the data that was processed by the data_processor and inserts it in the state's messages | |
""" | |
# Ensure valid_docs exists and current_plan_step is valid | |
valid_docs_content = state.get("valid_docs", "No processed documents available.") | |
current_plan_step = state.get("current_plan_step", -1) | |
store_plan = state.get("store_plan", []) | |
if current_plan_step < 0 or current_plan_step >= len(store_plan): | |
print(f"Warning: Invalid current_plan_step ({current_plan_step}) in concatenate_data.") | |
# Handle error - maybe return an error message | |
step_description = "Error: Current plan step invalid." | |
else: | |
step_description = store_plan[current_plan_step] | |
prompt = f"""#########TECHNICAL INFORMATION ############ | |
{str(valid_docs_content)} | |
########END OF TECHNICAL INFORMATION####### | |
Using the information provided above, proceed with step {current_plan_step + 1} of your plan : | |
{step_description} | |
""" | |
return {"messages": [HumanMessage(content=prompt)]} | |
def human_validation(state: HumanValidationState) -> HumanValidationState: | |
""" | |
Dummy node to interrupt before processing, can be used for manual validation later. | |
""" | |
# Defaulting to no processing steps needed unless specified elsewhere | |
return {'process_steps' : state.get('process_steps', [])} | |
def generate_ki(state: State): | |
""" | |
This node inserts the prompt to begin Key Issues generation | |
""" | |
print(f"THIS IS THE STATE FOR CURRENT PLAN STEP IN GENERATE_KI : {state.get('current_plan_step')}") | |
current_plan_step = state.get("current_plan_step", -1) | |
store_plan = state.get("store_plan", []) | |
# Check if the next step exists in the plan | |
next_step_index = current_plan_step + 1 | |
if next_step_index < 0 or next_step_index >= len(store_plan): | |
print(f"Warning: Invalid next plan step ({next_step_index}) for KI generation.") | |
step_description = "Error: Plan step for KI generation unavailable." | |
else: | |
step_description = store_plan[next_step_index] | |
prompt = f"""Using the information provided above, proceed with step {next_step_index + 1} of your plan to provide the user with NEW and INNOVATIVE Key Issues : | |
{step_description}""" | |
return {"messages" : [HumanMessage(content=prompt)]} | |
def detail_ki(state: State): | |
""" | |
This node inserts the last prompt to detail the generated Key Issues | |
""" | |
current_plan_step = state.get("current_plan_step", -1) | |
store_plan = state.get("store_plan", []) | |
# Check if the step after next exists in the plan | |
detail_step_index = current_plan_step + 2 | |
if detail_step_index < 0 or detail_step_index >= len(store_plan): | |
print(f"Warning: Invalid plan step ({detail_step_index}) for KI detailing.") | |
step_description = "Error: Plan step for KI detailing unavailable." | |
else: | |
step_description = store_plan[detail_step_index] | |
prompt = f"""Using the information provided above, proceed with step {detail_step_index + 1} of your plan to provide the user with NEW and INNOVATIVE Key Issues : | |
{step_description}""" | |
return {"messages" : [HumanMessage(content=prompt)]} | |
########################################################################## | |
###### CONDITIONAL EDGE FUNCTIONS ###### | |
########################################################################## | |
def validate_plan(state: State): | |
""" | |
Whether to regenerate the plan or to parse it | |
""" | |
# Check the last message for "My plan is correct" | |
if "messages" in state and state["messages"]: | |
last_message = state["messages"][-1] | |
if hasattr(last_message, 'content') and "My plan is correct" in last_message.content: | |
return "parse" | |
# Default to validate (regenerate) if condition not met or messages are missing | |
return "validate" | |
def next_plan_step(state: State, config: ConfigSchema): | |
""" | |
Proceed to next plan step (either generate KI or retrieve more data) | |
""" | |
current_plan_step = state.get("current_plan_step", -1) | |
store_plan_len = len(state.get("store_plan", [])) | |
# Simplified logic: go to KI generation if it's the last step based on plan length | |
if current_plan_step >= store_plan_len - 1: | |
return "generate_key_issues" | |
else: | |
return "detail_step" | |
def detail_or_data_retriever(state: State, config: ConfigSchema): | |
""" | |
Decide whether to detail the query or go straight to data retrieval. | |
""" | |
# Check configuration if detailed query is needed | |
if config["configurable"].get("use_detailed_query"): | |
# Need to invoke the LLM to get the detailed query | |
return "chatbot_detail" | |
else: | |
# Use the plan step directly as the query | |
return "data_retriever" | |
def retrieve_or_process(state: State): | |
""" | |
Process the retrieved docs or keep retrieving (based on human_validated flag). | |
""" | |
# Check the 'human_validated' flag in the state | |
# This flag needs to be set externally (e.g., by Streamlit UI or another mechanism) | |
# before this node is reached after data_retriever. | |
if state.get('human_validated'): | |
return "process" | |
else: | |
# If not validated, loop back to retrieve more (or wait for validation) | |
# This assumes data_retriever might be called again or the graph waits. | |
# In the Streamlit app, the human_validation node allows setting this flag. | |
return "retrieve" | |
def build_planner_graph(memory, config): | |
""" | |
Builds the planner graph | |
""" | |
graph_builder = StateGraph(State) | |
graph_doc_retriever = build_data_retriever_graph(memory) | |
graph_doc_processor = build_data_processor_graph(memory) | |
# Use the generic chatbot node function | |
graph_builder.add_node("chatbot_planner", lambda state: chatbot_node(state, config)) | |
graph_builder.add_node("validate", validate_node) | |
# Add node for chatbot interaction when detailed query is needed | |
graph_builder.add_node("chatbot_detail", lambda state: chatbot_node(state, config)) | |
graph_builder.add_node("parse", parse_plan) | |
# Pass config to detail_step as it needs it now | |
graph_builder.add_node("detail_step", lambda state: detail_step(state, config)) | |
graph_builder.add_node("data_retriever", graph_doc_retriever) # Input mapping happens automatically if state keys match | |
graph_builder.add_node("human_validation", human_validation) # Needs input mapping if HumanValidationState differs significantly | |
graph_builder.add_node("data_processor", graph_doc_processor) # Needs input mapping if DocProcessorState differs significantly | |
graph_builder.add_node("concatenate_data", concatenate_data) | |
# Use the generic chatbot node function | |
graph_builder.add_node("chatbot_exec_step", lambda state: chatbot_node(state, config)) | |
graph_builder.add_node("generate_ki", generate_ki) | |
# Use the generic chatbot node function | |
graph_builder.add_node("chatbot_ki", lambda state: chatbot_node(state, config)) | |
graph_builder.add_node("detail_ki", detail_ki) | |
# Use the generic chatbot node function | |
graph_builder.add_node("chatbot_final", lambda state: chatbot_node(state, config)) | |
# Define edges | |
graph_builder.add_edge("validate", "chatbot_planner") | |
graph_builder.add_edge("parse", "detail_step") | |
# Edge from chatbot_detail (after getting detailed query) to data_retriever | |
graph_builder.add_edge("chatbot_detail", "data_retriever") | |
graph_builder.add_edge("data_retriever", "human_validation") | |
graph_builder.add_edge("data_processor", "concatenate_data") | |
graph_builder.add_edge("concatenate_data", "chatbot_exec_step") | |
graph_builder.add_edge("generate_ki", "chatbot_ki") | |
graph_builder.add_edge("chatbot_ki", "detail_ki") | |
graph_builder.add_edge("detail_ki", "chatbot_final") | |
graph_builder.add_edge("chatbot_final", "__end__") | |
# Define conditional edges | |
graph_builder.add_conditional_edges( | |
"detail_step", | |
# Pass config to the conditional function | |
lambda state: detail_or_data_retriever(state, config), | |
{"chatbot_detail": "chatbot_detail", "data_retriever": "data_retriever"} | |
) | |
graph_builder.add_conditional_edges( | |
"human_validation", | |
retrieve_or_process, | |
# Map 'retrieve' back to 'data_retriever' node, 'process' to 'data_processor' | |
{"retrieve" : "data_retriever", "process" : "data_processor"} | |
) | |
graph_builder.add_conditional_edges( | |
"chatbot_planner", | |
validate_plan, | |
{"parse" : "parse", "validate": "validate"} | |
) | |
graph_builder.add_conditional_edges( | |
"chatbot_exec_step", | |
# Pass config to the conditional function | |
lambda state: next_plan_step(state, config), | |
{"generate_key_issues" : "generate_ki", "detail_step": "detail_step"} | |
) | |
# Set entry point | |
graph_builder.set_entry_point("chatbot_planner") | |
# Compile the graph | |
graph = graph_builder.compile( | |
checkpointer=memory, | |
# Define interrupt points if needed for human interaction or debugging | |
interrupt_after=["human_validation", "chatbot_final"], | |
) | |
return graph |