File size: 3,215 Bytes
d58ea8f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
from typing import TypedDict, Annotated, Sequence
import operator
import re
from langgraph.graph import StateGraph, END
from .ai_tools import Calculator, DocRetriever, WebSearcher

class AgentState(TypedDict):
    input: str
    context: Annotated[Sequence[str], operator.add]
    last_tool: str
    output: str

class GaiaGraph:
    def __init__(self, model, tokenizer, tools):
        self.model = model
        self.tokenizer = tokenizer
        self.tools = tools
        self.tool_map = {tool.name: tool for tool in tools}
        self.graph = self._build_graph()
    
    def _build_graph(self):
        graph = StateGraph(AgentState)
        
        graph.add_node("agent", self._agent_node)
        graph.add_node("tool", self._tool_node)
        graph.set_entry_point("agent")
        
        graph.add_edge("agent", "tool")
        graph.add_conditional_edges(
            "tool",
            self._route_action,
            {"continue": "agent", "end": END}
        )
        
        return graph.compile()
    
    def _agent_node(self, state: AgentState) -> dict:
        tool_list = "\n".join([f"- {t.name}: {t.description}" for t in self.tools])
        prompt = f"""<|system|>
You're an expert problem solver. Use these tools when needed:
{tool_list}

Respond ONLY in this format:
Thought: <your reasoning>
Action: <tool_name or 'FINISH'>
Action Input: <input for tool>
</s>
<|user|>
{state['input']}
Context: {state['context']}
</s>
<|assistant|>
"""
        
        response = self.model(
            prompt,
            max_new_tokens=200,
            do_sample=True,
            temperature=0.2,
            pad_token_id=self.tokenizer.eos_token_id
        )[0]['generated_text']
        
        # Extract tool call
        action_match = re.search(r"Action: (\w+)", response)
        action_input_match = re.search(r"Action Input: (.+?)\n", response, re.DOTALL)
        
        if action_match and action_input_match:
            tool_name = action_match.group(1)
            tool_input = action_input_match.group(1).strip()
            return {
                "last_tool": tool_name,
                "tool_input": tool_input,
                "thought": response
            }
        else:
            return {"last_tool": "FINISH", "output": response}
    
    def _tool_node(self, state: AgentState) -> dict:
        if state["last_tool"] == "FINISH":
            return {"output": state.get("output", "No output generated")}
        
        tool = self.tool_map.get(state["last_tool"])
        if not tool:
            return {"context": f"Error: Unknown tool {state['last_tool']}"}
        
        result = tool.run(state["tool_input"])
        return {"context": f"Tool {tool.name} returned: {result}"}
    
    def _route_action(self, state: AgentState) -> str:
        return "end" if state["last_tool"] == "FINISH" else "continue"
    
    def run(self, input: str) -> str:
        state = {"input": input, "context": [], "last_tool": "", "output": ""}
        for step in self.graph.stream(state):
            for node, value in step.items():
                if node == "__end__":
                    return value["output"]
        return "Execution completed without output"