superone001 commited on
Commit
1310897
·
verified ·
1 Parent(s): 21ef4cf

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +110 -94
agent.py CHANGED
@@ -1,100 +1,116 @@
1
- from typing import Dict, Any, Optional
2
- from langgraph.graph import StateGraph, START, END
3
- from langgraph.prebuilt import ToolNode
4
- from langgraph.checkpoint.sqlite import SqliteSaver
5
- from ai_tools import AITools
6
 
7
- class CustomAgent:
8
- def __init__(self):
9
- self.tools = AITools()
10
- self.workflow = self._build_workflow()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
- def _build_workflow(self):
13
- # 定义工具节点
14
- tool_node = ToolNode([self._route_to_tool])
15
-
16
- # 构建工作流
17
- workflow = StateGraph(State)
18
- workflow.add_node("tools", tool_node)
19
- workflow.add_node("process", self._process_result)
20
-
21
- # 设置入口点
22
- workflow.set_entry_point("tools")
23
-
24
- # 添加条件边
25
- workflow.add_conditional_edges(
26
- "tools",
27
- self._decide_next_step,
28
- {
29
- "continue": "process",
30
- "end": END
31
- }
32
- )
33
-
34
- workflow.add_edge("process", END)
35
-
36
- # 添加持久化
37
- workflow.checkpointer = SqliteSaver.from_conn_string(":memory:")
38
-
39
- return workflow.compile()
40
 
41
- def _route_to_tool(self, state: Dict[str, Any]):
42
- """路由问题到适当的工具"""
43
- question = state.get("question", "")
44
- file_name = state.get("file_name", "")
45
-
46
- # 处理反转文本问题
47
- if "rewsna" in question or "dnatsrednu" in question:
48
- return {"result": self.tools.reverse_text(question.split('"')[1])}
49
-
50
- # 处理蔬菜分类问题
51
- if "grocery list" in question.lower() or "vegetables" in question.lower():
52
- items = re.findall(r"[a-zA-Z]+(?=\W|\Z)", question)
53
- return {"result": ", ".join(self.tools.categorize_vegetables(items))}
54
-
55
- # 处理棋局问题
56
- if "chess position" in question.lower() and file_name.endswith(".png"):
57
- return {"result": self.tools.analyze_chess_position(file_name)}
58
-
59
- # 处理音频文件问题
60
- if file_name.endswith(".mp3"):
61
- transcript = self.tools.extract_audio_transcript(file_name)
62
- if "page numbers" in question.lower():
63
- return {"result": transcript}
64
- else:
65
- return {"result": ", ".join(sorted(transcript.split(", ")))}
66
-
67
- # 处理表格运算问题
68
- if "* on the set S" in question:
69
- table_data = {"operation": "*", "set": ["a", "b", "c", "d", "e"]}
70
- return {"result": self.tools.process_table_operation(table_data)}
71
-
72
- # 处理Python代码问题
73
- if file_name.endswith(".py"):
74
- return {"result": self.tools.analyze_python_code(file_name)}
75
-
76
- # 处理Excel文件问题
77
- if file_name.endswith(".xlsx"):
78
- return {"result": self.tools.process_excel_file(file_name)}
79
-
80
- return {"result": "I don't have a tool to answer this question."}
81
 
82
- def _process_result(self, state: Dict[str, Any]):
83
- """处理工具返回的结果"""
84
- result = state.get("result", "")
85
- return {"answer": result}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
 
87
- def _decide_next_step(self, state: Dict[str, Any]):
88
- """决定下一步"""
89
- result = state.get("result", "")
90
- if result == "I don't have a tool to answer this question.":
91
- return "end"
92
- return "continue"
93
 
94
- def __call__(self, question: str, file_name: str = "") -> str:
95
- """执行Agent"""
96
- state = {"question": question, "file_name": file_name}
97
- for step in self.workflow.stream(state):
98
- if "__end__" in step:
99
- return step["__end__"]["answer"]
100
- return "No answer generated."
 
1
+ from typing import TypedDict, Annotated, Sequence
2
+ import operator
3
+ from langgraph.graph import StateGraph, END
4
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
5
+ from ai_tools import Calculator, DocRetriever, WebSearcher
6
 
7
+ # Configuration
8
+ MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
9
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
10
+ model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
11
+ llm_pipeline = pipeline("text-generation", model=model, tokenizer=tokenizer)
12
+
13
+ # Define tools
14
+ tools = [Calculator(), WebSearcher()]
15
+ doc_retriever = DocRetriever()
16
+ tool_map = {tool.name: tool for tool in tools}
17
+ tool_map["DocRetriever"] = doc_retriever
18
+
19
+ # Agent State
20
+ class AgentState(TypedDict):
21
+ input: str
22
+ context: Annotated[Sequence[str], operator.add]
23
+ last_tool: str
24
+
25
+ # Tool calling prompt template
26
+ TOOL_PROMPT = """<|system|>
27
+ You're an expert problem solver. Use these tools:
28
+ {tool_descriptions}
29
+
30
+ Respond ONLY in this format:
31
+ Thought: <strategy>
32
+ Action: <tool_name>
33
+ Action Input: <input>
34
+ </s>
35
+ <|user|>
36
+ {input}
37
+ Context: {context}
38
+ </s>
39
+ <|assistant|>
40
+ """
41
+
42
+ # Initialize graph
43
+ graph = StateGraph(AgentState)
44
+
45
+ # Node: Generate tool calls
46
+ def agent_node(state):
47
+ tool_list = "\n".join([f"- {t.name}: {t.description}" for t in tools])
48
+ prompt = TOOL_PROMPT.format(
49
+ tool_descriptions=tool_list,
50
+ input=state["input"],
51
+ context=state["context"]
52
+ )
53
 
54
+ response = llm_pipeline(
55
+ prompt,
56
+ max_new_tokens=150,
57
+ do_sample=True,
58
+ temperature=0.2,
59
+ pad_token_id=tokenizer.eos_token_id
60
+ )[0]['generated_text']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
+ # Extract tool call
63
+ action_match = re.search(r"Action: (\w+)", response)
64
+ action_input_match = re.search(r"Action Input: (.+?)\n", response)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
+ if action_match and action_input_match:
67
+ tool_name = action_match.group(1)
68
+ tool_input = action_input_match.group(1).strip()
69
+ return {
70
+ "last_tool": tool_name,
71
+ "tool_input": tool_input,
72
+ "thought": response
73
+ }
74
+ else:
75
+ return {"last_tool": "FINISH", "output": response}
76
+
77
+ # Node: Execute tools
78
+ def tool_node(state):
79
+ tool = tool_map.get(state["last_tool"])
80
+ if not tool:
81
+ return {"context": f"Error: Unknown tool {state['last_tool']}"}
82
+
83
+ result = tool.run(state["tool_input"])
84
+ return {"context": f"Tool {tool.name} returned: {result}"}
85
+
86
+ # Define graph structure
87
+ graph.add_node("agent", agent_node)
88
+ graph.add_node("tool", tool_node)
89
+ graph.set_entry_point("agent")
90
+
91
+ # Conditional edges
92
+ def route_action(state):
93
+ if state["last_tool"] == "FINISH":
94
+ return END
95
+ return "tool"
96
+
97
+ graph.add_edge("agent", "tool")
98
+ graph.add_conditional_edges("tool", route_action, {"agent": "agent", END: END})
99
+ graph.add_edge("tool", "agent") # Loop back after tool use
100
+
101
+ # Compile the agent
102
+ agent = graph.compile()
103
+
104
+ # Interface function
105
+ def run_agent(query: str, document: str = ""):
106
+ doc_retriever.document = document # Load document
107
+ state = {"input": query, "context": [], "last_tool": ""}
108
 
109
+ for step in agent.stream(state):
110
+ for node, value in step.items():
111
+ if node == "agent":
112
+ print(f"THOUGHT: {value['thought']}")
113
+ if node == "tool":
114
+ print(f"TOOL RESULT: {value['context']}")
115
 
116
+ return state["context"][-1] if state["context"] else "No output"