JuanjoSG5 commited on
Commit
26e31ab
·
2 Parent(s): b2b7174 14c9c39

Merge branch 'test'

Browse files
Files changed (2) hide show
  1. agent_test.py +240 -0
  2. gradio_interface/app.py +201 -4
agent_test.py ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import os
3
+ import json
4
+ import base64
5
+ from typing import List, Dict, Any, Union
6
+ from contextlib import AsyncExitStack
7
+ from io import BytesIO
8
+ from PIL import Image
9
+ import gradio as gr
10
+ from gradio.components.chatbot import ChatMessage
11
+ from mcp import ClientSession, StdioServerParameters
12
+ from mcp.client.stdio import stdio_client
13
+ from dotenv import load_dotenv
14
+ from langchain_openai import ChatOpenAI
15
+
16
+ load_dotenv()
17
+
18
+ loop = asyncio.new_event_loop()
19
+ asyncio.set_event_loop(loop)
20
+
21
+ class MCPClientWrapper:
22
+ def __init__(self):
23
+ self.session = None
24
+ self.exit_stack = None
25
+ self.mistral = ChatOpenAI(model_name="mistralai/mistral-small", temperature=0.7, openai_api_key=os.getenv("OPENROUTER_API_KEY"))
26
+ self.tools = []
27
+
28
+ def connect(self, server_path: str) -> str:
29
+ return loop.run_until_complete(self._connect(server_path))
30
+
31
+ async def _connect(self, server_path: str) -> str:
32
+ if self.exit_stack:
33
+ await self.exit_stack.aclose()
34
+
35
+ self.exit_stack = AsyncExitStack()
36
+
37
+ is_python = server_path.endswith('.py')
38
+ command = "python" if is_python else "node"
39
+
40
+ server_params = StdioServerParameters(
41
+ command=command,
42
+ args=[server_path],
43
+ env={"PYTHONIOENCODING": "utf-8", "PYTHONUNBUFFERED": "1"}
44
+ )
45
+
46
+ stdio_transport = await self.exit_stack.enter_async_context(stdio_client(server_params))
47
+ self.stdio, self.write = stdio_transport
48
+
49
+ self.session = await self.exit_stack.enter_async_context(ClientSession(self.stdio, self.write))
50
+ await self.session.initialize()
51
+
52
+ response = await self.session.list_tools()
53
+ self.tools = [{
54
+ "name": tool.name,
55
+ "description": tool.description,
56
+ "input_schema": tool.inputSchema
57
+ } for tool in response.tools]
58
+
59
+ tool_names = [tool["name"] for tool in self.tools]
60
+ return f"Connected to MCP server. Available tools: {', '.join(tool_names)}"
61
+
62
+ def process_message(self, message: str, history: List[Union[Dict[str, Any], ChatMessage]]) -> tuple:
63
+ if not self.session:
64
+ return history + [
65
+ {"role": "user", "content": message},
66
+ {"role": "assistant", "content": "Please connect to an MCP server first."}
67
+ ], gr.Textbox(value="")
68
+
69
+ new_messages = loop.run_until_complete(self._process_query(message, history))
70
+ return history + [{"role": "user", "content": message}] + new_messages, gr.Textbox(value="")
71
+
72
+ async def _process_query(self, message: str, history: List[Union[Dict[str, Any], ChatMessage]]):
73
+ claude_messages = []
74
+ for msg in history:
75
+ if isinstance(msg, ChatMessage):
76
+ role, content = msg.role, msg.content
77
+ else:
78
+ role, content = msg.get("role"), msg.get("content")
79
+
80
+ if role in ["user", "assistant", "system"]:
81
+ claude_messages.append({"role": role, "content": content})
82
+
83
+ claude_messages.append({"role": "user", "content": message})
84
+
85
+ response = self.mistral.messages.create(
86
+ model="claude-3-5-sonnet-20241022",
87
+ max_tokens=1000,
88
+ messages=claude_messages,
89
+ tools=self.tools
90
+ )
91
+
92
+ result_messages = []
93
+
94
+ for content in response.content:
95
+ if content.type == 'text':
96
+ result_messages.append({
97
+ "role": "assistant",
98
+ "content": content.text
99
+ })
100
+
101
+ elif content.type == 'tool_use':
102
+ tool_name = content.name
103
+ tool_args = content.input
104
+
105
+ result_messages.append({
106
+ "role": "assistant",
107
+ "content": f"I'll use the {tool_name} tool to help answer your question.",
108
+ "metadata": {
109
+ "title": f"Using tool: {tool_name}",
110
+ "log": f"Parameters: {json.dumps(tool_args, ensure_ascii=True)}",
111
+ "status": "pending",
112
+ "id": f"tool_call_{tool_name}"
113
+ }
114
+ })
115
+
116
+ result_messages.append({
117
+ "role": "assistant",
118
+ "content": "```json\n" + json.dumps(tool_args, indent=2, ensure_ascii=True) + "\n```",
119
+ "metadata": {
120
+ "parent_id": f"tool_call_{tool_name}",
121
+ "id": f"params_{tool_name}",
122
+ "title": "Tool Parameters"
123
+ }
124
+ })
125
+
126
+ result = await self.session.call_tool(tool_name, tool_args)
127
+
128
+ if result_messages and "metadata" in result_messages[-2]:
129
+ result_messages[-2]["metadata"]["status"] = "done"
130
+
131
+ result_messages.append({
132
+ "role": "assistant",
133
+ "content": "Here are the results from the tool:",
134
+ "metadata": {
135
+ "title": f"Tool Result for {tool_name}",
136
+ "status": "done",
137
+ "id": f"result_{tool_name}"
138
+ }
139
+ })
140
+
141
+ result_content = result.content
142
+ if isinstance(result_content, list):
143
+ result_content = "\n".join(str(item) for item in result_content)
144
+
145
+ try:
146
+ result_json = json.loads(result_content)
147
+ if isinstance(result_json, dict) and "type" in result_json:
148
+ if result_json["type"] == "image" and "url" in result_json:
149
+ result_messages.append({
150
+ "role": "assistant",
151
+ "content": {"path": result_json["url"], "alt_text": result_json.get("message", "Generated image")},
152
+ "metadata": {
153
+ "parent_id": f"result_{tool_name}",
154
+ "id": f"image_{tool_name}",
155
+ "title": "Generated Image"
156
+ }
157
+ })
158
+ else:
159
+ result_messages.append({
160
+ "role": "assistant",
161
+ "content": "```\n" + result_content + "\n```",
162
+ "metadata": {
163
+ "parent_id": f"result_{tool_name}",
164
+ "id": f"raw_result_{tool_name}",
165
+ "title": "Raw Output"
166
+ }
167
+ })
168
+ except:
169
+ result_messages.append({
170
+ "role": "assistant",
171
+ "content": "```\n" + result_content + "\n```",
172
+ "metadata": {
173
+ "parent_id": f"result_{tool_name}",
174
+ "id": f"raw_result_{tool_name}",
175
+ "title": "Raw Output"
176
+ }
177
+ })
178
+
179
+ claude_messages.append({"role": "user", "content": f"Tool result for {tool_name}: {result_content}"})
180
+ next_response = self.mistral.messages.create(
181
+ model="claude-3-5-sonnet-20241022",
182
+ max_tokens=1000,
183
+ messages=claude_messages,
184
+ )
185
+
186
+ if next_response.content and next_response.content[0].type == 'text':
187
+ result_messages.append({
188
+ "role": "assistant",
189
+ "content": next_response.content[0].text
190
+ })
191
+
192
+ return result_messages
193
+
194
+ client = MCPClientWrapper()
195
+
196
+ def gradio_interface():
197
+ with gr.Blocks(title="MCP Weather Client") as demo:
198
+ gr.Markdown("# MCP Weather Assistant")
199
+ gr.Markdown("Connect to your MCP weather server and chat with the assistant")
200
+
201
+ with gr.Row(equal_height=True):
202
+ with gr.Column(scale=4):
203
+ server_path = gr.Textbox(
204
+ label="Server Script Path",
205
+ placeholder="Enter path to server script (e.g., weather.py)",
206
+ value="gradio_mcp_server.py"
207
+ )
208
+ with gr.Column(scale=1):
209
+ connect_btn = gr.Button("Connect")
210
+
211
+ status = gr.Textbox(label="Connection Status", interactive=False)
212
+
213
+ chatbot = gr.Chatbot(
214
+ value=[],
215
+ height=500,
216
+ type="messages",
217
+ show_copy_button=True,
218
+ avatar_images=("👤", "🤖")
219
+ )
220
+
221
+ with gr.Row(equal_height=True):
222
+ msg = gr.Textbox(
223
+ label="Your Question",
224
+ placeholder="Ask about weather or alerts (e.g., What's the weather in New York?)",
225
+ scale=4
226
+ )
227
+ clear_btn = gr.Button("Clear Chat", scale=1)
228
+
229
+ connect_btn.click(client.connect, inputs=server_path, outputs=status)
230
+ msg.submit(client.process_message, [msg, chatbot], [chatbot, msg])
231
+ clear_btn.click(lambda: [], None, chatbot)
232
+
233
+ return demo
234
+
235
+ if __name__ == "__main__":
236
+ if not os.getenv("OPENROUTER_API_KEY"):
237
+ print("Warning: OPENROUTER_API_KEY not found in environment. Please set it in your .env file.")
238
+
239
+ interface = gradio_interface()
240
+ interface.launch(debug=True)
gradio_interface/app.py CHANGED
@@ -1,7 +1,204 @@
 
1
  import gradio as gr
 
 
 
 
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
 
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
  import gradio as gr
3
+ from os import getenv
4
+ import base64
5
+ from io import BytesIO
6
+ from dotenv import load_dotenv
7
+ import requests
8
+ import socket
9
+ import logging
10
+ import json
11
 
12
+ from langchain_openai import ChatOpenAI
13
+ from langchain_core.messages import HumanMessage, AIMessage
14
+ from langchain_core.callbacks import StreamingStdOutCallbackHandler
15
 
16
+ # Load environment
17
+ dotenv_path = os.path.join(os.path.dirname(__file__), '.env')
18
+ load_dotenv(dotenv_path=dotenv_path)
19
+
20
+ # Connectivity test
21
+ def test_connectivity(url="https://openrouter.helicone.ai/api/v1"):
22
+ try:
23
+ return requests.get(url, timeout=5).status_code == 200
24
+ except (requests.RequestException, socket.error):
25
+ return False
26
+
27
+ # Helper to make direct API calls to OpenRouter when LangChain fails
28
+ def direct_api_call(messages, api_key, base_url):
29
+ headers = {
30
+ "Content-Type": "application/json",
31
+ "Authorization": f"Bearer {api_key}",
32
+ "HTTP-Referer": "https://your-app-domain.com", # Add your domain
33
+ "X-Title": "Image Analysis App"
34
+ }
35
+
36
+ if getenv("HELICONE_API_KEY"):
37
+ headers["Helicone-Auth"] = f"Bearer {getenv('HELICONE_API_KEY')}"
38
+
39
+ payload = {
40
+ "model": "google/gemini-flash-1.5",
41
+ "messages": messages,
42
+ "stream": False,
43
+ }
44
+
45
+ try:
46
+ response = requests.post(
47
+ f"{base_url}/chat/completions",
48
+ headers=headers,
49
+ json=payload,
50
+ timeout=30
51
+ )
52
+ response.raise_for_status()
53
+ return response.json()["choices"][0]["message"]["content"]
54
+ except Exception as e:
55
+ return f"Error: {str(e)}"
56
+
57
+ # Initialize LLM with streaming and retry logic
58
+ def init_llm():
59
+ if not test_connectivity():
60
+ raise RuntimeError("No hay conexión a OpenRouter. Verifica red y claves.")
61
+ return ChatOpenAI(
62
+ openai_api_key=getenv("OPENROUTER_API_KEY"),
63
+ openai_api_base=getenv("OPENROUTER_BASE_URL"),
64
+ model_name="google/gemini-flash-1.5",
65
+ streaming=True,
66
+ callbacks=[StreamingStdOutCallbackHandler()],
67
+ model_kwargs={
68
+ "extra_headers": {"Helicone-Auth": f"Bearer {getenv('HELICONE_API_KEY')}"}
69
+ },
70
+ )
71
+
72
+ # Try to initialize LLM but handle failures gracefully
73
+ try:
74
+ llm = init_llm()
75
+ except Exception as e:
76
+ llm = None
77
+
78
+ # Helpers
79
+ def encode_image_to_base64(pil_image):
80
+ buffer = BytesIO()
81
+ pil_image.save(buffer, format="PNG")
82
+ return base64.b64encode(buffer.getvalue()).decode()
83
+
84
+ # Core logic
85
+ def generate_response(message, chat_history, image):
86
+ # Convert chat history to standard format
87
+ formatted_history = []
88
+ for msg in chat_history:
89
+ role = msg.get('role')
90
+ content = msg.get('content')
91
+ if role == 'user':
92
+ formatted_history.append({"role": "user", "content": content})
93
+ else:
94
+ formatted_history.append({"role": "assistant", "content": content})
95
+
96
+ # Prepare system message
97
+ system_msg = {"role": "system", "content": "You are an expert image analysis assistant. Answer succinctly."}
98
+
99
+ # Prepare the latest message with image if provided
100
+ if image:
101
+ base64_image = encode_image_to_base64(image)
102
+
103
+ # Format for direct API call (OpenRouter/OpenAI format)
104
+ api_messages = [system_msg] + formatted_history + [{
105
+ "role": "user",
106
+ "content": [
107
+ {"type": "text", "text": message},
108
+ {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{base64_image}"}}
109
+ ]
110
+ }]
111
+
112
+ # For LangChain format
113
+ content_for_langchain = [
114
+ {"type": "text", "text": message},
115
+ {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{base64_image}"}}
116
+ ]
117
+ else:
118
+ api_messages = [system_msg] + formatted_history + [{"role": "user", "content": message}]
119
+ content_for_langchain = message
120
+
121
+ # Build LangChain messages
122
+ lc_messages = [HumanMessage(content="You are an expert image analysis assistant. Answer succinctly.")]
123
+ for msg in chat_history:
124
+ role = msg.get('role')
125
+ content = msg.get('content')
126
+ if role == 'user':
127
+ lc_messages.append(HumanMessage(content=content))
128
+ else:
129
+ lc_messages.append(AIMessage(content=content))
130
+
131
+ lc_messages.append(HumanMessage(content=content_for_langchain))
132
+
133
+ try:
134
+ # First try with LangChain
135
+ if llm:
136
+ try:
137
+ try:
138
+ stream_iter = llm.stream(lc_messages)
139
+ partial = ""
140
+ for chunk in stream_iter:
141
+ if chunk is None:
142
+ continue
143
+ content = getattr(chunk, 'content', None)
144
+ if content is None:
145
+ continue
146
+ partial += content
147
+ yield partial
148
+
149
+ # If we got this far, streaming worked
150
+ return
151
+ except Exception as e:
152
+ print(f"Streaming failed: {e}. Falling back to non-streaming mode")
153
+
154
+ # Try non-streaming
155
+ try:
156
+ response = llm.invoke(lc_messages)
157
+ yield response.content
158
+ return
159
+ except Exception as e:
160
+ raise e
161
+ except Exception as e:
162
+ raise e
163
+
164
+ response_text = direct_api_call(
165
+ api_messages,
166
+ getenv("OPENROUTER_API_KEY"),
167
+ getenv("OPENROUTER_BASE_URL")
168
+ )
169
+ yield response_text
170
+
171
+ except Exception as e:
172
+ import traceback
173
+ error_trace = traceback.format_exc()
174
+ yield f"⚠️ Error al generar respuesta: {str(e)}. Intenta más tarde."
175
+
176
+ # Gradio interface
177
+ def process_message(message, chat_history, image):
178
+ if chat_history is None:
179
+ chat_history = []
180
+ if image is None:
181
+ chat_history.append({'role':'assistant','content':'Por favor sube una imagen.'})
182
+ return "", chat_history
183
+ chat_history.append({'role':'user','content':message})
184
+ chat_history.append({'role':'assistant','content':'⏳ Procesando...'})
185
+ yield "", chat_history
186
+ for chunk in generate_response(message, chat_history, image):
187
+ chat_history[-1]['content'] = chunk
188
+ yield "", chat_history
189
+ return "", chat_history
190
+
191
+ with gr.Blocks() as demo:
192
+ with gr.Row():
193
+ with gr.Column(scale=2):
194
+ chatbot = gr.Chatbot(type='messages', height=600)
195
+ msg = gr.Textbox(label="Mensaje", placeholder="Escribe tu pregunta...")
196
+ clear = gr.ClearButton([msg, chatbot])
197
+ with gr.Column(scale=1):
198
+ image_input = gr.Image(type="pil", label="Sube Imagen")
199
+ info = gr.Textbox(label="Info Imagen", interactive=False)
200
+
201
+ msg.submit(process_message, [msg, chatbot, image_input], [msg, chatbot])
202
+ image_input.change(lambda img: f"Tamaño: {img.size}" if img else "Sin imagen.", [image_input], [info])
203
+
204
+ demo.launch()