File size: 3,170 Bytes
5c751cb
 
 
 
 
 
 
 
2826c36
1d755bf
5c751cb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1d755bf
5c751cb
 
1d755bf
5c751cb
 
 
 
1d755bf
5c751cb
 
 
 
 
 
 
 
 
 
 
 
1d755bf
5c751cb
 
 
 
 
 
 
 
 
 
 
420e08f
5c751cb
 
 
 
 
 
 
420e08f
1d755bf
5c751cb
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
from typing import Annotated, Sequence, TypedDict
from langchain_community.llms import HuggingFaceHub
from langchain_core.messages import BaseMessage, HumanMessage, AIMessage
from langgraph.graph import StateGraph, END
from langchain_core.agents import AgentAction, AgentFinish
from langchain.agents import create_react_agent
from langchain import hub
from ai_tools import get_tools  # 导入自定义工具集
import operator 

class AgentState(TypedDict):
    messages: Annotated[Sequence[BaseMessage], operator.add]
    intermediate_steps: Annotated[list, operator.add]

def build_graph():
    # 1. 初始化模型 - 使用HuggingFace免费接口
    llm = HuggingFaceHub(
        repo_id="mistralai/Mistral-7B-Instruct-v0.2",
        model_kwargs={"temperature": 0.1, "max_new_tokens": 500}
    )
    
    # 2. 创建ReAct代理
    prompt = hub.pull("hwchase17/react")
    tools = get_tools()  # 从ai_tools获取工具
    agent = create_react_agent(llm, tools, prompt)
    
    # 3. 定义节点行为
    def agent_node(state: AgentState):
        input = state["messages"][-1].content
        result = agent.invoke({
            "input": input,
            "intermediate_steps": state["intermediate_steps"]
        })
        return {"intermediate_steps": [result]}
    
    def tool_node(state: AgentState):
        last_step = state["intermediate_steps"][-1]
        action = last_step[0] if isinstance(last_step, list) else last_step
        
        if not isinstance(action, AgentAction):
            return {"messages": [AIMessage(content="Invalid action format")]}
        
        # 执行工具调用
        tool = next((t for t in tools if t.name == action.tool), None)
        if not tool:
            return {"messages": [AIMessage(content=f"Tool {action.tool} not found")]}
        
        observation = tool.run(action.tool_input)
        return {"messages": [AIMessage(content=observation)]}
    
    # 4. 构建状态图
    workflow = StateGraph(AgentState)
    workflow.add_node("agent", agent_node)
    workflow.add_node("tool", tool_node)
    
    # 5. 定义边和条件
    def route_action(state: AgentState):
        last_step = state["intermediate_steps"][-1]
        action = last_step[0] if isinstance(last_step, list) else last_step
        
        if isinstance(action, AgentFinish):
            return END
        return "tool"
    
    workflow.set_entry_point("agent")
    workflow.add_conditional_edges(
        "agent",
        route_action,
        {"tool": "tool", END: END}
    )
    workflow.add_edge("tool", "agent")
    
    return workflow.compile()

class BasicAgent:
    """LangGraph智能体封装"""
    def __init__(self):
        print("BasicAgent initialized.")
        self.graph = build_graph()
    
    def __call__(self, question: str) -> str:
        print(f"Agent received question: {question[:50]}...")
        messages = [HumanMessage(content=question)]
        result = self.graph.invoke({
            "messages": messages,
            "intermediate_steps": []
        })
        
        # 提取最终答案
        final_message = result["messages"][-1].content
        return final_message.strip()