File size: 4,724 Bytes
a438cae
 
 
36383cc
a438cae
 
 
 
 
 
 
36383cc
a438cae
0235be8
 
 
 
 
e05c99a
 
 
0235be8
a438cae
36383cc
 
a438cae
 
 
 
 
 
 
9893f68
a438cae
 
36383cc
9ad58d0
36383cc
a438cae
 
 
9893f68
 
 
36383cc
9ad58d0
36383cc
9893f68
 
 
 
a438cae
 
9893f68
 
36383cc
9893f68
 
a438cae
0235be8
a438cae
 
 
 
 
0235be8
 
a438cae
 
 
0235be8
 
 
 
 
 
a438cae
e05c99a
 
 
 
 
0235be8
 
 
 
 
a438cae
e05c99a
0235be8
a438cae
0235be8
 
a438cae
e05c99a
0235be8
 
 
 
 
e05c99a
a438cae
 
0235be8
 
 
 
 
 
 
 
a438cae
36383cc
 
a438cae
36383cc
 
 
 
a438cae
e05c99a
a438cae
e05c99a
 
 
 
 
 
 
 
a438cae
 
 
e05c99a
 
 
a438cae
 
 
e05c99a
 
 
 
 
a438cae
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
import os
import pprint
import uuid
from dotenv import load_dotenv

from langchain_mcp_adapters.client import MultiServerMCPClient
from langchain_openai import ChatOpenAI
from langgraph.prebuilt import ToolNode
from langgraph.graph import MessagesState, END, StateGraph
from langchain_core.messages import HumanMessage
from langgraph.checkpoint.memory import MemorySaver


from pmcp.agents.executor import ExecutorAgent
from pmcp.agents.trello_agent import TrelloAgent
from pmcp.agents.github_agent import GithubAgent
from pmcp.agents.planner import PlannerAgent

from pmcp.nodes.human_interrupt_node import HumanInterruptNode
from pmcp.nodes.human_resume_node import HumanResumeNode

from pmcp.models.state import PlanningState

load_dotenv()


async def call_llm(llm_with_tools: ChatOpenAI, state: MessagesState):
    response = llm_with_tools.invoke(state["messages"])
    return {"messages": [response]}


async def main():
    mcp_client_trello = MultiServerMCPClient(
        {
            "trello": {
                "command": "python",
                "args": ["pmcp/mcp_server/trello_server/mcp_trello_main.py"],
                "transport": "stdio",
            }
        }
    )
    mcp_client_github = MultiServerMCPClient(
        {
            "github": {
                "command": "python",
                "args": ["pmcp/mcp_server/github_server/mcp_github_main.py"],
                "transport": "stdio",
            }
        }
    )

    memory = MemorySaver()

    trello_tools = await mcp_client_trello.get_tools()
    github_tools = await mcp_client_github.get_tools()

    tool_node = ToolNode(github_tools + trello_tools)

    llm = ChatOpenAI(
        model="Qwen/Qwen2.5-32B-Instruct",
        temperature=0.0,
        api_key=os.getenv("NEBIUS_API_KEY"),
        base_url="https://api.studio.nebius.com/v1/",
    )

    trello_agent = TrelloAgent(
        tools=trello_tools,
        llm=llm,
    )

    github_agent = GithubAgent(llm=llm, tools=github_tools)

    planner_agent = PlannerAgent(
        llm=llm,
    )
    executor_agent = ExecutorAgent(llm=llm)

    human_interrupt_node = HumanInterruptNode(
        llm=llm,
    )
    human_resume_node = HumanResumeNode(llm=llm)

    graph = StateGraph(MessagesState)
    graph.add_node(planner_agent.agent.agent_name, planner_agent.acall_planner_agent)
    graph.add_node(trello_agent.agent.agent_name, trello_agent.acall_trello_agent)
    graph.add_node(github_agent.agent.agent_name, github_agent.acall_github_agent)
    graph.add_node(executor_agent.agent.agent_name, executor_agent.acall_executor_agent)
    graph.add_node("tool", tool_node)
    graph.add_node("human_interrupt", human_interrupt_node.call_human_interrupt_agent)
    graph.set_entry_point(planner_agent.agent.agent_name)

    def should_continue(state: PlanningState):
        last_message = state.messages[-1]
        if last_message.tool_calls:
            return "human_interrupt"
        return executor_agent.agent.agent_name

    def execute_agent(state: PlanningState):
        if state.current_step:
            return state.current_step.agent

        return END

    graph.add_conditional_edges(trello_agent.agent.agent_name, should_continue)
    graph.add_conditional_edges(github_agent.agent.agent_name, should_continue)
    graph.add_conditional_edges(executor_agent.agent.agent_name, execute_agent)

    graph.add_edge("tool", trello_agent.agent.agent_name)
    graph.add_edge("tool", github_agent.agent.agent_name)
    graph.add_edge(planner_agent.agent.agent_name, executor_agent.agent.agent_name)

    app = graph.compile(checkpointer=memory)
    app.get_graph(xray=True).draw_mermaid()

    user_input = input("user >")
    config = {
        "configurable": {"thread_id": f"{str(uuid.uuid4())}"},
        "recursion_limit": 100,
    }

    is_message_command = False
    while user_input.lower() != "q":

        if is_message_command:
            app_input = human_resume_node.call_human_interrupt_agent(
                user_input
            )
            is_message_command = False
        else:
            app_input = {
                "messages": [
                    HumanMessage(content=user_input),
                ]
            }
        async for res in app.astream(
            app_input,
            config=config,
            stream_mode="values",
        ):
            if "messages" in res:
                pprint.pprint(res["messages"][-1])
            else:
                pprint.pprint(res["__interrupt__"][0])
                is_message_command = True
            pprint.pprint("-------------------------------------")
        user_input = input("user >")


if __name__ == "__main__":
    import asyncio

    asyncio.run(main())