Spaces:
Running
Running
from fastapi import WebSocket | |
from typing import Dict, List, Optional | |
import logging | |
logger = logging.getLogger(__name__) | |
class ConnectionManager: | |
"""Manages WebSocket connections for masters and slaves""" | |
def __init__(self): | |
# robot_id -> websocket | |
self.master_connections: Dict[str, WebSocket] = {} | |
# robot_id -> list of websockets | |
self.slave_connections: Dict[str, List[WebSocket]] = {} | |
# connection_id -> (robot_id, websocket) | |
self.connection_registry: Dict[str, tuple] = {} | |
async def connect_master(self, connection_id: str, robot_id: str, websocket: WebSocket): | |
"""Connect a master to a robot""" | |
# Only one master per robot | |
if robot_id in self.master_connections: | |
logger.warning(f"Disconnecting existing master for robot {robot_id}") | |
await self.disconnect_master_by_robot(robot_id) | |
self.master_connections[robot_id] = websocket | |
self.connection_registry[connection_id] = (robot_id, websocket) | |
logger.info(f"Master {connection_id} connected to robot {robot_id}") | |
async def connect_slave(self, connection_id: str, robot_id: str, websocket: WebSocket): | |
"""Connect a slave to a robot""" | |
if robot_id not in self.slave_connections: | |
self.slave_connections[robot_id] = [] | |
self.slave_connections[robot_id].append(websocket) | |
self.connection_registry[connection_id] = (robot_id, websocket) | |
logger.info(f"Slave {connection_id} connected to robot {robot_id} ({len(self.slave_connections[robot_id])} total slaves)") | |
async def disconnect_master(self, connection_id: str): | |
"""Disconnect a master connection""" | |
if connection_id in self.connection_registry: | |
robot_id, websocket = self.connection_registry[connection_id] | |
if robot_id in self.master_connections: | |
del self.master_connections[robot_id] | |
del self.connection_registry[connection_id] | |
logger.info(f"Master {connection_id} disconnected from robot {robot_id}") | |
async def disconnect_master_by_robot(self, robot_id: str): | |
"""Disconnect master by robot ID""" | |
if robot_id in self.master_connections: | |
websocket = self.master_connections[robot_id] | |
# Find and remove from connection registry | |
for conn_id, (r_id, ws) in list(self.connection_registry.items()): | |
if r_id == robot_id and ws == websocket: | |
del self.connection_registry[conn_id] | |
break | |
del self.master_connections[robot_id] | |
try: | |
await websocket.close() | |
except Exception as e: | |
logger.error(f"Error closing master websocket for robot {robot_id}: {e}") | |
async def disconnect_slave(self, connection_id: str): | |
"""Disconnect a slave connection""" | |
if connection_id in self.connection_registry: | |
robot_id, websocket = self.connection_registry[connection_id] | |
if robot_id in self.slave_connections: | |
try: | |
self.slave_connections[robot_id].remove(websocket) | |
if not self.slave_connections[robot_id]: # Remove empty list | |
del self.slave_connections[robot_id] | |
except ValueError: | |
logger.warning(f"Slave websocket not found in connections for robot {robot_id}") | |
del self.connection_registry[connection_id] | |
logger.info(f"Slave {connection_id} disconnected from robot {robot_id}") | |
def get_master_connection(self, robot_id: str) -> Optional[WebSocket]: | |
"""Get master connection for a robot""" | |
return self.master_connections.get(robot_id) | |
def get_slave_connections(self, robot_id: str) -> List[WebSocket]: | |
"""Get all slave connections for a robot""" | |
return self.slave_connections.get(robot_id, []) | |
def get_connection_count(self) -> int: | |
"""Get total number of active connections""" | |
master_count = len(self.master_connections) | |
slave_count = sum(len(slaves) for slaves in self.slave_connections.values()) | |
return master_count + slave_count | |
def get_robot_connection_info(self, robot_id: str) -> dict: | |
"""Get connection information for a robot""" | |
has_master = robot_id in self.master_connections | |
slave_count = len(self.slave_connections.get(robot_id, [])) | |
return { | |
"robot_id": robot_id, | |
"has_master": has_master, | |
"slave_count": slave_count, | |
"total_connections": (1 if has_master else 0) + slave_count | |
} | |
async def cleanup_robot_connections(self, robot_id: str): | |
"""Clean up all connections for a robot""" | |
# Close master connection | |
if robot_id in self.master_connections: | |
try: | |
await self.master_connections[robot_id].close() | |
except Exception as e: | |
logger.error(f"Error closing master connection for robot {robot_id}: {e}") | |
del self.master_connections[robot_id] | |
# Close slave connections | |
if robot_id in self.slave_connections: | |
for websocket in self.slave_connections[robot_id]: | |
try: | |
await websocket.close() | |
except Exception as e: | |
logger.error(f"Error closing slave connection for robot {robot_id}: {e}") | |
del self.slave_connections[robot_id] | |
# Clean up connection registry | |
to_remove = [] | |
for conn_id, (r_id, _) in self.connection_registry.items(): | |
if r_id == robot_id: | |
to_remove.append(conn_id) | |
for conn_id in to_remove: | |
del self.connection_registry[conn_id] | |
logger.info(f"Cleaned up all connections for robot {robot_id}") | |
def list_all_connections(self) -> dict: | |
"""List all active connections for debugging""" | |
return { | |
"masters": {robot_id: "connected" for robot_id in self.master_connections.keys()}, | |
"slaves": {robot_id: len(slaves) for robot_id, slaves in self.slave_connections.items()}, | |
"total_connections": self.get_connection_count() | |
} |