Spaces:
Sleeping
Sleeping
File size: 6,312 Bytes
8227e25 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 |
import traceback
from fastapi import FastAPI, WebSocket
from fastapi.responses import FileResponse
import asyncio
from fastapi.staticfiles import StaticFiles
from contextlib import asynccontextmanager
import json
from fastapi import HTTPException
from pydantic import BaseModel
from fastapi.middleware.cors import CORSMiddleware
from typing import List, Optional, Any, Dict
from mcp_client import MCPClient
mcp = MCPClient()
class ChatMessage(BaseModel):
role: str
content: str
class ChatCompletionRequest(BaseModel):
model: str = "gemini-2.5-pro-exp-03-25"
messages: List[ChatMessage]
tools: Optional[list] = []
max_tokens: Optional[int] = None
class ChatCompletionResponseChoice(BaseModel):
index: int = 0
message: ChatMessage
finish_reason: str = "stop"
class ChatCompletionResponse(BaseModel):
id: str
object: str = "chat.completion"
created: int
model: str
choices: List[ChatCompletionResponseChoice]
@asynccontextmanager
async def lifespan(app: FastAPI):
try:
await mcp.connect()
print("Connexion au MCP réussi !")
except Exception as e:
print("Warning ! : Connexion au MCP impossible\n", str(e))
yield
if mcp.session:
try:
await mcp.exit_stack.aclose()
print("MCP déconnecté !")
except Exception as e:
print("Erreur à la fermeture du MCP\n", str(e))
app = FastAPI(lifespan=lifespan)
app.mount("/static", StaticFiles(directory="static"), name="static")
app.add_middleware(
CORSMiddleware,
allow_credentials=True,
allow_headers=["*"],
allow_methods=["*"],
allow_origins=["*"]
)
class ConnectionManager:
def __init__(self):
self.active_connections = {}
self.response_queues = {}
async def connect(self, websocket: WebSocket):
await websocket.accept()
self.active_connections[websocket] = None
def set_source(self, websocket: WebSocket, source: str):
if websocket in self.active_connections:
self.active_connections[websocket] = source
async def send_to_dest(self, destination: str, message: str):
for ws, src in self.active_connections.items():
if src == destination:
await ws.send_text(message)
def remove(self, websocket: WebSocket):
if websocket in self.active_connections:
del self.active_connections[websocket]
async def wait_for_response(self, request_id: str, timeout: int = 30):
queue = asyncio.Queue(maxsize=1)
self.response_queues[request_id] = queue
try:
return await asyncio.wait_for(queue.get(), timeout=timeout)
finally:
self.response_queues.pop(request_id, None)
manager = ConnectionManager()
@app.get("/")
async def index_page():
return FileResponse("index.html")
# @app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
# async def chat_completions(request: ChatCompletionRequest):
# request_id = str(uuid.uuid4())
# proxy_ws = next((ws for ws, src in manager.active_connections.items() if src == "proxy"), None)
# if not proxy_ws:
# raise HTTPException(503, "Proxy client not connected !")
# user_msg = next((m for m in request.messages if m.role == "user"), None)
# if not user_msg:
# raise HTTPException(400, "No user message found !")
# proxy_msg = {
# "request_id": request_id,
# "content": user_msg.content,
# "source": "api",
# "destination": "proxy",
# "model": request.model,
# "tools": request.tools,
# "max_tokens": request.max_tokens
# }
# await proxy_ws.send_text(json.dumps(proxy_msg))
# try:
# response_content = await manager.wait_for_response(request_id)
# except asyncio.TimeoutError:
# raise HTTPException(504, "Proxy response timeout")
# return ChatCompletionResponse(
# id=request_id,
# created=int(time.time()),
# model=request.model,
# choices=[ChatCompletionResponseChoice(
# message=ChatMessage(role="assistant", content=response_content)
# )]
# )
class ToolCallRequest(BaseModel):
tool_calls: List[Dict[str, Any]]
@app.get("/list-tools", response_model=List[Dict[str, Any]])
async def list_tools():
if not mcp.session:
try:
await mcp.connect()
except Exception as e:
raise HTTPException(status_code=503, detail=f"Connexion au MCP impossible !\n{str(e)}")
try:
tools = await mcp.list_tools()
return tools
except Exception as e:
raise HTTPException(status_code=500, detail=f"Erreur lors de la récupération des outils: {str(e)}")
@app.post("/call-tools")
async def call_tools(request: ToolCallRequest):
if not mcp.session:
try:
await mcp.connect()
except Exception as e:
raise HTTPException(status_code=503, detail=f"Erreur lors de la récupération des outils: {str(e)}")
try:
result_tools = []
for tool_call in request.tool_calls:
print(tool_call)
tool = tool_call["function"]
tool_name = tool["name"]
tool_args = tool["arguments"]
result = await mcp.session.call_tool(tool_name, json.loads(tool_args))
result_tools.append({
"role": "user",
"content": result.content[0].text
})
print("Finished !")
return result_tools
except Exception as e:
raise HTTPException(status_code=500, detail=f"Erreur lors de l'appel des outils: {str(e)}")
@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
await manager.connect(websocket)
try:
data = await websocket.receive_text()
init_msg = json.loads(data)
if 'source' in init_msg:
manager.set_source(websocket, init_msg['source'])
print(init_msg['source'])
while True:
message = await websocket.receive_text()
msg_data = json.loads(message)
await manager.send_to_dest(msg_data["destination"], message)
except Exception as e:
manager.remove(websocket)
await websocket.close() |