from typing import TypedDict, Annotated, Sequence import operator from langgraph.graph import StateGraph, END from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline from ai_tools import Calculator, DocRetriever, WebSearcher # Configuration MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta" tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) model = AutoModelForCausalLM.from_pretrained(MODEL_NAME) llm_pipeline = pipeline("text-generation", model=model, tokenizer=tokenizer) # Define tools tools = [Calculator(), WebSearcher()] doc_retriever = DocRetriever() tool_map = {tool.name: tool for tool in tools} tool_map["DocRetriever"] = doc_retriever # Agent State class AgentState(TypedDict): input: str context: Annotated[Sequence[str], operator.add] last_tool: str # Tool calling prompt template TOOL_PROMPT = """<|system|> You're an expert problem solver. Use these tools: {tool_descriptions} Respond ONLY in this format: Thought: Action: Action Input: <|user|> {input} Context: {context} <|assistant|> """ # Initialize graph graph = StateGraph(AgentState) # Node: Generate tool calls def agent_node(state): tool_list = "\n".join([f"- {t.name}: {t.description}" for t in tools]) prompt = TOOL_PROMPT.format( tool_descriptions=tool_list, input=state["input"], context=state["context"] ) response = llm_pipeline( prompt, max_new_tokens=150, do_sample=True, temperature=0.2, pad_token_id=tokenizer.eos_token_id )[0]['generated_text'] # Extract tool call action_match = re.search(r"Action: (\w+)", response) action_input_match = re.search(r"Action Input: (.+?)\n", response) if action_match and action_input_match: tool_name = action_match.group(1) tool_input = action_input_match.group(1).strip() return { "last_tool": tool_name, "tool_input": tool_input, "thought": response } else: return {"last_tool": "FINISH", "output": response} # Node: Execute tools def tool_node(state): tool = tool_map.get(state["last_tool"]) if not tool: return {"context": f"Error: Unknown tool {state['last_tool']}"} result = tool.run(state["tool_input"]) return {"context": f"Tool {tool.name} returned: {result}"} # Define graph structure graph.add_node("agent", agent_node) graph.add_node("tool", tool_node) graph.set_entry_point("agent") # Conditional edges def route_action(state): if state["last_tool"] == "FINISH": return END return "tool" graph.add_edge("agent", "tool") graph.add_conditional_edges("tool", route_action, {"agent": "agent", END: END}) graph.add_edge("tool", "agent") # Loop back after tool use # Compile the agent agent = graph.compile() # Interface function def run_agent(query: str, document: str = ""): doc_retriever.document = document # Load document state = {"input": query, "context": [], "last_tool": ""} for step in agent.stream(state): for node, value in step.items(): if node == "agent": print(f"THOUGHT: {value['thought']}") if node == "tool": print(f"TOOL RESULT: {value['context']}") return state["context"][-1] if state["context"] else "No output"