superone001 commited on
Commit
5c751cb
·
verified ·
1 Parent(s): 8557302

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +82 -38
agent.py CHANGED
@@ -1,46 +1,90 @@
1
- from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
2
- from .ai_tools import Calculator, DocRetriever, WebSearcher
3
- from .graph import GaiaGraph
 
 
 
 
 
4
 
5
- class GaiaAgent:
6
- def __init__(self, model_name="HuggingFaceH4/zephyr-7b-beta"):
7
- self.model_name = model_name
8
- self.tokenizer = AutoTokenizer.from_pretrained(model_name)
9
- self.model = AutoModelForCausalLM.from_pretrained(model_name)
10
- self.llm_pipeline = pipeline(
11
- "text-generation",
12
- model=self.model,
13
- tokenizer=self.tokenizer
14
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
- # Initialize tools
17
- self.calculator = Calculator()
18
- self.doc_retriever = DocRetriever()
19
- self.web_searcher = WebSearcher()
20
 
21
- # Create tool list
22
- self.tools = [
23
- self.calculator,
24
- self.web_searcher,
25
- self.doc_retriever
26
- ]
27
 
28
- # Build LangGraph workflow
29
- self.graph = GaiaGraph(
30
- model=self.llm_pipeline,
31
- tokenizer=self.tokenizer,
32
- tools=self.tools
33
- )
 
 
 
 
 
 
34
 
35
- print(f"GaiaAgent initialized with model: {model_name}")
 
 
 
 
 
 
 
 
 
 
36
 
37
- def load_document(self, document_text: str):
38
- """Load document content for retrieval"""
39
- self.doc_retriever.load_document(document_text)
40
- print(f"Document loaded ({len(document_text)} characters)")
 
 
 
41
 
42
  def __call__(self, question: str) -> str:
43
- print(f"Agent received question: {question[:50]}{'...' if len(question) > 50 else ''}")
44
- result = self.graph.run(question)
45
- print(f"Agent returning answer: {result[:50]}{'...' if len(result) > 50 else ''}")
46
- return result
 
 
 
 
 
 
 
1
+ from typing import Annotated, Sequence, TypedDict
2
+ from langchain_community.llms import HuggingFaceHub
3
+ from langchain_core.messages import BaseMessage, HumanMessage, AIMessage
4
+ from langgraph.graph import StateGraph, END
5
+ from langchain_core.agents import AgentAction, AgentFinish
6
+ from langchain.agents import create_react_agent
7
+ from langchain import hub
8
+ from ai_tools import get_tools # 导入自定义工具集
9
 
10
+ class AgentState(TypedDict):
11
+ messages: Annotated[Sequence[BaseMessage], operator.add]
12
+ intermediate_steps: Annotated[list, operator.add]
13
+
14
+ def build_graph():
15
+ # 1. 初始化模型 - 使用HuggingFace免费接口
16
+ llm = HuggingFaceHub(
17
+ repo_id="mistralai/Mistral-7B-Instruct-v0.2",
18
+ model_kwargs={"temperature": 0.1, "max_new_tokens": 500}
19
+ )
20
+
21
+ # 2. 创建ReAct代理
22
+ prompt = hub.pull("hwchase17/react")
23
+ tools = get_tools() # 从ai_tools获取工具
24
+ agent = create_react_agent(llm, tools, prompt)
25
+
26
+ # 3. 定义节点行为
27
+ def agent_node(state: AgentState):
28
+ input = state["messages"][-1].content
29
+ result = agent.invoke({
30
+ "input": input,
31
+ "intermediate_steps": state["intermediate_steps"]
32
+ })
33
+ return {"intermediate_steps": [result]}
34
+
35
+ def tool_node(state: AgentState):
36
+ last_step = state["intermediate_steps"][-1]
37
+ action = last_step[0] if isinstance(last_step, list) else last_step
38
 
39
+ if not isinstance(action, AgentAction):
40
+ return {"messages": [AIMessage(content="Invalid action format")]}
 
 
41
 
42
+ # 执行工具调用
43
+ tool = next((t for t in tools if t.name == action.tool), None)
44
+ if not tool:
45
+ return {"messages": [AIMessage(content=f"Tool {action.tool} not found")]}
 
 
46
 
47
+ observation = tool.run(action.tool_input)
48
+ return {"messages": [AIMessage(content=observation)]}
49
+
50
+ # 4. 构建状态图
51
+ workflow = StateGraph(AgentState)
52
+ workflow.add_node("agent", agent_node)
53
+ workflow.add_node("tool", tool_node)
54
+
55
+ # 5. 定义边和条件
56
+ def route_action(state: AgentState):
57
+ last_step = state["intermediate_steps"][-1]
58
+ action = last_step[0] if isinstance(last_step, list) else last_step
59
 
60
+ if isinstance(action, AgentFinish):
61
+ return END
62
+ return "tool"
63
+
64
+ workflow.set_entry_point("agent")
65
+ workflow.add_conditional_edges(
66
+ "agent",
67
+ route_action,
68
+ {"tool": "tool", END: END}
69
+ )
70
+ workflow.add_edge("tool", "agent")
71
 
72
+ return workflow.compile()
73
+
74
+ class BasicAgent:
75
+ """LangGraph智能体封装"""
76
+ def __init__(self):
77
+ print("BasicAgent initialized.")
78
+ self.graph = build_graph()
79
 
80
  def __call__(self, question: str) -> str:
81
+ print(f"Agent received question: {question[:50]}...")
82
+ messages = [HumanMessage(content=question)]
83
+ result = self.graph.invoke({
84
+ "messages": messages,
85
+ "intermediate_steps": []
86
+ })
87
+
88
+ # 提取最终答案
89
+ final_message = result["messages"][-1].content
90
+ return final_message.strip()