superone001's picture
Update agent.py
2826c36 verified
raw
history blame
3.17 kB
from typing import Annotated, Sequence, TypedDict
from langchain_community.llms import HuggingFaceHub
from langchain_core.messages import BaseMessage, HumanMessage, AIMessage
from langgraph.graph import StateGraph, END
from langchain_core.agents import AgentAction, AgentFinish
from langchain.agents import create_react_agent
from langchain import hub
from ai_tools import get_tools # 导入自定义工具集
import operator
class AgentState(TypedDict):
messages: Annotated[Sequence[BaseMessage], operator.add]
intermediate_steps: Annotated[list, operator.add]
def build_graph():
# 1. 初始化模型 - 使用HuggingFace免费接口
llm = HuggingFaceHub(
repo_id="mistralai/Mistral-7B-Instruct-v0.2",
model_kwargs={"temperature": 0.1, "max_new_tokens": 500}
)
# 2. 创建ReAct代理
prompt = hub.pull("hwchase17/react")
tools = get_tools() # 从ai_tools获取工具
agent = create_react_agent(llm, tools, prompt)
# 3. 定义节点行为
def agent_node(state: AgentState):
input = state["messages"][-1].content
result = agent.invoke({
"input": input,
"intermediate_steps": state["intermediate_steps"]
})
return {"intermediate_steps": [result]}
def tool_node(state: AgentState):
last_step = state["intermediate_steps"][-1]
action = last_step[0] if isinstance(last_step, list) else last_step
if not isinstance(action, AgentAction):
return {"messages": [AIMessage(content="Invalid action format")]}
# 执行工具调用
tool = next((t for t in tools if t.name == action.tool), None)
if not tool:
return {"messages": [AIMessage(content=f"Tool {action.tool} not found")]}
observation = tool.run(action.tool_input)
return {"messages": [AIMessage(content=observation)]}
# 4. 构建状态图
workflow = StateGraph(AgentState)
workflow.add_node("agent", agent_node)
workflow.add_node("tool", tool_node)
# 5. 定义边和条件
def route_action(state: AgentState):
last_step = state["intermediate_steps"][-1]
action = last_step[0] if isinstance(last_step, list) else last_step
if isinstance(action, AgentFinish):
return END
return "tool"
workflow.set_entry_point("agent")
workflow.add_conditional_edges(
"agent",
route_action,
{"tool": "tool", END: END}
)
workflow.add_edge("tool", "agent")
return workflow.compile()
class BasicAgent:
"""LangGraph智能体封装"""
def __init__(self):
print("BasicAgent initialized.")
self.graph = build_graph()
def __call__(self, question: str) -> str:
print(f"Agent received question: {question[:50]}...")
messages = [HumanMessage(content=question)]
result = self.graph.invoke({
"messages": messages,
"intermediate_steps": []
})
# 提取最终答案
final_message = result["messages"][-1].content
return final_message.strip()