superone001 commited on
Commit
d58ea8f
·
verified ·
1 Parent(s): d9df2c4

Create graph.py

Browse files
Files changed (1) hide show
  1. graph.py +98 -0
graph.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import TypedDict, Annotated, Sequence
2
+ import operator
3
+ import re
4
+ from langgraph.graph import StateGraph, END
5
+ from .ai_tools import Calculator, DocRetriever, WebSearcher
6
+
7
+ class AgentState(TypedDict):
8
+ input: str
9
+ context: Annotated[Sequence[str], operator.add]
10
+ last_tool: str
11
+ output: str
12
+
13
+ class GaiaGraph:
14
+ def __init__(self, model, tokenizer, tools):
15
+ self.model = model
16
+ self.tokenizer = tokenizer
17
+ self.tools = tools
18
+ self.tool_map = {tool.name: tool for tool in tools}
19
+ self.graph = self._build_graph()
20
+
21
+ def _build_graph(self):
22
+ graph = StateGraph(AgentState)
23
+
24
+ graph.add_node("agent", self._agent_node)
25
+ graph.add_node("tool", self._tool_node)
26
+ graph.set_entry_point("agent")
27
+
28
+ graph.add_edge("agent", "tool")
29
+ graph.add_conditional_edges(
30
+ "tool",
31
+ self._route_action,
32
+ {"continue": "agent", "end": END}
33
+ )
34
+
35
+ return graph.compile()
36
+
37
+ def _agent_node(self, state: AgentState) -> dict:
38
+ tool_list = "\n".join([f"- {t.name}: {t.description}" for t in self.tools])
39
+ prompt = f"""<|system|>
40
+ You're an expert problem solver. Use these tools when needed:
41
+ {tool_list}
42
+
43
+ Respond ONLY in this format:
44
+ Thought: <your reasoning>
45
+ Action: <tool_name or 'FINISH'>
46
+ Action Input: <input for tool>
47
+ </s>
48
+ <|user|>
49
+ {state['input']}
50
+ Context: {state['context']}
51
+ </s>
52
+ <|assistant|>
53
+ """
54
+
55
+ response = self.model(
56
+ prompt,
57
+ max_new_tokens=200,
58
+ do_sample=True,
59
+ temperature=0.2,
60
+ pad_token_id=self.tokenizer.eos_token_id
61
+ )[0]['generated_text']
62
+
63
+ # Extract tool call
64
+ action_match = re.search(r"Action: (\w+)", response)
65
+ action_input_match = re.search(r"Action Input: (.+?)\n", response, re.DOTALL)
66
+
67
+ if action_match and action_input_match:
68
+ tool_name = action_match.group(1)
69
+ tool_input = action_input_match.group(1).strip()
70
+ return {
71
+ "last_tool": tool_name,
72
+ "tool_input": tool_input,
73
+ "thought": response
74
+ }
75
+ else:
76
+ return {"last_tool": "FINISH", "output": response}
77
+
78
+ def _tool_node(self, state: AgentState) -> dict:
79
+ if state["last_tool"] == "FINISH":
80
+ return {"output": state.get("output", "No output generated")}
81
+
82
+ tool = self.tool_map.get(state["last_tool"])
83
+ if not tool:
84
+ return {"context": f"Error: Unknown tool {state['last_tool']}"}
85
+
86
+ result = tool.run(state["tool_input"])
87
+ return {"context": f"Tool {tool.name} returned: {result}"}
88
+
89
+ def _route_action(self, state: AgentState) -> str:
90
+ return "end" if state["last_tool"] == "FINISH" else "continue"
91
+
92
+ def run(self, input: str) -> str:
93
+ state = {"input": input, "context": [], "last_tool": "", "output": ""}
94
+ for step in self.graph.stream(state):
95
+ for node, value in step.items():
96
+ if node == "__end__":
97
+ return value["output"]
98
+ return "Execution completed without output"