import traceback from fastapi import FastAPI, WebSocket from fastapi.responses import FileResponse import asyncio from fastapi.staticfiles import StaticFiles from contextlib import asynccontextmanager import json from fastapi import HTTPException from pydantic import BaseModel from fastapi.middleware.cors import CORSMiddleware from typing import List, Optional, Any, Dict from mcp_client import MCPClient mcp = MCPClient() class ChatMessage(BaseModel): role: str content: str class ChatCompletionRequest(BaseModel): model: str = "gemini-2.5-pro-exp-03-25" messages: List[ChatMessage] tools: Optional[list] = [] max_tokens: Optional[int] = None class ChatCompletionResponseChoice(BaseModel): index: int = 0 message: ChatMessage finish_reason: str = "stop" class ChatCompletionResponse(BaseModel): id: str object: str = "chat.completion" created: int model: str choices: List[ChatCompletionResponseChoice] @asynccontextmanager async def lifespan(app: FastAPI): try: await mcp.connect() print("Connexion au MCP réussi !") except Exception as e: print("Warning ! : Connexion au MCP impossible\n", str(e)) yield if mcp.session: try: await mcp.exit_stack.aclose() print("MCP déconnecté !") except Exception as e: print("Erreur à la fermeture du MCP\n", str(e)) app = FastAPI(lifespan=lifespan) app.mount("/static", StaticFiles(directory="static"), name="static") app.add_middleware( CORSMiddleware, allow_credentials=True, allow_headers=["*"], allow_methods=["*"], allow_origins=["*"] ) class ConnectionManager: def __init__(self): self.active_connections = {} self.response_queues = {} async def connect(self, websocket: WebSocket): await websocket.accept() self.active_connections[websocket] = None def set_source(self, websocket: WebSocket, source: str): if websocket in self.active_connections: self.active_connections[websocket] = source async def send_to_dest(self, destination: str, message: str): for ws, src in self.active_connections.items(): if src == destination: await ws.send_text(message) def remove(self, websocket: WebSocket): if websocket in self.active_connections: del self.active_connections[websocket] async def wait_for_response(self, request_id: str, timeout: int = 30): queue = asyncio.Queue(maxsize=1) self.response_queues[request_id] = queue try: return await asyncio.wait_for(queue.get(), timeout=timeout) finally: self.response_queues.pop(request_id, None) manager = ConnectionManager() @app.get("/") async def index_page(): return FileResponse("index.html") # @app.post("/v1/chat/completions", response_model=ChatCompletionResponse) # async def chat_completions(request: ChatCompletionRequest): # request_id = str(uuid.uuid4()) # proxy_ws = next((ws for ws, src in manager.active_connections.items() if src == "proxy"), None) # if not proxy_ws: # raise HTTPException(503, "Proxy client not connected !") # user_msg = next((m for m in request.messages if m.role == "user"), None) # if not user_msg: # raise HTTPException(400, "No user message found !") # proxy_msg = { # "request_id": request_id, # "content": user_msg.content, # "source": "api", # "destination": "proxy", # "model": request.model, # "tools": request.tools, # "max_tokens": request.max_tokens # } # await proxy_ws.send_text(json.dumps(proxy_msg)) # try: # response_content = await manager.wait_for_response(request_id) # except asyncio.TimeoutError: # raise HTTPException(504, "Proxy response timeout") # return ChatCompletionResponse( # id=request_id, # created=int(time.time()), # model=request.model, # choices=[ChatCompletionResponseChoice( # message=ChatMessage(role="assistant", content=response_content) # )] # ) class ToolCallRequest(BaseModel): tool_calls: List[Dict[str, Any]] @app.get("/list-tools", response_model=List[Dict[str, Any]]) async def list_tools(): if not mcp.session: try: await mcp.connect() except Exception as e: raise HTTPException(status_code=503, detail=f"Connexion au MCP impossible !\n{str(e)}") try: tools = await mcp.list_tools() return tools except Exception as e: raise HTTPException(status_code=500, detail=f"Erreur lors de la récupération des outils: {str(e)}") @app.post("/call-tools") async def call_tools(request: ToolCallRequest): if not mcp.session: try: await mcp.connect() except Exception as e: raise HTTPException(status_code=503, detail=f"Erreur lors de la récupération des outils: {str(e)}") try: result_tools = [] for tool_call in request.tool_calls: print(tool_call) tool = tool_call["function"] tool_name = tool["name"] tool_args = tool["arguments"] result = await mcp.session.call_tool(tool_name, json.loads(tool_args)) result_tools.append({ "role": "user", "content": result.content[0].text }) print("Finished !") return result_tools except Exception as e: raise HTTPException(status_code=500, detail=f"Erreur lors de l'appel des outils: {str(e)}") @app.websocket("/ws") async def websocket_endpoint(websocket: WebSocket): await manager.connect(websocket) try: data = await websocket.receive_text() init_msg = json.loads(data) if 'source' in init_msg: manager.set_source(websocket, init_msg['source']) print(init_msg['source']) while True: message = await websocket.receive_text() msg_data = json.loads(message) await manager.send_to_dest(msg_data["destination"], message) except Exception as e: manager.remove(websocket) await websocket.close()