LeRobot-Arena / src-python /src /connection_manager.py
blanchon's picture
squash: initial commit
3aea7c6
from fastapi import WebSocket
from typing import Dict, List, Optional
import logging
logger = logging.getLogger(__name__)
class ConnectionManager:
"""Manages WebSocket connections for masters and slaves"""
def __init__(self):
# robot_id -> websocket
self.master_connections: Dict[str, WebSocket] = {}
# robot_id -> list of websockets
self.slave_connections: Dict[str, List[WebSocket]] = {}
# connection_id -> (robot_id, websocket)
self.connection_registry: Dict[str, tuple] = {}
async def connect_master(self, connection_id: str, robot_id: str, websocket: WebSocket):
"""Connect a master to a robot"""
# Only one master per robot
if robot_id in self.master_connections:
logger.warning(f"Disconnecting existing master for robot {robot_id}")
await self.disconnect_master_by_robot(robot_id)
self.master_connections[robot_id] = websocket
self.connection_registry[connection_id] = (robot_id, websocket)
logger.info(f"Master {connection_id} connected to robot {robot_id}")
async def connect_slave(self, connection_id: str, robot_id: str, websocket: WebSocket):
"""Connect a slave to a robot"""
if robot_id not in self.slave_connections:
self.slave_connections[robot_id] = []
self.slave_connections[robot_id].append(websocket)
self.connection_registry[connection_id] = (robot_id, websocket)
logger.info(f"Slave {connection_id} connected to robot {robot_id} ({len(self.slave_connections[robot_id])} total slaves)")
async def disconnect_master(self, connection_id: str):
"""Disconnect a master connection"""
if connection_id in self.connection_registry:
robot_id, websocket = self.connection_registry[connection_id]
if robot_id in self.master_connections:
del self.master_connections[robot_id]
del self.connection_registry[connection_id]
logger.info(f"Master {connection_id} disconnected from robot {robot_id}")
async def disconnect_master_by_robot(self, robot_id: str):
"""Disconnect master by robot ID"""
if robot_id in self.master_connections:
websocket = self.master_connections[robot_id]
# Find and remove from connection registry
for conn_id, (r_id, ws) in list(self.connection_registry.items()):
if r_id == robot_id and ws == websocket:
del self.connection_registry[conn_id]
break
del self.master_connections[robot_id]
try:
await websocket.close()
except Exception as e:
logger.error(f"Error closing master websocket for robot {robot_id}: {e}")
async def disconnect_slave(self, connection_id: str):
"""Disconnect a slave connection"""
if connection_id in self.connection_registry:
robot_id, websocket = self.connection_registry[connection_id]
if robot_id in self.slave_connections:
try:
self.slave_connections[robot_id].remove(websocket)
if not self.slave_connections[robot_id]: # Remove empty list
del self.slave_connections[robot_id]
except ValueError:
logger.warning(f"Slave websocket not found in connections for robot {robot_id}")
del self.connection_registry[connection_id]
logger.info(f"Slave {connection_id} disconnected from robot {robot_id}")
def get_master_connection(self, robot_id: str) -> Optional[WebSocket]:
"""Get master connection for a robot"""
return self.master_connections.get(robot_id)
def get_slave_connections(self, robot_id: str) -> List[WebSocket]:
"""Get all slave connections for a robot"""
return self.slave_connections.get(robot_id, [])
def get_connection_count(self) -> int:
"""Get total number of active connections"""
master_count = len(self.master_connections)
slave_count = sum(len(slaves) for slaves in self.slave_connections.values())
return master_count + slave_count
def get_robot_connection_info(self, robot_id: str) -> dict:
"""Get connection information for a robot"""
has_master = robot_id in self.master_connections
slave_count = len(self.slave_connections.get(robot_id, []))
return {
"robot_id": robot_id,
"has_master": has_master,
"slave_count": slave_count,
"total_connections": (1 if has_master else 0) + slave_count
}
async def cleanup_robot_connections(self, robot_id: str):
"""Clean up all connections for a robot"""
# Close master connection
if robot_id in self.master_connections:
try:
await self.master_connections[robot_id].close()
except Exception as e:
logger.error(f"Error closing master connection for robot {robot_id}: {e}")
del self.master_connections[robot_id]
# Close slave connections
if robot_id in self.slave_connections:
for websocket in self.slave_connections[robot_id]:
try:
await websocket.close()
except Exception as e:
logger.error(f"Error closing slave connection for robot {robot_id}: {e}")
del self.slave_connections[robot_id]
# Clean up connection registry
to_remove = []
for conn_id, (r_id, _) in self.connection_registry.items():
if r_id == robot_id:
to_remove.append(conn_id)
for conn_id in to_remove:
del self.connection_registry[conn_id]
logger.info(f"Cleaned up all connections for robot {robot_id}")
def list_all_connections(self) -> dict:
"""List all active connections for debugging"""
return {
"masters": {robot_id: "connected" for robot_id in self.master_connections.keys()},
"slaves": {robot_id: len(slaves) for robot_id, slaves in self.slave_connections.items()},
"total_connections": self.get_connection_count()
}