|
from typing import Annotated, Sequence, TypedDict |
|
from langchain_community.llms import HuggingFaceHub |
|
from langchain_core.messages import BaseMessage, HumanMessage, AIMessage |
|
from langgraph.graph import StateGraph, END |
|
from langchain_core.agents import AgentAction, AgentFinish |
|
from langchain.agents import create_react_agent |
|
from langchain import hub |
|
from ai_tools import get_tools |
|
import operator |
|
|
|
class AgentState(TypedDict): |
|
messages: Annotated[Sequence[BaseMessage], operator.add] |
|
intermediate_steps: Annotated[list, operator.add] |
|
|
|
def build_graph(): |
|
|
|
llm = HuggingFaceHub( |
|
repo_id="mistralai/Mistral-7B-Instruct-v0.2", |
|
model_kwargs={"temperature": 0.1, "max_new_tokens": 500} |
|
) |
|
|
|
|
|
prompt = hub.pull("hwchase17/react") |
|
tools = get_tools() |
|
agent = create_react_agent(llm, tools, prompt) |
|
|
|
|
|
def agent_node(state: AgentState): |
|
input = state["messages"][-1].content |
|
result = agent.invoke({ |
|
"input": input, |
|
"intermediate_steps": state["intermediate_steps"] |
|
}) |
|
return {"intermediate_steps": [result]} |
|
|
|
def tool_node(state: AgentState): |
|
last_step = state["intermediate_steps"][-1] |
|
action = last_step[0] if isinstance(last_step, list) else last_step |
|
|
|
if not isinstance(action, AgentAction): |
|
return {"messages": [AIMessage(content="Invalid action format")]} |
|
|
|
|
|
tool = next((t for t in tools if t.name == action.tool), None) |
|
if not tool: |
|
return {"messages": [AIMessage(content=f"Tool {action.tool} not found")]} |
|
|
|
observation = tool.run(action.tool_input) |
|
return {"messages": [AIMessage(content=observation)]} |
|
|
|
|
|
workflow = StateGraph(AgentState) |
|
workflow.add_node("agent", agent_node) |
|
workflow.add_node("tool", tool_node) |
|
|
|
|
|
def route_action(state: AgentState): |
|
last_step = state["intermediate_steps"][-1] |
|
action = last_step[0] if isinstance(last_step, list) else last_step |
|
|
|
if isinstance(action, AgentFinish): |
|
return END |
|
return "tool" |
|
|
|
workflow.set_entry_point("agent") |
|
workflow.add_conditional_edges( |
|
"agent", |
|
route_action, |
|
{"tool": "tool", END: END} |
|
) |
|
workflow.add_edge("tool", "agent") |
|
|
|
return workflow.compile() |
|
|
|
class BasicAgent: |
|
"""LangGraph智能体封装""" |
|
def __init__(self): |
|
print("BasicAgent initialized.") |
|
self.graph = build_graph() |
|
|
|
def __call__(self, question: str) -> str: |
|
print(f"Agent received question: {question[:50]}...") |
|
messages = [HumanMessage(content=question)] |
|
result = self.graph.invoke({ |
|
"messages": messages, |
|
"intermediate_steps": [] |
|
}) |
|
|
|
|
|
final_message = result["messages"][-1].content |
|
return final_message.strip() |