|
import os |
|
import pprint |
|
import uuid |
|
from dotenv import load_dotenv |
|
|
|
from langchain_mcp_adapters.client import MultiServerMCPClient |
|
from langchain_openai import ChatOpenAI |
|
from langgraph.prebuilt import ToolNode |
|
from langgraph.graph import MessagesState, END, StateGraph |
|
from langchain_core.messages import HumanMessage |
|
from langgraph.checkpoint.memory import MemorySaver |
|
|
|
|
|
from pmcp.agents.executor import ExecutorAgent |
|
from pmcp.agents.trello_agent import TrelloAgent |
|
from pmcp.agents.github_agent import GithubAgent |
|
from pmcp.agents.planner import PlannerAgent |
|
|
|
from pmcp.nodes.human_interrupt_node import HumanInterruptNode |
|
from pmcp.nodes.human_resume_node import HumanResumeNode |
|
|
|
from pmcp.models.state import PlanningState |
|
|
|
load_dotenv() |
|
|
|
|
|
async def call_llm(llm_with_tools: ChatOpenAI, state: MessagesState): |
|
response = llm_with_tools.invoke(state["messages"]) |
|
return {"messages": [response]} |
|
|
|
|
|
async def main(): |
|
mcp_client_trello = MultiServerMCPClient( |
|
{ |
|
"trello": { |
|
"command": "python", |
|
"args": ["pmcp/mcp_server/trello_server/mcp_trello_main.py"], |
|
"transport": "stdio", |
|
} |
|
} |
|
) |
|
mcp_client_github = MultiServerMCPClient( |
|
{ |
|
"github": { |
|
"command": "python", |
|
"args": ["pmcp/mcp_server/github_server/mcp_github_main.py"], |
|
"transport": "stdio", |
|
} |
|
} |
|
) |
|
|
|
memory = MemorySaver() |
|
|
|
trello_tools = await mcp_client_trello.get_tools() |
|
github_tools = await mcp_client_github.get_tools() |
|
|
|
tool_node = ToolNode(github_tools + trello_tools) |
|
|
|
llm = ChatOpenAI( |
|
model="Qwen/Qwen2.5-32B-Instruct", |
|
temperature=0.0, |
|
api_key=os.getenv("NEBIUS_API_KEY"), |
|
base_url="https://api.studio.nebius.com/v1/", |
|
) |
|
|
|
trello_agent = TrelloAgent( |
|
tools=trello_tools, |
|
llm=llm, |
|
) |
|
|
|
github_agent = GithubAgent(llm=llm, tools=github_tools) |
|
|
|
planner_agent = PlannerAgent( |
|
llm=llm, |
|
) |
|
executor_agent = ExecutorAgent(llm=llm) |
|
|
|
human_interrupt_node = HumanInterruptNode( |
|
llm=llm, |
|
) |
|
human_resume_node = HumanResumeNode(llm=llm) |
|
|
|
graph = StateGraph(MessagesState) |
|
graph.add_node(planner_agent.agent.agent_name, planner_agent.acall_planner_agent) |
|
graph.add_node(trello_agent.agent.agent_name, trello_agent.acall_trello_agent) |
|
graph.add_node(github_agent.agent.agent_name, github_agent.acall_github_agent) |
|
graph.add_node(executor_agent.agent.agent_name, executor_agent.acall_executor_agent) |
|
graph.add_node("tool", tool_node) |
|
graph.add_node("human_interrupt", human_interrupt_node.call_human_interrupt_agent) |
|
graph.set_entry_point(planner_agent.agent.agent_name) |
|
|
|
def should_continue(state: PlanningState): |
|
last_message = state.messages[-1] |
|
if last_message.tool_calls: |
|
return "human_interrupt" |
|
return executor_agent.agent.agent_name |
|
|
|
def execute_agent(state: PlanningState): |
|
if state.current_step: |
|
return state.current_step.agent |
|
|
|
return END |
|
|
|
graph.add_conditional_edges(trello_agent.agent.agent_name, should_continue) |
|
graph.add_conditional_edges(github_agent.agent.agent_name, should_continue) |
|
graph.add_conditional_edges(executor_agent.agent.agent_name, execute_agent) |
|
|
|
graph.add_edge("tool", trello_agent.agent.agent_name) |
|
graph.add_edge("tool", github_agent.agent.agent_name) |
|
graph.add_edge(planner_agent.agent.agent_name, executor_agent.agent.agent_name) |
|
|
|
app = graph.compile(checkpointer=memory) |
|
app.get_graph(xray=True).draw_mermaid() |
|
|
|
user_input = input("user >") |
|
config = { |
|
"configurable": {"thread_id": f"{str(uuid.uuid4())}"}, |
|
"recursion_limit": 100, |
|
} |
|
|
|
is_message_command = False |
|
while user_input.lower() != "q": |
|
|
|
if is_message_command: |
|
app_input = human_resume_node.call_human_interrupt_agent( |
|
user_input |
|
) |
|
is_message_command = False |
|
else: |
|
app_input = { |
|
"messages": [ |
|
HumanMessage(content=user_input), |
|
] |
|
} |
|
async for res in app.astream( |
|
app_input, |
|
config=config, |
|
stream_mode="values", |
|
): |
|
if "messages" in res: |
|
pprint.pprint(res["messages"][-1]) |
|
else: |
|
pprint.pprint(res["__interrupt__"][0]) |
|
is_message_command = True |
|
pprint.pprint("-------------------------------------") |
|
user_input = input("user >") |
|
|
|
|
|
if __name__ == "__main__": |
|
import asyncio |
|
|
|
asyncio.run(main()) |
|
|