from typing import TypedDict, Annotated, Sequence import operator import re from langgraph.graph import StateGraph, END from .ai_tools import Calculator, DocRetriever, WebSearcher class AgentState(TypedDict): input: str context: Annotated[Sequence[str], operator.add] last_tool: str output: str class GaiaGraph: def __init__(self, model, tokenizer, tools): self.model = model self.tokenizer = tokenizer self.tools = tools self.tool_map = {tool.name: tool for tool in tools} self.graph = self._build_graph() def _build_graph(self): graph = StateGraph(AgentState) graph.add_node("agent", self._agent_node) graph.add_node("tool", self._tool_node) graph.set_entry_point("agent") graph.add_edge("agent", "tool") graph.add_conditional_edges( "tool", self._route_action, {"continue": "agent", "end": END} ) return graph.compile() def _agent_node(self, state: AgentState) -> dict: tool_list = "\n".join([f"- {t.name}: {t.description}" for t in self.tools]) prompt = f"""<|system|> You're an expert problem solver. Use these tools when needed: {tool_list} Respond ONLY in this format: Thought: Action: Action Input: <|user|> {state['input']} Context: {state['context']} <|assistant|> """ response = self.model( prompt, max_new_tokens=200, do_sample=True, temperature=0.2, pad_token_id=self.tokenizer.eos_token_id )[0]['generated_text'] # Extract tool call action_match = re.search(r"Action: (\w+)", response) action_input_match = re.search(r"Action Input: (.+?)\n", response, re.DOTALL) if action_match and action_input_match: tool_name = action_match.group(1) tool_input = action_input_match.group(1).strip() return { "last_tool": tool_name, "tool_input": tool_input, "thought": response } else: return {"last_tool": "FINISH", "output": response} def _tool_node(self, state: AgentState) -> dict: if state["last_tool"] == "FINISH": return {"output": state.get("output", "No output generated")} tool = self.tool_map.get(state["last_tool"]) if not tool: return {"context": f"Error: Unknown tool {state['last_tool']}"} result = tool.run(state["tool_input"]) return {"context": f"Tool {tool.name} returned: {result}"} def _route_action(self, state: AgentState) -> str: return "end" if state["last_tool"] == "FINISH" else "continue" def run(self, input: str) -> str: state = {"input": input, "context": [], "last_tool": "", "output": ""} for step in self.graph.stream(state): for node, value in step.items(): if node == "__end__": return value["output"] return "Execution completed without output"