dkolarova commited on
Commit
bb3becf
·
verified ·
1 Parent(s): 7585692

Update tools_agent.py

Browse files
Files changed (1) hide show
  1. tools_agent.py +17 -14
tools_agent.py CHANGED
@@ -1,14 +1,19 @@
1
  from typing import Dict, List, Any
2
  from tool_registry import Tool
3
- import openai
4
  import os
5
  import json
6
 
7
 
8
  class Agent:
9
- def __init__(self, client):
 
 
 
 
 
 
10
  """Initialize Agent with empty tool registry."""
11
- self.client = client
12
  self.tools: Dict[str, Tool] = {}
13
 
14
  def add_tool(self, tool: Tool) -> None:
@@ -153,9 +158,7 @@ class Agent:
153
 
154
  return f"""You are an AI assistant that helps users by providing direct answers or using tools when necessary.
155
  Configuration, instructions, and available tools are provided in JSON format below:
156
-
157
  {json.dumps(tools_json, indent=2)}
158
-
159
  Always respond with a JSON object following the response_format schema above.
160
  Remember to use tools only when they are actually needed for the task."""
161
 
@@ -166,12 +169,13 @@ Remember to use tools only when they are actually needed for the task."""
166
  {"role": "user", "content": user_query}
167
  ]
168
 
169
- response = self.client.chat.completions.create(
170
- model="gpt-4o-mini",
171
- messages=messages,
172
- temperature=0
 
173
  )
174
-
175
  try:
176
  return json.loads(response.choices[0].message.content)
177
  except json.JSONDecodeError:
@@ -181,7 +185,7 @@ Remember to use tools only when they are actually needed for the task."""
181
  """Execute the full pipeline: plan and execute tools."""
182
  try:
183
  plan = self.plan(user_query)
184
-
185
  if not plan.get("requires_tools", True):
186
  return plan["direct_response"]
187
 
@@ -207,8 +211,7 @@ def main():
207
  agent = Agent()
208
  agent.add_tool(convert_currency)
209
 
210
- query_list = ["I am traveling to Japan from Serbia, I have 1500 of local currency, how much of Japaese currency will I be able to get?",
211
- "How are you doing?"]
212
 
213
  for query in query_list:
214
  print(f"\nQuery: {query}")
@@ -216,4 +219,4 @@ def main():
216
  print(result)
217
 
218
  if __name__ == "__main__":
219
- main()
 
1
  from typing import Dict, List, Any
2
  from tool_registry import Tool
3
+ from huggingface_hub import InferenceClient
4
  import os
5
  import json
6
 
7
 
8
  class Agent:
9
+ def __init__(self):
10
+ """
11
+ For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
12
+ """
13
+ # self.client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
14
+ self.client = InferenceClient("Qwen/Qwen2.5-Coder-32B-Instruct")
15
+
16
  """Initialize Agent with empty tool registry."""
 
17
  self.tools: Dict[str, Tool] = {}
18
 
19
  def add_tool(self, tool: Tool) -> None:
 
158
 
159
  return f"""You are an AI assistant that helps users by providing direct answers or using tools when necessary.
160
  Configuration, instructions, and available tools are provided in JSON format below:
 
161
  {json.dumps(tools_json, indent=2)}
 
162
  Always respond with a JSON object following the response_format schema above.
163
  Remember to use tools only when they are actually needed for the task."""
164
 
 
169
  {"role": "user", "content": user_query}
170
  ]
171
 
172
+ response = self.client.chat_completion(
173
+ messages,
174
+ max_tokens=512,
175
+ temperature=0,
176
+ top_p=0.95,
177
  )
178
+ print(response.choices[0])
179
  try:
180
  return json.loads(response.choices[0].message.content)
181
  except json.JSONDecodeError:
 
185
  """Execute the full pipeline: plan and execute tools."""
186
  try:
187
  plan = self.plan(user_query)
188
+
189
  if not plan.get("requires_tools", True):
190
  return plan["direct_response"]
191
 
 
211
  agent = Agent()
212
  agent.add_tool(convert_currency)
213
 
214
+ query_list = ["I am traveling to Japan from Serbia, I have 1500 of local currency, how much of Japaese currency will I be able to get?"]
 
215
 
216
  for query in query_list:
217
  print(f"\nQuery: {query}")
 
219
  print(result)
220
 
221
  if __name__ == "__main__":
222
+ main()