Spaces:
Running
Running
File size: 6,450 Bytes
3aea7c6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 |
from fastapi import WebSocket
from typing import Dict, List, Optional
import logging
logger = logging.getLogger(__name__)
class ConnectionManager:
"""Manages WebSocket connections for masters and slaves"""
def __init__(self):
# robot_id -> websocket
self.master_connections: Dict[str, WebSocket] = {}
# robot_id -> list of websockets
self.slave_connections: Dict[str, List[WebSocket]] = {}
# connection_id -> (robot_id, websocket)
self.connection_registry: Dict[str, tuple] = {}
async def connect_master(self, connection_id: str, robot_id: str, websocket: WebSocket):
"""Connect a master to a robot"""
# Only one master per robot
if robot_id in self.master_connections:
logger.warning(f"Disconnecting existing master for robot {robot_id}")
await self.disconnect_master_by_robot(robot_id)
self.master_connections[robot_id] = websocket
self.connection_registry[connection_id] = (robot_id, websocket)
logger.info(f"Master {connection_id} connected to robot {robot_id}")
async def connect_slave(self, connection_id: str, robot_id: str, websocket: WebSocket):
"""Connect a slave to a robot"""
if robot_id not in self.slave_connections:
self.slave_connections[robot_id] = []
self.slave_connections[robot_id].append(websocket)
self.connection_registry[connection_id] = (robot_id, websocket)
logger.info(f"Slave {connection_id} connected to robot {robot_id} ({len(self.slave_connections[robot_id])} total slaves)")
async def disconnect_master(self, connection_id: str):
"""Disconnect a master connection"""
if connection_id in self.connection_registry:
robot_id, websocket = self.connection_registry[connection_id]
if robot_id in self.master_connections:
del self.master_connections[robot_id]
del self.connection_registry[connection_id]
logger.info(f"Master {connection_id} disconnected from robot {robot_id}")
async def disconnect_master_by_robot(self, robot_id: str):
"""Disconnect master by robot ID"""
if robot_id in self.master_connections:
websocket = self.master_connections[robot_id]
# Find and remove from connection registry
for conn_id, (r_id, ws) in list(self.connection_registry.items()):
if r_id == robot_id and ws == websocket:
del self.connection_registry[conn_id]
break
del self.master_connections[robot_id]
try:
await websocket.close()
except Exception as e:
logger.error(f"Error closing master websocket for robot {robot_id}: {e}")
async def disconnect_slave(self, connection_id: str):
"""Disconnect a slave connection"""
if connection_id in self.connection_registry:
robot_id, websocket = self.connection_registry[connection_id]
if robot_id in self.slave_connections:
try:
self.slave_connections[robot_id].remove(websocket)
if not self.slave_connections[robot_id]: # Remove empty list
del self.slave_connections[robot_id]
except ValueError:
logger.warning(f"Slave websocket not found in connections for robot {robot_id}")
del self.connection_registry[connection_id]
logger.info(f"Slave {connection_id} disconnected from robot {robot_id}")
def get_master_connection(self, robot_id: str) -> Optional[WebSocket]:
"""Get master connection for a robot"""
return self.master_connections.get(robot_id)
def get_slave_connections(self, robot_id: str) -> List[WebSocket]:
"""Get all slave connections for a robot"""
return self.slave_connections.get(robot_id, [])
def get_connection_count(self) -> int:
"""Get total number of active connections"""
master_count = len(self.master_connections)
slave_count = sum(len(slaves) for slaves in self.slave_connections.values())
return master_count + slave_count
def get_robot_connection_info(self, robot_id: str) -> dict:
"""Get connection information for a robot"""
has_master = robot_id in self.master_connections
slave_count = len(self.slave_connections.get(robot_id, []))
return {
"robot_id": robot_id,
"has_master": has_master,
"slave_count": slave_count,
"total_connections": (1 if has_master else 0) + slave_count
}
async def cleanup_robot_connections(self, robot_id: str):
"""Clean up all connections for a robot"""
# Close master connection
if robot_id in self.master_connections:
try:
await self.master_connections[robot_id].close()
except Exception as e:
logger.error(f"Error closing master connection for robot {robot_id}: {e}")
del self.master_connections[robot_id]
# Close slave connections
if robot_id in self.slave_connections:
for websocket in self.slave_connections[robot_id]:
try:
await websocket.close()
except Exception as e:
logger.error(f"Error closing slave connection for robot {robot_id}: {e}")
del self.slave_connections[robot_id]
# Clean up connection registry
to_remove = []
for conn_id, (r_id, _) in self.connection_registry.items():
if r_id == robot_id:
to_remove.append(conn_id)
for conn_id in to_remove:
del self.connection_registry[conn_id]
logger.info(f"Cleaned up all connections for robot {robot_id}")
def list_all_connections(self) -> dict:
"""List all active connections for debugging"""
return {
"masters": {robot_id: "connected" for robot_id in self.master_connections.keys()},
"slaves": {robot_id: len(slaves) for robot_id, slaves in self.slave_connections.items()},
"total_connections": self.get_connection_count()
} |