MatteoMass commited on
Commit
e05c99a
·
1 Parent(s): 36383cc

add human in the loop functions

Browse files
.gitignore CHANGED
@@ -9,3 +9,4 @@ wheels/
9
 
10
  # Virtual environments
11
  .venv
 
 
9
 
10
  # Virtual environments
11
  .venv
12
+ .env
main.py CHANGED
@@ -10,7 +10,7 @@ from langgraph.prebuilt import ToolNode
10
  from langgraph.graph import MessagesState, END, StateGraph
11
  from langchain_core.messages import HumanMessage
12
  from langgraph.checkpoint.memory import MemorySaver
13
- from langgraph.types import Command, interrupt
14
 
15
 
16
  from pmcp.agents.executor import ExecutorAgent
@@ -18,6 +18,9 @@ from pmcp.agents.trello_agent import TrelloAgent
18
  from pmcp.agents.github_agent import GithubAgent
19
  from pmcp.agents.planner import PlannerAgent
20
 
 
 
 
21
  from pmcp.models.state import PlanningState
22
 
23
  load_dotenv()
@@ -28,36 +31,6 @@ async def call_llm(llm_with_tools: ChatOpenAI, state: MessagesState):
28
  return {"messages": [response]}
29
 
30
 
31
- def human_review_node(state) -> Command[Literal["PLANNER_AGENT", "tool"]]:
32
- last_message = state["messages"][-1]
33
- tool_call = last_message.tool_calls[-1]
34
- if tool_call.get("name", "").startswith("get_"):
35
- return Command(goto="tool")
36
-
37
- human_review = interrupt(
38
- {
39
- "question": "Is this correct?",
40
- # Surface tool calls for review
41
- "tool_call": tool_call,
42
- }
43
- )
44
-
45
- review_action = human_review["action"]
46
- review_data = human_review.get("data")
47
-
48
- if review_action == "continue":
49
- return Command(goto="tool")
50
-
51
- else:
52
- tool_message = {
53
- "role": "tool",
54
- "content": review_data,
55
- "name": tool_call["name"],
56
- "tool_call_id": tool_call["id"],
57
- }
58
- return Command(goto="PLANNER_AGENT", update={"messages": [tool_message]})
59
-
60
-
61
  async def main():
62
  mcp_client_trello = MultiServerMCPClient(
63
  {
@@ -104,25 +77,30 @@ async def main():
104
  )
105
  executor_agent = ExecutorAgent(llm=llm)
106
 
 
 
 
 
 
107
  graph = StateGraph(MessagesState)
108
  graph.add_node(planner_agent.agent.agent_name, planner_agent.acall_planner_agent)
109
  graph.add_node(trello_agent.agent.agent_name, trello_agent.acall_trello_agent)
110
  graph.add_node(github_agent.agent.agent_name, github_agent.acall_github_agent)
111
  graph.add_node(executor_agent.agent.agent_name, executor_agent.acall_executor_agent)
112
  graph.add_node("tool", tool_node)
113
- graph.add_node("human_review", human_review_node)
114
  graph.set_entry_point(planner_agent.agent.agent_name)
115
 
116
  def should_continue(state: PlanningState):
117
  last_message = state.messages[-1]
118
  if last_message.tool_calls:
119
- return "human_review"
120
  return executor_agent.agent.agent_name
121
 
122
  def execute_agent(state: PlanningState):
123
  if state.current_step:
124
  return state.current_step.agent
125
-
126
  return END
127
 
128
  graph.add_conditional_edges(trello_agent.agent.agent_name, should_continue)
@@ -136,24 +114,36 @@ async def main():
136
  app = graph.compile(checkpointer=memory)
137
  app.get_graph(xray=True).draw_mermaid()
138
 
139
-
140
  user_input = input("user >")
141
  config = {
142
  "configurable": {"thread_id": f"{str(uuid.uuid4())}"},
143
  "recursion_limit": 100,
144
  }
145
 
 
146
  while user_input.lower() != "q":
147
- async for res in app.astream(
148
- {
 
 
 
 
 
 
149
  "messages": [
150
  HumanMessage(content=user_input),
151
  ]
152
- },
 
 
153
  config=config,
154
  stream_mode="values",
155
  ):
156
- pprint.pprint(res["messages"][-1])
 
 
 
 
157
  pprint.pprint("-------------------------------------")
158
  user_input = input("user >")
159
 
 
10
  from langgraph.graph import MessagesState, END, StateGraph
11
  from langchain_core.messages import HumanMessage
12
  from langgraph.checkpoint.memory import MemorySaver
13
+ from langgraph.types import Command
14
 
15
 
16
  from pmcp.agents.executor import ExecutorAgent
 
18
  from pmcp.agents.github_agent import GithubAgent
19
  from pmcp.agents.planner import PlannerAgent
20
 
21
+ from pmcp.nodes.human_interrupt_node import HumanInterruptNode
22
+ from pmcp.nodes.human_resume_node import HumanResumeNode
23
+
24
  from pmcp.models.state import PlanningState
25
 
26
  load_dotenv()
 
31
  return {"messages": [response]}
32
 
33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  async def main():
35
  mcp_client_trello = MultiServerMCPClient(
36
  {
 
77
  )
78
  executor_agent = ExecutorAgent(llm=llm)
79
 
80
+ human_interrupt_node = HumanInterruptNode(
81
+ llm=llm,
82
+ )
83
+ human_resume_node = HumanResumeNode(llm=llm)
84
+
85
  graph = StateGraph(MessagesState)
86
  graph.add_node(planner_agent.agent.agent_name, planner_agent.acall_planner_agent)
87
  graph.add_node(trello_agent.agent.agent_name, trello_agent.acall_trello_agent)
88
  graph.add_node(github_agent.agent.agent_name, github_agent.acall_github_agent)
89
  graph.add_node(executor_agent.agent.agent_name, executor_agent.acall_executor_agent)
90
  graph.add_node("tool", tool_node)
91
+ graph.add_node("human_interrupt", human_interrupt_node.call_human_interrupt_agent)
92
  graph.set_entry_point(planner_agent.agent.agent_name)
93
 
94
  def should_continue(state: PlanningState):
95
  last_message = state.messages[-1]
96
  if last_message.tool_calls:
97
+ return "human_interrupt"
98
  return executor_agent.agent.agent_name
99
 
100
  def execute_agent(state: PlanningState):
101
  if state.current_step:
102
  return state.current_step.agent
103
+
104
  return END
105
 
106
  graph.add_conditional_edges(trello_agent.agent.agent_name, should_continue)
 
114
  app = graph.compile(checkpointer=memory)
115
  app.get_graph(xray=True).draw_mermaid()
116
 
 
117
  user_input = input("user >")
118
  config = {
119
  "configurable": {"thread_id": f"{str(uuid.uuid4())}"},
120
  "recursion_limit": 100,
121
  }
122
 
123
+ is_message_command = False
124
  while user_input.lower() != "q":
125
+
126
+ if is_message_command:
127
+ app_input = human_resume_node.call_human_interrupt_agent(
128
+ user_input
129
+ )
130
+ is_message_command = False
131
+ else:
132
+ app_input = {
133
  "messages": [
134
  HumanMessage(content=user_input),
135
  ]
136
+ }
137
+ async for res in app.astream(
138
+ app_input,
139
  config=config,
140
  stream_mode="values",
141
  ):
142
+ if "messages" in res:
143
+ pprint.pprint(res["messages"][-1])
144
+ else:
145
+ pprint.pprint(res["__interrupt__"][0])
146
+ is_message_command = True
147
  pprint.pprint("-------------------------------------")
148
  user_input = input("user >")
149
 
pmcp/models/resume_trigger.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from typing import Optional
3
+ from pydantic import BaseModel
4
+ from typing_extensions import Literal
5
+
6
+
7
+ class ResumeTrigger(BaseModel):
8
+ action: Literal["continue", "edit"]
9
+ changes: Optional[str] = None
pmcp/nodes/__init__.py ADDED
File without changes
pmcp/nodes/human_interrupt_node.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional
2
+
3
+ from pmcp.agents.agent_base import AgentBlueprint
4
+ from langchain_core.tools import BaseTool
5
+ from langchain_core.messages import SystemMessage, AIMessage
6
+ from langchain_openai import ChatOpenAI
7
+ from langgraph.types import Command, interrupt
8
+
9
+ from pmcp.models.state import PlanningState
10
+
11
+ SYSTEM_PROMPT = """
12
+ You are a Human Reviewer Agent responsible for confirming the execution of tasks planned by the Planner Agent. Your role is to:
13
+ - Ask the user for confirmation before an tool calling is performed.
14
+ """
15
+
16
+
17
+ class HumanInterruptNode:
18
+ def __init__(self, llm: ChatOpenAI, tools: Optional[List[BaseTool]] = None):
19
+ self.agent = AgentBlueprint(
20
+ agent_name="HUMAN_REVIEWER_AGENT",
21
+ description="The agent asks for human confirmation",
22
+ tools=tools,
23
+ system_prompt=SYSTEM_PROMPT.strip(),
24
+ llm=llm,
25
+ )
26
+
27
+ def call_human_interrupt_agent(self, state: PlanningState):
28
+ last_message = state.messages[-1]
29
+
30
+ try:
31
+ tool_call = last_message.tool_calls[-1]
32
+ except Exception:
33
+ last_message = state.messages[-2]
34
+ tool_call = last_message.tool_calls[-1]
35
+
36
+
37
+
38
+ if tool_call.get("name", "").startswith("get_"):
39
+ return Command(goto="tool")
40
+
41
+ response = self.agent.call_agent(
42
+ messages=[SystemMessage(content=self.agent.system_prompt), AIMessage(content= str(tool_call))],
43
+ )
44
+ human_review = interrupt(response.content)
45
+
46
+ review_action = human_review.action
47
+ review_changes = human_review.changes
48
+ if review_action == "continue":
49
+ return Command(goto="tool")
50
+
51
+ else:
52
+ tool_message = {
53
+ "role": "tool",
54
+ "content": review_changes,
55
+ "name": tool_call["name"],
56
+ "tool_call_id": tool_call["id"],
57
+ }
58
+ return Command(goto="PLANNER_AGENT", update={"messages": [tool_message]})
pmcp/nodes/human_resume_node.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional
2
+
3
+ from pmcp.agents.agent_base import AgentBlueprint
4
+ from langchain_core.tools import BaseTool
5
+ from langchain_core.messages import SystemMessage, HumanMessage
6
+ from langchain_openai import ChatOpenAI
7
+ from langgraph.types import Command
8
+
9
+ from pmcp.models.resume_trigger import ResumeTrigger
10
+
11
+
12
+ SYSTEM_PROMPT = """
13
+ You are a Human Resumer Agent responsible for understading the user indication on whethere procede or not with an action.
14
+ """
15
+
16
+
17
+ class HumanResumeNode:
18
+ def __init__(self, llm: ChatOpenAI, tools: Optional[List[BaseTool]] = None):
19
+ self.agent = AgentBlueprint(
20
+ agent_name="HUMAN_REVIEWER_AGENT",
21
+ description="The agent asks for human confirmation",
22
+ tools=tools,
23
+ system_prompt=SYSTEM_PROMPT.strip(),
24
+ llm=llm,
25
+ )
26
+
27
+ def call_human_interrupt_agent(self, user_message: str):
28
+ response = self.agent.call_agent_structured(
29
+ [SystemMessage(content=self.agent.system_prompt), HumanMessage(content= user_message)],
30
+ clazz=ResumeTrigger,
31
+ )
32
+
33
+ return Command(resume=response)