File size: 6,312 Bytes
8227e25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
import traceback
from fastapi import FastAPI, WebSocket
from fastapi.responses import FileResponse
import asyncio
from fastapi.staticfiles import StaticFiles
from contextlib import asynccontextmanager
import json
from fastapi import HTTPException
from pydantic import BaseModel
from fastapi.middleware.cors import CORSMiddleware
from typing import List, Optional, Any, Dict
from mcp_client import MCPClient

mcp = MCPClient()

class ChatMessage(BaseModel):
    role: str
    content: str

class ChatCompletionRequest(BaseModel):
    model: str = "gemini-2.5-pro-exp-03-25"
    messages: List[ChatMessage]
    tools: Optional[list] = []
    max_tokens: Optional[int] = None

class ChatCompletionResponseChoice(BaseModel):
    index: int = 0
    message: ChatMessage
    finish_reason: str = "stop"

class ChatCompletionResponse(BaseModel):
    id: str
    object: str = "chat.completion"
    created: int
    model: str
    choices: List[ChatCompletionResponseChoice]

@asynccontextmanager
async def lifespan(app: FastAPI):
    try:
        await mcp.connect()
        print("Connexion au MCP réussi !")
    except Exception as e:
        print("Warning ! : Connexion au MCP impossible\n", str(e))
    
    yield

    if mcp.session:
        try:
            await mcp.exit_stack.aclose()
            print("MCP déconnecté !")
        except Exception as e:
            print("Erreur à la fermeture du MCP\n", str(e))

app = FastAPI(lifespan=lifespan)
app.mount("/static", StaticFiles(directory="static"), name="static")
app.add_middleware(
    CORSMiddleware,
    allow_credentials=True,
    allow_headers=["*"],
    allow_methods=["*"],
    allow_origins=["*"]
)

class ConnectionManager:
    def __init__(self):
        self.active_connections = {}
        self.response_queues = {}

    async def connect(self, websocket: WebSocket):
        await websocket.accept()
        self.active_connections[websocket] = None
    
    def set_source(self, websocket: WebSocket, source: str):
        if websocket in self.active_connections:
            self.active_connections[websocket] = source
    
    async def send_to_dest(self, destination: str, message: str):
        for ws, src in self.active_connections.items():
            if src == destination:
                await ws.send_text(message)
    
    def remove(self, websocket: WebSocket):
        if websocket in self.active_connections:
            del self.active_connections[websocket]
    
    async def wait_for_response(self, request_id: str, timeout: int = 30):
        queue = asyncio.Queue(maxsize=1)
        self.response_queues[request_id] = queue
        try:
            return await asyncio.wait_for(queue.get(), timeout=timeout)
        finally:
            self.response_queues.pop(request_id, None)

manager = ConnectionManager()

@app.get("/")
async def index_page():
    return FileResponse("index.html")

# @app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
# async def chat_completions(request: ChatCompletionRequest):
#     request_id = str(uuid.uuid4())
#     proxy_ws = next((ws for ws, src in manager.active_connections.items() if src == "proxy"), None)
#     if not proxy_ws:
#         raise HTTPException(503, "Proxy client not connected !")
#     user_msg = next((m for m in request.messages if m.role == "user"), None)
#     if not user_msg:
#         raise HTTPException(400, "No user message found !")

#     proxy_msg = {
#         "request_id": request_id,
#         "content": user_msg.content,
#         "source": "api",
#         "destination": "proxy",
#         "model": request.model,
#         "tools": request.tools,
#         "max_tokens": request.max_tokens
#     }

#     await proxy_ws.send_text(json.dumps(proxy_msg))

#     try:
#         response_content = await manager.wait_for_response(request_id)
#     except asyncio.TimeoutError:
#         raise HTTPException(504, "Proxy response timeout")
#     return ChatCompletionResponse(
#         id=request_id,
#         created=int(time.time()),
#         model=request.model,
#         choices=[ChatCompletionResponseChoice(
#             message=ChatMessage(role="assistant", content=response_content)
#         )]
#     )

class ToolCallRequest(BaseModel):
    tool_calls: List[Dict[str, Any]]

@app.get("/list-tools", response_model=List[Dict[str, Any]])
async def list_tools():
    if not mcp.session:
        try:
            await mcp.connect()
        except Exception as e:
            raise HTTPException(status_code=503, detail=f"Connexion au MCP impossible !\n{str(e)}")
    try:
        tools = await mcp.list_tools()
        return tools
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Erreur lors de la récupération des outils: {str(e)}")

@app.post("/call-tools")
async def call_tools(request: ToolCallRequest):
    if not mcp.session:
        try:
            await mcp.connect()
        except Exception as e:
            raise HTTPException(status_code=503, detail=f"Erreur lors de la récupération des outils: {str(e)}")
    try:
        result_tools = []
        for tool_call in request.tool_calls:
            print(tool_call)
            tool = tool_call["function"]
            tool_name = tool["name"]
            tool_args = tool["arguments"]
            result = await mcp.session.call_tool(tool_name, json.loads(tool_args))
            result_tools.append({
                "role": "user",
                "content": result.content[0].text
            })
        print("Finished !")
        return result_tools
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Erreur lors de l'appel des outils: {str(e)}")




@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
    await manager.connect(websocket)
    try:
        data = await websocket.receive_text()
        init_msg = json.loads(data)
        if 'source' in init_msg:
            manager.set_source(websocket, init_msg['source'])
            print(init_msg['source'])
        
        while True:
            message = await websocket.receive_text()
            msg_data = json.loads(message)
            await manager.send_to_dest(msg_data["destination"], message)
    except Exception as e:
        manager.remove(websocket)
        await websocket.close()