from fastapi import WebSocket from typing import Dict, List, Optional import logging logger = logging.getLogger(__name__) class ConnectionManager: """Manages WebSocket connections for masters and slaves""" def __init__(self): # robot_id -> websocket self.master_connections: Dict[str, WebSocket] = {} # robot_id -> list of websockets self.slave_connections: Dict[str, List[WebSocket]] = {} # connection_id -> (robot_id, websocket) self.connection_registry: Dict[str, tuple] = {} async def connect_master(self, connection_id: str, robot_id: str, websocket: WebSocket): """Connect a master to a robot""" # Only one master per robot if robot_id in self.master_connections: logger.warning(f"Disconnecting existing master for robot {robot_id}") await self.disconnect_master_by_robot(robot_id) self.master_connections[robot_id] = websocket self.connection_registry[connection_id] = (robot_id, websocket) logger.info(f"Master {connection_id} connected to robot {robot_id}") async def connect_slave(self, connection_id: str, robot_id: str, websocket: WebSocket): """Connect a slave to a robot""" if robot_id not in self.slave_connections: self.slave_connections[robot_id] = [] self.slave_connections[robot_id].append(websocket) self.connection_registry[connection_id] = (robot_id, websocket) logger.info(f"Slave {connection_id} connected to robot {robot_id} ({len(self.slave_connections[robot_id])} total slaves)") async def disconnect_master(self, connection_id: str): """Disconnect a master connection""" if connection_id in self.connection_registry: robot_id, websocket = self.connection_registry[connection_id] if robot_id in self.master_connections: del self.master_connections[robot_id] del self.connection_registry[connection_id] logger.info(f"Master {connection_id} disconnected from robot {robot_id}") async def disconnect_master_by_robot(self, robot_id: str): """Disconnect master by robot ID""" if robot_id in self.master_connections: websocket = self.master_connections[robot_id] # Find and remove from connection registry for conn_id, (r_id, ws) in list(self.connection_registry.items()): if r_id == robot_id and ws == websocket: del self.connection_registry[conn_id] break del self.master_connections[robot_id] try: await websocket.close() except Exception as e: logger.error(f"Error closing master websocket for robot {robot_id}: {e}") async def disconnect_slave(self, connection_id: str): """Disconnect a slave connection""" if connection_id in self.connection_registry: robot_id, websocket = self.connection_registry[connection_id] if robot_id in self.slave_connections: try: self.slave_connections[robot_id].remove(websocket) if not self.slave_connections[robot_id]: # Remove empty list del self.slave_connections[robot_id] except ValueError: logger.warning(f"Slave websocket not found in connections for robot {robot_id}") del self.connection_registry[connection_id] logger.info(f"Slave {connection_id} disconnected from robot {robot_id}") def get_master_connection(self, robot_id: str) -> Optional[WebSocket]: """Get master connection for a robot""" return self.master_connections.get(robot_id) def get_slave_connections(self, robot_id: str) -> List[WebSocket]: """Get all slave connections for a robot""" return self.slave_connections.get(robot_id, []) def get_connection_count(self) -> int: """Get total number of active connections""" master_count = len(self.master_connections) slave_count = sum(len(slaves) for slaves in self.slave_connections.values()) return master_count + slave_count def get_robot_connection_info(self, robot_id: str) -> dict: """Get connection information for a robot""" has_master = robot_id in self.master_connections slave_count = len(self.slave_connections.get(robot_id, [])) return { "robot_id": robot_id, "has_master": has_master, "slave_count": slave_count, "total_connections": (1 if has_master else 0) + slave_count } async def cleanup_robot_connections(self, robot_id: str): """Clean up all connections for a robot""" # Close master connection if robot_id in self.master_connections: try: await self.master_connections[robot_id].close() except Exception as e: logger.error(f"Error closing master connection for robot {robot_id}: {e}") del self.master_connections[robot_id] # Close slave connections if robot_id in self.slave_connections: for websocket in self.slave_connections[robot_id]: try: await websocket.close() except Exception as e: logger.error(f"Error closing slave connection for robot {robot_id}: {e}") del self.slave_connections[robot_id] # Clean up connection registry to_remove = [] for conn_id, (r_id, _) in self.connection_registry.items(): if r_id == robot_id: to_remove.append(conn_id) for conn_id in to_remove: del self.connection_registry[conn_id] logger.info(f"Cleaned up all connections for robot {robot_id}") def list_all_connections(self) -> dict: """List all active connections for debugging""" return { "masters": {robot_id: "connected" for robot_id in self.master_connections.keys()}, "slaves": {robot_id: len(slaves) for robot_id, slaves in self.slave_connections.items()}, "total_connections": self.get_connection_count() }