File size: 4,076 Bytes
b81ac13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
from typing import TypedDict, Annotated
from langgraph.graph.message import add_messages
from langchain_core.messages import AnyMessage, HumanMessage, AIMessage
from langgraph.graph import START, StateGraph
from components.stage_mapping import get_stage_list, get_next_stage
from llm_utils import call_llm_api, is_stage_complete
from langchain_core.messages import AIMessage
from langgraph.prebuilt import ToolNode, tools_condition

# Define the agent state
class AgentState(TypedDict):
    messages: Annotated[list[AnyMessage], add_messages]
    current_stage: str
    completed_stages: list[str]

stage_list = get_stage_list()

def make_stage_node(stage_name):
    def stage_node(state: AgentState):
        # Only proceed if the last message is from the user
        last_msg = state["messages"][-1]
        # Only call LLM if the last message is from the user (not AI)
        if hasattr(last_msg, "type") and last_msg.type == "human":
            # Prepare messages for LLM context
            messages = []
            for msg in state["messages"]:
                if hasattr(msg, "type") and msg.type == "system":
                    messages.append({"role": "system", "content": msg.content})
                elif hasattr(msg, "type") and msg.type == "human":
                    messages.append({"role": "user", "content": msg.content})
                elif hasattr(msg, "type") and msg.type == "ai":
                    messages.append({"role": "assistant", "content": msg.content})
            # --- Add robust stage management system prompt ---
            stage_context_prompt = (
                f"[Stage Management]\n"
                f"Current stage: {state['current_stage']}\n"
                f"Completed stages: {', '.join(state['completed_stages']) if state['completed_stages'] else 'None'}\n"
                "You must always check if the current stage is complete. You must look at evidence in <self-notes> to determine if you have enough logical information and reasoning to conclude the stage is complete. "
                "If it is, clearly state that the stage is complete and suggest moving to the next stage. "
                "If not, ask clarifying questions or provide guidance for the current stage. "
                "Never forget to consider the current stage and completed stages in your reasoning."
            )
            messages = [{"role": "system", "content": stage_context_prompt}] + messages
            assistant_reply = call_llm_api(messages)
            new_messages = state["messages"] + [AIMessage(content=assistant_reply)]
            completed_stages = state["completed_stages"].copy()
            current_stage = state["current_stage"]
            # Only move to next stage if is_stage_complete returns True
            if is_stage_complete(assistant_reply):
                completed_stages.append(current_stage)
                next_stage = get_next_stage(current_stage)
                if next_stage:
                    current_stage = next_stage
                else:
                    current_stage = None
            return {
                "messages": new_messages,
                "current_stage": current_stage,
                "completed_stages": completed_stages,
            }
        else:
            # If last message is not from user, do nothing (wait for user input)
            return state
    return stage_node

# Build the graph
builder = StateGraph(AgentState)

# Add a node for each stage
for stage in stage_list:
    builder.add_node(stage, make_stage_node(stage))


# Add edges for sequential progression and conditional tool usage
builder.add_edge(START, stage_list[0])
for stage in stage_list:
    next_stage = get_next_stage(stage)
    # Always add a conditional edge to tools and to the next/default stage
    if next_stage:
        builder.add_edge(stage, next_stage)
    ## Modal and Nebius do not support conditional tool edges yet

# Compile the graph 
stage_graph = builder.compile()

with open("graph_output.png", "wb") as f:
    f.write(stage_graph.get_graph().draw_mermaid_png())