File size: 4,076 Bytes
b81ac13 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 |
from typing import TypedDict, Annotated
from langgraph.graph.message import add_messages
from langchain_core.messages import AnyMessage, HumanMessage, AIMessage
from langgraph.graph import START, StateGraph
from components.stage_mapping import get_stage_list, get_next_stage
from llm_utils import call_llm_api, is_stage_complete
from langchain_core.messages import AIMessage
from langgraph.prebuilt import ToolNode, tools_condition
# Define the agent state
class AgentState(TypedDict):
messages: Annotated[list[AnyMessage], add_messages]
current_stage: str
completed_stages: list[str]
stage_list = get_stage_list()
def make_stage_node(stage_name):
def stage_node(state: AgentState):
# Only proceed if the last message is from the user
last_msg = state["messages"][-1]
# Only call LLM if the last message is from the user (not AI)
if hasattr(last_msg, "type") and last_msg.type == "human":
# Prepare messages for LLM context
messages = []
for msg in state["messages"]:
if hasattr(msg, "type") and msg.type == "system":
messages.append({"role": "system", "content": msg.content})
elif hasattr(msg, "type") and msg.type == "human":
messages.append({"role": "user", "content": msg.content})
elif hasattr(msg, "type") and msg.type == "ai":
messages.append({"role": "assistant", "content": msg.content})
# --- Add robust stage management system prompt ---
stage_context_prompt = (
f"[Stage Management]\n"
f"Current stage: {state['current_stage']}\n"
f"Completed stages: {', '.join(state['completed_stages']) if state['completed_stages'] else 'None'}\n"
"You must always check if the current stage is complete. You must look at evidence in <self-notes> to determine if you have enough logical information and reasoning to conclude the stage is complete. "
"If it is, clearly state that the stage is complete and suggest moving to the next stage. "
"If not, ask clarifying questions or provide guidance for the current stage. "
"Never forget to consider the current stage and completed stages in your reasoning."
)
messages = [{"role": "system", "content": stage_context_prompt}] + messages
assistant_reply = call_llm_api(messages)
new_messages = state["messages"] + [AIMessage(content=assistant_reply)]
completed_stages = state["completed_stages"].copy()
current_stage = state["current_stage"]
# Only move to next stage if is_stage_complete returns True
if is_stage_complete(assistant_reply):
completed_stages.append(current_stage)
next_stage = get_next_stage(current_stage)
if next_stage:
current_stage = next_stage
else:
current_stage = None
return {
"messages": new_messages,
"current_stage": current_stage,
"completed_stages": completed_stages,
}
else:
# If last message is not from user, do nothing (wait for user input)
return state
return stage_node
# Build the graph
builder = StateGraph(AgentState)
# Add a node for each stage
for stage in stage_list:
builder.add_node(stage, make_stage_node(stage))
# Add edges for sequential progression and conditional tool usage
builder.add_edge(START, stage_list[0])
for stage in stage_list:
next_stage = get_next_stage(stage)
# Always add a conditional edge to tools and to the next/default stage
if next_stage:
builder.add_edge(stage, next_stage)
## Modal and Nebius do not support conditional tool edges yet
# Compile the graph
stage_graph = builder.compile()
with open("graph_output.png", "wb") as f:
f.write(stage_graph.get_graph().draw_mermaid_png())
|