JuanjoSG5 commited on
Commit
76d4323
·
1 Parent(s): 2d2877d

test: testing agent

Browse files
Files changed (1) hide show
  1. agent_test.py +360 -0
agent_test.py ADDED
@@ -0,0 +1,360 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import os
3
+ import json
4
+ import base64
5
+ from typing import List, Dict, Any, Union
6
+ from contextlib import AsyncExitStack
7
+ from io import BytesIO
8
+ from PIL import Image
9
+ import gradio as gr
10
+ from gradio.components.chatbot import ChatMessage
11
+ from mcp import ClientSession, StdioServerParameters
12
+ from mcp.client.stdio import stdio_client
13
+ from dotenv import load_dotenv
14
+ from langchain_openai import ChatOpenAI
15
+
16
+ load_dotenv()
17
+
18
+ loop = asyncio.new_event_loop()
19
+ asyncio.set_event_loop(loop)
20
+
21
+ class MCPClientWrapper:
22
+ def __init__(self):
23
+ self.session = None
24
+ self.exit_stack = None
25
+ self.mistral = ChatOpenAI(model_name="mistralai/mistral-small", temperature=0.7, openai_api_key=os.getenv("OPENROUTER_API_KEY"), openai_api_base=os.getenv("OPENROUTER_API_BASE_URL"))
26
+ self.tools = []
27
+
28
+ def connect(self, server_path: str) -> str:
29
+ return loop.run_until_complete(self._connect(server_path))
30
+
31
+ async def _connect(self, server_path: str) -> str:
32
+ if self.exit_stack:
33
+ await self.exit_stack.aclose()
34
+
35
+ self.exit_stack = AsyncExitStack()
36
+
37
+ is_python = server_path.endswith('.py')
38
+ command = "python" if is_python else "node"
39
+
40
+ server_params = StdioServerParameters(
41
+ command=command,
42
+ args=[server_path],
43
+ env={"PYTHONIOENCODING": "utf-8", "PYTHONUNBUFFERED": "1"}
44
+ )
45
+
46
+ stdio_transport = await self.exit_stack.enter_async_context(stdio_client(server_params))
47
+ self.stdio, self.write = stdio_transport
48
+
49
+ self.session = await self.exit_stack.enter_async_context(ClientSession(self.stdio, self.write))
50
+ await self.session.initialize()
51
+
52
+ response = await self.session.list_tools()
53
+ self.tools = [{
54
+ "name": tool.name,
55
+ "description": tool.description,
56
+ "input_schema": tool.inputSchema
57
+ } for tool in response.tools]
58
+
59
+ tool_names = [tool["name"] for tool in self.tools]
60
+ return f"Connected to MCP server. Available tools: {', '.join(tool_names)}"
61
+
62
+ def process_message(self, message: str, history: List[Union[Dict[str, Any], ChatMessage]]) -> tuple:
63
+ if not self.session:
64
+ return history + [
65
+ {"role": "user", "content": message},
66
+ {"role": "assistant", "content": "Please connect to an MCP server first."}
67
+ ], gr.Textbox(value="")
68
+
69
+ new_messages = loop.run_until_complete(self._process_query(message, history))
70
+ return history + [{"role": "user", "content": message}] + new_messages, gr.Textbox(value="")
71
+
72
+ async def _process_query(self, message: str, history: List[Union[Dict[str, Any], ChatMessage]]):
73
+ claude_messages = []
74
+ for msg in history:
75
+ if isinstance(msg, ChatMessage):
76
+ role, content = msg.role, msg.content
77
+ else:
78
+ role, content = msg.get("role"), msg.get("content")
79
+
80
+ if role in ["user", "assistant", "system"]:
81
+ claude_messages.append({"role": role, "content": content})
82
+
83
+ claude_messages.append({"role": "user", "content": message})
84
+
85
+ response = self.mistral.messages.create(
86
+ model="claude-3-5-sonnet-20241022",
87
+ max_tokens=1000,
88
+ messages=claude_messages,
89
+ tools=self.tools
90
+ )
91
+
92
+ result_messages = []
93
+
94
+ for content in response.content:
95
+ if content.type == 'text':
96
+ result_messages.append({
97
+ "role": "assistant",
98
+ "content": content.text
99
+ })
100
+
101
+ elif content.type == 'tool_use':
102
+ tool_name = content.name
103
+ tool_args = content.input
104
+
105
+ result_messages.append({
106
+ "role": "assistant",
107
+ "content": f"I'll use the {tool_name} tool to help answer your question.",
108
+ "metadata": {
109
+ "title": f"Using tool: {tool_name}",
110
+ "log": f"Parameters: {json.dumps(tool_args, ensure_ascii=True)}",
111
+ "status": "pending",
112
+ "id": f"tool_call_{tool_name}"
113
+ }
114
+ })
115
+
116
+ result_messages.append({
117
+ "role": "assistant",
118
+ "content": "```json\n" + json.dumps(tool_args, indent=2, ensure_ascii=True) + "\n```",
119
+ "metadata": {
120
+ "parent_id": f"tool_call_{tool_name}",
121
+ "id": f"params_{tool_name}",
122
+ "title": "Tool Parameters"
123
+ }
124
+ })
125
+
126
+ result = await self.session.call_tool(tool_name, tool_args)
127
+
128
+ if result_messages and "metadata" in result_messages[-2]:
129
+ result_messages[-2]["metadata"]["status"] = "done"
130
+
131
+ result_messages.append({
132
+ "role": "assistant",
133
+ "content": "Here are the results from the tool:",
134
+ "metadata": {
135
+ "title": f"Tool Result for {tool_name}",
136
+ "status": "done",
137
+ "id": f"result_{tool_name}"
138
+ }
139
+ })
140
+
141
+ result_content = result.content
142
+ if isinstance(result_content, list):
143
+ result_content = "\n".join(str(item) for item in result_content)
144
+
145
+ try:
146
+ result_json = json.loads(result_content)
147
+ if isinstance(result_json, dict) and "type" in result_json:
148
+ if result_json["type"] == "image" and "url" in result_json:
149
+ result_messages.append({
150
+ "role": "assistant",
151
+ "content": {"path": result_json["url"], "alt_text": result_json.get("message", "Generated image")},
152
+ "metadata": {
153
+ "parent_id": f"result_{tool_name}",
154
+ "id": f"image_{tool_name}",
155
+ "title": "Generated Image"
156
+ }
157
+ })
158
+ else:
159
+ result_messages.append({
160
+ "role": "assistant",
161
+ "content": "```\n" + result_content + "\n```",
162
+ "metadata": {
163
+ "parent_id": f"result_{tool_name}",
164
+ "id": f"raw_result_{tool_name}",
165
+ "title": "Raw Output"
166
+ }
167
+ })
168
+ except:
169
+ result_messages.append({
170
+ "role": "assistant",
171
+ "content": "```\n" + result_content + "\n```",
172
+ "metadata": {
173
+ "parent_id": f"result_{tool_name}",
174
+ "id": f"raw_result_{tool_name}",
175
+ "title": "Raw Output"
176
+ }
177
+ })
178
+
179
+ claude_messages.append({"role": "user", "content": f"Tool result for {tool_name}: {result_content}"})
180
+ next_response = self.mistral.messages.create(
181
+ model="claude-3-5-sonnet-20241022",
182
+ max_tokens=1000,
183
+ messages=claude_messages,
184
+ )
185
+
186
+ if next_response.content and next_response.content[0].type == 'text':
187
+ result_messages.append({
188
+ "role": "assistant",
189
+ "content": next_response.content[0].text
190
+ })
191
+
192
+ return result_messages
193
+
194
+ # New methods for image processing
195
+ def image_to_base64(self, image):
196
+ """Convert PIL image to base64 string"""
197
+ if image is None:
198
+ return None
199
+ buffered = BytesIO()
200
+ image.save(buffered, format="PNG")
201
+ img_str = base64.b64encode(buffered.getvalue()).decode()
202
+ return img_str
203
+
204
+ async def process_image(self, image, operation, target_format=None, width=None, height=None):
205
+ """Process an image using MCP tools"""
206
+ if not self.session:
207
+ return None, "Please connect to an MCP server first."
208
+
209
+ if image is None:
210
+ return None, "No image provided."
211
+
212
+ try:
213
+ img_base64 = self.image_to_base64(image)
214
+
215
+ if operation == "Remove Background":
216
+ result = await self.session.call_tool("remove_background_from_url", {"url": img_base64})
217
+
218
+ elif operation == "Change Format":
219
+ if not target_format:
220
+ return None, "Please select a target format."
221
+ result = await self.session.call_tool("change_format", {
222
+ "image_base64": img_base64,
223
+ "target_format": target_format.lower()
224
+ })
225
+
226
+ elif operation == "Resize Image":
227
+ if not width or not height:
228
+ return None, "Please provide width and height."
229
+ result = await self.session.call_tool("resize_image", {
230
+ "image_base64": img_base64,
231
+ "width": int(width),
232
+ "height": int(height)
233
+ })
234
+
235
+ elif operation == "Visualize Image":
236
+ result = await self.session.call_tool("visualize_base64_image", {"image_base64": img_base64})
237
+
238
+ else:
239
+ return None, "Unknown operation."
240
+
241
+ # Process the result
242
+ result_content = result.content
243
+ if isinstance(result_content, str):
244
+ try:
245
+ result_data = json.loads(result_content)
246
+ if "image_base64" in result_data:
247
+ # Convert result base64 back to image
248
+ img_data = base64.b64decode(result_data["image_base64"])
249
+ result_img = Image.open(BytesIO(img_data))
250
+ return result_img, "Image processed successfully."
251
+ else:
252
+ return None, f"Unexpected result format: {result_content}"
253
+ except json.JSONDecodeError:
254
+ return None, f"Error decoding result: {result_content}"
255
+ else:
256
+ return None, f"Unexpected result type: {type(result_content)}"
257
+
258
+ except Exception as e:
259
+ return None, f"Error processing image: {str(e)}"
260
+
261
+ client = MCPClientWrapper()
262
+
263
+ def gradio_interface():
264
+ with gr.Blocks(title="MCP Assistant") as demo:
265
+ gr.Markdown("# MCP Assistant")
266
+ gr.Markdown("Connect to your MCP server to chat or process images")
267
+
268
+ with gr.Row(equal_height=True):
269
+ with gr.Column(scale=4):
270
+ server_path = gr.Textbox(
271
+ label="Server Script Path",
272
+ placeholder="Enter path to server script",
273
+ value="mcp_server.py"
274
+ )
275
+ with gr.Column(scale=1):
276
+ connect_btn = gr.Button("Connect")
277
+
278
+ status = gr.Textbox(label="Connection Status", interactive=False)
279
+
280
+ with gr.Tabs() as tabs:
281
+ with gr.TabItem("Chat Interface"):
282
+ chatbot = gr.Chatbot(
283
+ value=[],
284
+ height=500,
285
+ type="messages",
286
+ show_copy_button=True,
287
+ avatar_images=("👤", "🤖")
288
+ )
289
+
290
+ with gr.Row(equal_height=True):
291
+ msg = gr.Textbox(
292
+ label="Your Question",
293
+ placeholder="Ask about the available tools or how to process images",
294
+ scale=4
295
+ )
296
+ clear_btn = gr.Button("Clear Chat", scale=1)
297
+
298
+ with gr.TabItem("Image Processing"):
299
+ with gr.Row():
300
+ with gr.Column():
301
+ input_image = gr.Image(label="Input Image", type="pil")
302
+ operation = gr.Radio(
303
+ ["Remove Background", "Change Format", "Resize Image", "Visualize Image"],
304
+ label="Select Operation",
305
+ value="Visualize Image"
306
+ )
307
+
308
+ with gr.Group() as format_options:
309
+ target_format = gr.Dropdown(
310
+ ["png", "jpeg", "webp"],
311
+ label="Target Format",
312
+ value="png",
313
+ visible=False
314
+ )
315
+
316
+ with gr.Group() as resize_options:
317
+ with gr.Row():
318
+ width = gr.Number(label="Width", value=300, visible=False)
319
+ height = gr.Number(label="Height", value=300, visible=False)
320
+
321
+ process_btn = gr.Button("Process Image")
322
+
323
+ with gr.Column():
324
+ output_image = gr.Image(label="Processed Image")
325
+ output_message = gr.Textbox(label="Status")
326
+
327
+ # Connect to server
328
+ connect_btn.click(client.connect, inputs=server_path, outputs=status)
329
+
330
+ # Chat functionality
331
+ msg.submit(client.process_message, [msg, chatbot], [chatbot, msg])
332
+ clear_btn.click(lambda: [], None, chatbot)
333
+
334
+ # Image processing functionality
335
+ def update_options(op):
336
+ return {
337
+ target_format: op == "Change Format",
338
+ width: op == "Resize Image",
339
+ height: op == "Resize Image"
340
+ }
341
+
342
+ operation.change(update_options, inputs=operation, outputs=[target_format, width, height])
343
+
344
+ def process_image_wrapper(image, operation, target_format, width, height):
345
+ return loop.run_until_complete(client.process_image(image, operation, target_format, width, height))
346
+
347
+ process_btn.click(
348
+ process_image_wrapper,
349
+ inputs=[input_image, operation, target_format, width, height],
350
+ outputs=[output_image, output_message]
351
+ )
352
+
353
+ return demo
354
+
355
+ if __name__ == "__main__":
356
+ if not os.getenv("OPENROUTER_API_KEY"):
357
+ print("Warning: OPENROUTER_API_KEY not found in environment. Please set it in your .env file.")
358
+
359
+ interface = gradio_interface()
360
+ interface.launch(debug=True)