dkolarova commited on
Commit
ed734e9
·
verified ·
1 Parent(s): 185f5a7

Create planning_agent.py

Browse files
Files changed (1) hide show
  1. planning_agent.py +298 -0
planning_agent.py ADDED
@@ -0,0 +1,298 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Any
2
+ from dataclasses import dataclass
3
+ from datetime import datetime
4
+ import os
5
+ import json
6
+ import requests
7
+
8
+
9
+ @dataclass
10
+ class Interaction:
11
+ """Record of a single interaction with the agent"""
12
+ timestamp: datetime
13
+ query: str
14
+ plan: Dict[str, Any]
15
+
16
+ class Agent:
17
+ def __init__(self, model: str = "Qwen/Qwen2.5-Coder-32B-Instruct"):
18
+ """Initialize Agent with empty interaction history."""
19
+ self.interactions: List[Interaction] = [] # Working memory
20
+ self.model = model
21
+
22
+ def _query_llm(self, messages):
23
+ headers = {
24
+ "Content-Type": "application/json"
25
+ }
26
+ data = {
27
+ "model": self.model,
28
+ "messages": messages,
29
+ "max_tokens": 150
30
+ }
31
+ response = requests.post("https://api-inference.huggingface.co/v1/chat/completions", headers=headers, data=json.dumps(data))
32
+ print("Original response ", response.json())
33
+ print("\nOriginal response type", type(json.loads(response.choices[0].message.content)))
34
+ # final_response = response.json()['choices'][0]['message']['content'].strip()
35
+ # print("LLM Response ", final_response)
36
+
37
+ return response.choices[0].message.content
38
+
39
+ def create_system_prompt(self) -> str:
40
+ """Create the system prompt for the LLM with available tools."""
41
+ tools_json = {
42
+ "role": "AI Assistant",
43
+ "capabilities": [
44
+ "Using provided tools to help users when necessary",
45
+ "Responding directly without tools for questions that don't require tool usage",
46
+ "Planning efficient tool usage sequences",
47
+ "If asked by the user, reflecting on the plan and suggesting changes if needed"
48
+ ],
49
+ "instructions": [
50
+ "Use tools only when they are necessary for the task",
51
+ "If a query can be answered directly, respond with a simple message instead of using tools",
52
+ "When tools are needed, plan their usage efficiently to minimize tool calls",
53
+ "If asked by the user, reflect on the plan and suggest changes if needed"
54
+ ],
55
+ "tools": [
56
+ {
57
+ "name": "convert_currency",
58
+ "description": "Converts currency using latest exchange rates.",
59
+ "parameters": {
60
+ "amount": {
61
+ "type": "float",
62
+ "description": "Amount to convert"
63
+ },
64
+ "from_currency": {
65
+ "type": "str",
66
+ "description": "Source currency code (e.g., USD)"
67
+ },
68
+ "to_currency": {
69
+ "type": "str",
70
+ "description": "Target currency code (e.g., EUR)"
71
+ }
72
+ }
73
+ }
74
+ ],
75
+ "response_format": {
76
+ "type": "json",
77
+ "schema": {
78
+ "requires_tools": {
79
+ "type": "boolean",
80
+ "description": "whether tools are needed for this query"
81
+ },
82
+ "direct_response": {
83
+ "type": "string",
84
+ "description": "response when no tools are needed",
85
+ "optional": True
86
+ },
87
+ "thought": {
88
+ "type": "string",
89
+ "description": "reasoning about how to solve the task (when tools are needed)",
90
+ "optional": True
91
+ },
92
+ "plan": {
93
+ "type": "array",
94
+ "items": {"type": "string"},
95
+ "description": "steps to solve the task (when tools are needed)",
96
+ "optional": True
97
+ },
98
+ "tool_calls": {
99
+ "type": "array",
100
+ "items": {
101
+ "type": "object",
102
+ "properties": {
103
+ "tool": {
104
+ "type": "string",
105
+ "description": "name of the tool"
106
+ },
107
+ "args": {
108
+ "type": "object",
109
+ "description": "parameters for the tool"
110
+ }
111
+ }
112
+ },
113
+ "description": "tools to call in sequence (when tools are needed)",
114
+ "optional": True
115
+ }
116
+ },
117
+ "examples": [
118
+ {
119
+ "query": "Convert 100 USD to EUR",
120
+ "response": {
121
+ "requires_tools": True,
122
+ "thought": "I need to use the currency conversion tool to convert USD to EUR",
123
+ "plan": [
124
+ "Use convert_currency tool to convert 100 USD to EUR",
125
+ "Return the conversion result"
126
+ ],
127
+ "tool_calls": [
128
+ {
129
+ "tool": "convert_currency",
130
+ "args": {
131
+ "amount": 100,
132
+ "from_currency": "USD",
133
+ "to_currency": "EUR"
134
+ }
135
+ }
136
+ ]
137
+ }
138
+ },
139
+ {
140
+ "query": "What's 500 Japanese Yen in British Pounds?",
141
+ "response": {
142
+ "requires_tools": True,
143
+ "thought": "I need to convert JPY to GBP using the currency converter",
144
+ "plan": [
145
+ "Use convert_currency tool to convert 500 JPY to GBP",
146
+ "Return the conversion result"
147
+ ],
148
+ "tool_calls": [
149
+ {
150
+ "tool": "convert_currency",
151
+ "args": {
152
+ "amount": 500,
153
+ "from_currency": "JPY",
154
+ "to_currency": "GBP"
155
+ }
156
+ }
157
+ ]
158
+ }
159
+ },
160
+ {
161
+ "query": "What currency does Japan use?",
162
+ "response": {
163
+ "requires_tools": False,
164
+ "direct_response": "Japan uses the Japanese Yen (JPY) as its official currency. This is common knowledge that doesn't require using the currency conversion tool."
165
+ }
166
+ }
167
+ ]
168
+ }
169
+ }
170
+
171
+ return f"""You are an AI assistant that helps users by providing direct answers or using tools when necessary.
172
+ Configuration, instructions, and available tools are provided in JSON format below:
173
+ {json.dumps(tools_json, indent=2)}
174
+ Always respond with a JSON object following the response_format schema above.
175
+ Remember to use tools only when they are actually needed for the task."""
176
+
177
+ def plan(self, user_query: str) -> Dict:
178
+ """Use LLM to create a plan and store it in memory."""
179
+ messages = [
180
+ {"role": "system", "content": self.create_system_prompt()},
181
+ {"role": "user", "content": user_query}
182
+ ]
183
+
184
+ response = self._query_llm(messages=messages)
185
+
186
+ try:
187
+ plan = json.loads(response)
188
+ # Store the interaction immediately after planning
189
+ interaction = Interaction(
190
+ timestamp=datetime.now(),
191
+ query=user_query,
192
+ plan=plan
193
+ )
194
+ self.interactions.append(interaction)
195
+ return plan
196
+ except json.JSONDecodeError:
197
+ raise ValueError("Failed to parse LLM response as JSON")
198
+
199
+ def reflect_on_plan(self) -> Dict[str, Any]:
200
+ """Reflect on the most recent plan using interaction history."""
201
+ if not self.interactions:
202
+ return {"reflection": "No plan to reflect on", "requires_changes": False}
203
+
204
+ latest_interaction = self.interactions[-1]
205
+
206
+ reflection_prompt = {
207
+ "task": "reflection",
208
+ "context": {
209
+ "user_query": latest_interaction.query,
210
+ "generated_plan": latest_interaction.plan
211
+ },
212
+ "instructions": [
213
+ "Review the generated plan for potential improvements",
214
+ "Consider if the chosen tools are appropriate",
215
+ "Verify tool parameters are correct",
216
+ "Check if the plan is efficient",
217
+ "Determine if tools are actually needed"
218
+ ],
219
+ "response_format": {
220
+ "type": "json",
221
+ "schema": {
222
+ "requires_changes": {
223
+ "type": "boolean",
224
+ "description": "whether the plan needs modifications"
225
+ },
226
+ "reflection": {
227
+ "type": "string",
228
+ "description": "explanation of what changes are needed or why no changes are needed"
229
+ },
230
+ "suggestions": {
231
+ "type": "array",
232
+ "items": {"type": "string"},
233
+ "description": "specific suggestions for improvements",
234
+ "optional": True
235
+ }
236
+ }
237
+ }
238
+ }
239
+
240
+ messages = [
241
+ {"role": "system", "content": self.create_system_prompt()},
242
+ {"role": "user", "content": json.dumps(reflection_prompt, indent=2)}
243
+ ]
244
+
245
+ response = self._query_llm(messages=messages)
246
+
247
+ try:
248
+ return json.loads(response)
249
+ except json.JSONDecodeError:
250
+ return {"reflection": response.choices[0].message.content}
251
+
252
+ def execute(self, user_query: str) -> str:
253
+ """Execute the full pipeline: plan, reflect, and potentially replan."""
254
+ try:
255
+ # Create initial plan (this also stores it in memory)
256
+ initial_plan = self.plan(user_query)
257
+
258
+ # Reflect on the plan using memory
259
+ reflection = self.reflect_on_plan()
260
+
261
+ # Check if reflection suggests changes
262
+ if reflection.get("requires_changes", False):
263
+ # Generate new plan based on reflection
264
+ messages = [
265
+ {"role": "system", "content": self.create_system_prompt()},
266
+ {"role": "user", "content": user_query},
267
+ {"role": "assistant", "content": json.dumps(initial_plan)},
268
+ {"role": "user", "content": f"Please revise the plan based on this feedback: {json.dumps(reflection)}"}
269
+ ]
270
+
271
+ response = self._query_llm(messages=messages)
272
+
273
+ try:
274
+ final_plan = json.loads(response)
275
+ except json.JSONDecodeError:
276
+ final_plan = initial_plan # Fallback to initial plan if parsing fails
277
+ else:
278
+ final_plan = initial_plan
279
+
280
+ # Update the stored interaction with all information
281
+ self.interactions[-1].plan = {
282
+ "initial_plan": initial_plan,
283
+ "reflection": reflection,
284
+ "final_plan": final_plan
285
+ }
286
+
287
+ # Return the appropriate response
288
+ if final_plan.get("requires_tools", True):
289
+ return f"""Initial Thought: {initial_plan['thought']}
290
+ Initial Plan: {'. '.join(initial_plan['plan'])}
291
+ Reflection: {reflection.get('reflection', 'No improvements suggested')}
292
+ Final Plan: {'. '.join(final_plan['plan'])}"""
293
+ else:
294
+ return f"""Response: {final_plan['direct_response']}
295
+ Reflection: {reflection.get('reflection', 'No improvements suggested')}"""
296
+
297
+ except Exception as e:
298
+ return f"Error executing plan: {str(e)}"