superone001 commited on
Commit
1d755bf
·
verified ·
1 Parent(s): d58ea8f

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +43 -113
agent.py CHANGED
@@ -1,116 +1,46 @@
1
- from typing import TypedDict, Annotated, Sequence
2
- import operator
3
- from langgraph.graph import StateGraph, END
4
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
5
- from ai_tools import Calculator, DocRetriever, WebSearcher
6
-
7
- # Configuration
8
- MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
9
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
10
- model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
11
- llm_pipeline = pipeline("text-generation", model=model, tokenizer=tokenizer)
12
-
13
- # Define tools
14
- tools = [Calculator(), WebSearcher()]
15
- doc_retriever = DocRetriever()
16
- tool_map = {tool.name: tool for tool in tools}
17
- tool_map["DocRetriever"] = doc_retriever
18
-
19
- # Agent State
20
- class AgentState(TypedDict):
21
- input: str
22
- context: Annotated[Sequence[str], operator.add]
23
- last_tool: str
24
-
25
- # Tool calling prompt template
26
- TOOL_PROMPT = """<|system|>
27
- You're an expert problem solver. Use these tools:
28
- {tool_descriptions}
29
-
30
- Respond ONLY in this format:
31
- Thought: <strategy>
32
- Action: <tool_name>
33
- Action Input: <input>
34
- </s>
35
- <|user|>
36
- {input}
37
- Context: {context}
38
- </s>
39
- <|assistant|>
40
- """
41
-
42
- # Initialize graph
43
- graph = StateGraph(AgentState)
44
-
45
- # Node: Generate tool calls
46
- def agent_node(state):
47
- tool_list = "\n".join([f"- {t.name}: {t.description}" for t in tools])
48
- prompt = TOOL_PROMPT.format(
49
- tool_descriptions=tool_list,
50
- input=state["input"],
51
- context=state["context"]
52
- )
53
-
54
- response = llm_pipeline(
55
- prompt,
56
- max_new_tokens=150,
57
- do_sample=True,
58
- temperature=0.2,
59
- pad_token_id=tokenizer.eos_token_id
60
- )[0]['generated_text']
61
-
62
- # Extract tool call
63
- action_match = re.search(r"Action: (\w+)", response)
64
- action_input_match = re.search(r"Action Input: (.+?)\n", response)
65
-
66
- if action_match and action_input_match:
67
- tool_name = action_match.group(1)
68
- tool_input = action_input_match.group(1).strip()
69
- return {
70
- "last_tool": tool_name,
71
- "tool_input": tool_input,
72
- "thought": response
73
- }
74
- else:
75
- return {"last_tool": "FINISH", "output": response}
76
-
77
- # Node: Execute tools
78
- def tool_node(state):
79
- tool = tool_map.get(state["last_tool"])
80
- if not tool:
81
- return {"context": f"Error: Unknown tool {state['last_tool']}"}
82
-
83
- result = tool.run(state["tool_input"])
84
- return {"context": f"Tool {tool.name} returned: {result}"}
85
-
86
- # Define graph structure
87
- graph.add_node("agent", agent_node)
88
- graph.add_node("tool", tool_node)
89
- graph.set_entry_point("agent")
90
-
91
- # Conditional edges
92
- def route_action(state):
93
- if state["last_tool"] == "FINISH":
94
- return END
95
- return "tool"
96
-
97
- graph.add_edge("agent", "tool")
98
- graph.add_conditional_edges("tool", route_action, {"agent": "agent", END: END})
99
- graph.add_edge("tool", "agent") # Loop back after tool use
100
-
101
- # Compile the agent
102
- agent = graph.compile()
103
-
104
- # Interface function
105
- def run_agent(query: str, document: str = ""):
106
- doc_retriever.document = document # Load document
107
- state = {"input": query, "context": [], "last_tool": ""}
108
 
109
- for step in agent.stream(state):
110
- for node, value in step.items():
111
- if node == "agent":
112
- print(f"THOUGHT: {value['thought']}")
113
- if node == "tool":
114
- print(f"TOOL RESULT: {value['context']}")
115
 
116
- return state["context"][-1] if state["context"] else "No output"
 
 
 
 
 
 
 
 
1
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
2
+ from .ai_tools import Calculator, DocRetriever, WebSearcher
3
+ from .graph import GaiaGraph
4
+
5
+ class GaiaAgent:
6
+ def __init__(self, model_name="HuggingFaceH4/zephyr-7b-beta"):
7
+ self.model_name = model_name
8
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
9
+ self.model = AutoModelForCausalLM.from_pretrained(model_name)
10
+ self.llm_pipeline = pipeline(
11
+ "text-generation",
12
+ model=self.model,
13
+ tokenizer=self.tokenizer
14
+ )
15
+
16
+ # Initialize tools
17
+ self.calculator = Calculator()
18
+ self.doc_retriever = DocRetriever()
19
+ self.web_searcher = WebSearcher()
20
+
21
+ # Create tool list
22
+ self.tools = [
23
+ self.calculator,
24
+ self.web_searcher,
25
+ self.doc_retriever
26
+ ]
27
+
28
+ # Build LangGraph workflow
29
+ self.graph = GaiaGraph(
30
+ model=self.llm_pipeline,
31
+ tokenizer=self.tokenizer,
32
+ tools=self.tools
33
+ )
34
+
35
+ print(f"GaiaAgent initialized with model: {model_name}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
+ def load_document(self, document_text: str):
38
+ """Load document content for retrieval"""
39
+ self.doc_retriever.load_document(document_text)
40
+ print(f"Document loaded ({len(document_text)} characters)")
 
 
41
 
42
+ def __call__(self, question: str) -> str:
43
+ print(f"Agent received question: {question[:50]}{'...' if len(question) > 50 else ''}")
44
+ result = self.graph.run(question)
45
+ print(f"Agent returning answer: {result[:50]}{'...' if len(result) > 50 else ''}")
46
+ return result