superone001's picture
Update agent.py
1310897 verified
raw
history blame
3.37 kB
from typing import TypedDict, Annotated, Sequence
import operator
from langgraph.graph import StateGraph, END
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from ai_tools import Calculator, DocRetriever, WebSearcher
# Configuration
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
llm_pipeline = pipeline("text-generation", model=model, tokenizer=tokenizer)
# Define tools
tools = [Calculator(), WebSearcher()]
doc_retriever = DocRetriever()
tool_map = {tool.name: tool for tool in tools}
tool_map["DocRetriever"] = doc_retriever
# Agent State
class AgentState(TypedDict):
input: str
context: Annotated[Sequence[str], operator.add]
last_tool: str
# Tool calling prompt template
TOOL_PROMPT = """<|system|>
You're an expert problem solver. Use these tools:
{tool_descriptions}
Respond ONLY in this format:
Thought: <strategy>
Action: <tool_name>
Action Input: <input>
</s>
<|user|>
{input}
Context: {context}
</s>
<|assistant|>
"""
# Initialize graph
graph = StateGraph(AgentState)
# Node: Generate tool calls
def agent_node(state):
tool_list = "\n".join([f"- {t.name}: {t.description}" for t in tools])
prompt = TOOL_PROMPT.format(
tool_descriptions=tool_list,
input=state["input"],
context=state["context"]
)
response = llm_pipeline(
prompt,
max_new_tokens=150,
do_sample=True,
temperature=0.2,
pad_token_id=tokenizer.eos_token_id
)[0]['generated_text']
# Extract tool call
action_match = re.search(r"Action: (\w+)", response)
action_input_match = re.search(r"Action Input: (.+?)\n", response)
if action_match and action_input_match:
tool_name = action_match.group(1)
tool_input = action_input_match.group(1).strip()
return {
"last_tool": tool_name,
"tool_input": tool_input,
"thought": response
}
else:
return {"last_tool": "FINISH", "output": response}
# Node: Execute tools
def tool_node(state):
tool = tool_map.get(state["last_tool"])
if not tool:
return {"context": f"Error: Unknown tool {state['last_tool']}"}
result = tool.run(state["tool_input"])
return {"context": f"Tool {tool.name} returned: {result}"}
# Define graph structure
graph.add_node("agent", agent_node)
graph.add_node("tool", tool_node)
graph.set_entry_point("agent")
# Conditional edges
def route_action(state):
if state["last_tool"] == "FINISH":
return END
return "tool"
graph.add_edge("agent", "tool")
graph.add_conditional_edges("tool", route_action, {"agent": "agent", END: END})
graph.add_edge("tool", "agent") # Loop back after tool use
# Compile the agent
agent = graph.compile()
# Interface function
def run_agent(query: str, document: str = ""):
doc_retriever.document = document # Load document
state = {"input": query, "context": [], "last_tool": ""}
for step in agent.stream(state):
for node, value in step.items():
if node == "agent":
print(f"THOUGHT: {value['thought']}")
if node == "tool":
print(f"TOOL RESULT: {value['context']}")
return state["context"][-1] if state["context"] else "No output"