File size: 3,372 Bytes
1310897
 
 
 
 
420e08f
1310897
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
420e08f
1310897
 
 
 
 
 
 
420e08f
1310897
 
 
420e08f
1310897
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
420e08f
1310897
 
 
 
 
 
420e08f
1310897
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
from typing import TypedDict, Annotated, Sequence
import operator
from langgraph.graph import StateGraph, END
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from ai_tools import Calculator, DocRetriever, WebSearcher

# Configuration
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
llm_pipeline = pipeline("text-generation", model=model, tokenizer=tokenizer)

# Define tools
tools = [Calculator(), WebSearcher()]
doc_retriever = DocRetriever()
tool_map = {tool.name: tool for tool in tools}
tool_map["DocRetriever"] = doc_retriever

# Agent State
class AgentState(TypedDict):
    input: str
    context: Annotated[Sequence[str], operator.add]
    last_tool: str

# Tool calling prompt template
TOOL_PROMPT = """<|system|>
You're an expert problem solver. Use these tools:
{tool_descriptions}

Respond ONLY in this format:
Thought: <strategy>
Action: <tool_name>
Action Input: <input>
</s>
<|user|>
{input}
Context: {context}
</s>
<|assistant|>
"""

# Initialize graph
graph = StateGraph(AgentState)

# Node: Generate tool calls
def agent_node(state):
    tool_list = "\n".join([f"- {t.name}: {t.description}" for t in tools])
    prompt = TOOL_PROMPT.format(
        tool_descriptions=tool_list,
        input=state["input"],
        context=state["context"]
    )
    
    response = llm_pipeline(
        prompt,
        max_new_tokens=150,
        do_sample=True,
        temperature=0.2,
        pad_token_id=tokenizer.eos_token_id
    )[0]['generated_text']
    
    # Extract tool call
    action_match = re.search(r"Action: (\w+)", response)
    action_input_match = re.search(r"Action Input: (.+?)\n", response)
    
    if action_match and action_input_match:
        tool_name = action_match.group(1)
        tool_input = action_input_match.group(1).strip()
        return {
            "last_tool": tool_name,
            "tool_input": tool_input,
            "thought": response
        }
    else:
        return {"last_tool": "FINISH", "output": response}

# Node: Execute tools
def tool_node(state):
    tool = tool_map.get(state["last_tool"])
    if not tool:
        return {"context": f"Error: Unknown tool {state['last_tool']}"}
    
    result = tool.run(state["tool_input"])
    return {"context": f"Tool {tool.name} returned: {result}"}

# Define graph structure
graph.add_node("agent", agent_node)
graph.add_node("tool", tool_node)
graph.set_entry_point("agent")

# Conditional edges
def route_action(state):
    if state["last_tool"] == "FINISH":
        return END
    return "tool"

graph.add_edge("agent", "tool")
graph.add_conditional_edges("tool", route_action, {"agent": "agent", END: END})
graph.add_edge("tool", "agent")  # Loop back after tool use

# Compile the agent
agent = graph.compile()

# Interface function
def run_agent(query: str, document: str = ""):
    doc_retriever.document = document  # Load document
    state = {"input": query, "context": [], "last_tool": ""}
    
    for step in agent.stream(state):
        for node, value in step.items():
            if node == "agent":
                print(f"THOUGHT: {value['thought']}")
            if node == "tool":
                print(f"TOOL RESULT: {value['context']}")
    
    return state["context"][-1] if state["context"] else "No output"