|
from typing import TypedDict, Annotated |
|
from langgraph.graph.message import add_messages |
|
from langchain_core.messages import AnyMessage, HumanMessage, AIMessage |
|
from langgraph.graph import START, StateGraph |
|
from components.stage_mapping import get_stage_list, get_next_stage |
|
from llm_utils import call_llm_api, is_stage_complete |
|
from langchain_core.messages import AIMessage |
|
from langgraph.prebuilt import ToolNode, tools_condition |
|
|
|
|
|
class AgentState(TypedDict): |
|
messages: Annotated[list[AnyMessage], add_messages] |
|
current_stage: str |
|
completed_stages: list[str] |
|
|
|
stage_list = get_stage_list() |
|
|
|
def make_stage_node(stage_name): |
|
def stage_node(state: AgentState): |
|
|
|
last_msg = state["messages"][-1] |
|
|
|
if hasattr(last_msg, "type") and last_msg.type == "human": |
|
|
|
messages = [] |
|
for msg in state["messages"]: |
|
if hasattr(msg, "type") and msg.type == "system": |
|
messages.append({"role": "system", "content": msg.content}) |
|
elif hasattr(msg, "type") and msg.type == "human": |
|
messages.append({"role": "user", "content": msg.content}) |
|
elif hasattr(msg, "type") and msg.type == "ai": |
|
messages.append({"role": "assistant", "content": msg.content}) |
|
|
|
stage_context_prompt = ( |
|
f"[Stage Management]\n" |
|
f"Current stage: {state['current_stage']}\n" |
|
f"Completed stages: {', '.join(state['completed_stages']) if state['completed_stages'] else 'None'}\n" |
|
"You must always check if the current stage is complete. You must look at evidence in <self-notes> to determine if you have enough logical information and reasoning to conclude the stage is complete. " |
|
"If it is, clearly state that the stage is complete and suggest moving to the next stage. " |
|
"If not, ask clarifying questions or provide guidance for the current stage. " |
|
"Never forget to consider the current stage and completed stages in your reasoning." |
|
) |
|
messages = [{"role": "system", "content": stage_context_prompt}] + messages |
|
assistant_reply = call_llm_api(messages) |
|
new_messages = state["messages"] + [AIMessage(content=assistant_reply)] |
|
completed_stages = state["completed_stages"].copy() |
|
current_stage = state["current_stage"] |
|
|
|
if is_stage_complete(assistant_reply): |
|
completed_stages.append(current_stage) |
|
next_stage = get_next_stage(current_stage) |
|
if next_stage: |
|
current_stage = next_stage |
|
else: |
|
current_stage = None |
|
return { |
|
"messages": new_messages, |
|
"current_stage": current_stage, |
|
"completed_stages": completed_stages, |
|
} |
|
else: |
|
|
|
return state |
|
return stage_node |
|
|
|
|
|
builder = StateGraph(AgentState) |
|
|
|
|
|
for stage in stage_list: |
|
builder.add_node(stage, make_stage_node(stage)) |
|
|
|
|
|
|
|
builder.add_edge(START, stage_list[0]) |
|
for stage in stage_list: |
|
next_stage = get_next_stage(stage) |
|
|
|
if next_stage: |
|
builder.add_edge(stage, next_stage) |
|
|
|
|
|
|
|
stage_graph = builder.compile() |
|
|
|
with open("graph_output.png", "wb") as f: |
|
f.write(stage_graph.get_graph().draw_mermaid_png()) |
|
|
|
|