import json import logging import uuid from datetime import UTC, datetime from fastapi import WebSocket, WebSocketDisconnect from typing_extensions import TypedDict from .models import MessageType, ParticipantRole logger = logging.getLogger(__name__) class ConnectionMetadata(TypedDict): """Connection metadata with proper typing""" workspace_id: str room_id: str participant_id: str role: ParticipantRole connected_at: datetime last_activity: datetime message_count: int # ============= SIMPLIFIED ROOM SYSTEM ============= class RoboticsRoom: """Simple robotics room with producer/consumer pattern""" def __init__(self, room_id: str, workspace_id: str): self.id = room_id self.workspace_id = workspace_id # Participants self.producer: str | None = None self.consumers: list[str] = [] # State self.joints: dict[str, float] = {} class RoboticsCore: """Core robotics system - simplified and merged with workspace support""" def __init__(self): # Nested structure: workspace_id -> room_id -> RoboticsRoom self.workspaces: dict[str, dict[str, RoboticsRoom]] = {} self.connections: dict[str, WebSocket] = {} # participant_id -> websocket self.connection_metadata: dict[ str, ConnectionMetadata ] = {} # participant_id -> metadata # ============= ROOM MANAGEMENT ============= def create_room( self, workspace_id: str | None = None, room_id: str | None = None ) -> tuple[str, str]: """Create a new room and return (workspace_id, room_id)""" workspace_id = workspace_id or str(uuid.uuid4()) room_id = room_id or str(uuid.uuid4()) # Initialize workspace if it doesn't exist if workspace_id not in self.workspaces: self.workspaces[workspace_id] = {} room = RoboticsRoom(room_id, workspace_id) self.workspaces[workspace_id][room_id] = room logger.info(f"Created room {room_id} in workspace {workspace_id}") return workspace_id, room_id def list_rooms(self, workspace_id: str) -> list[dict]: """List all rooms in a specific workspace""" if workspace_id not in self.workspaces: return [] return [ { "id": room.id, "workspace_id": room.workspace_id, "participants": { "producer": room.producer, "consumers": room.consumers, "total": len(room.consumers) + (1 if room.producer else 0), }, "joints_count": len(room.joints), } for room in self.workspaces[workspace_id].values() ] def delete_room(self, workspace_id: str, room_id: str) -> bool: """Delete a room from a workspace""" if ( workspace_id not in self.workspaces or room_id not in self.workspaces[workspace_id] ): return False room = self.workspaces[workspace_id][room_id] # Disconnect all participants for consumer_id in room.consumers[:]: self.leave_room(workspace_id, room_id, consumer_id) if room.producer: self.leave_room(workspace_id, room_id, room.producer) del self.workspaces[workspace_id][room_id] logger.info(f"Deleted room {room_id} from workspace {workspace_id}") return True def get_room_state(self, workspace_id: str, room_id: str) -> dict: """Get detailed room state""" room = self._get_room(workspace_id, room_id) if not room: return {"error": "Room not found"} return { "room_id": room_id, "workspace_id": workspace_id, "joints": room.joints, "participants": { "producer": room.producer, "consumers": room.consumers, "total": len(room.consumers) + (1 if room.producer else 0), }, "timestamp": datetime.now(tz=UTC).isoformat(), } def get_room_info(self, workspace_id: str, room_id: str) -> dict: """Get basic room info""" room = self._get_room(workspace_id, room_id) if not room: return {"error": "Room not found"} return { "id": room.id, "workspace_id": room.workspace_id, "participants": { "producer": room.producer, "consumers": room.consumers, "total": len(room.consumers) + (1 if room.producer else 0), }, "joints_count": len(room.joints), "has_producer": room.producer is not None, "active_consumers": len(room.consumers), } def _get_room(self, workspace_id: str, room_id: str) -> RoboticsRoom | None: """Get room by workspace and room ID""" if workspace_id not in self.workspaces: return None return self.workspaces[workspace_id].get(room_id) # ============= PARTICIPANT MANAGEMENT ============= def join_room( self, workspace_id: str, room_id: str, participant_id: str, role: ParticipantRole, ) -> bool: """Join room as producer or consumer""" room = self._get_room(workspace_id, room_id) if not room: return False if role == ParticipantRole.PRODUCER: if room.producer is None: room.producer = participant_id logger.info( f"Producer {participant_id} joined room {room_id} in workspace {workspace_id}" ) return True # Room already has a producer, reject the new one logger.warning( f"Producer {participant_id} failed to join room {room_id} - room already has producer {room.producer}" ) return False if role == ParticipantRole.CONSUMER: if participant_id not in room.consumers: room.consumers.append(participant_id) logger.info( f"Consumer {participant_id} joined room {room_id} in workspace {workspace_id}" ) return True return False return False def leave_room(self, workspace_id: str, room_id: str, participant_id: str): """Leave room""" room = self._get_room(workspace_id, room_id) if not room: return if room.producer == participant_id: room.producer = None logger.info( f"Producer {participant_id} left room {room_id} in workspace {workspace_id}" ) elif participant_id in room.consumers: room.consumers.remove(participant_id) logger.info( f"Consumer {participant_id} left room {room_id} in workspace {workspace_id}" ) # ============= JOINT CONTROL ============= def update_joints( self, workspace_id: str, room_id: str, joint_updates: list[dict] ) -> list[dict]: room = self._get_room(workspace_id, room_id) if not room: msg = f"Room {room_id} not found in workspace {workspace_id}" raise ValueError(msg) changed_joints = [] for joint in joint_updates: name = joint["name"] value = joint["value"] # Only track actual changes if room.joints.get(name) != value: room.joints[name] = value changed_joints.append(joint) return changed_joints # ============= WEBSOCKET HANDLING ============= async def handle_websocket( self, websocket: WebSocket, workspace_id: str, room_id: str ): """Handle WebSocket connection""" await websocket.accept() participant_id: str | None = None role: ParticipantRole | None = None try: # Get join message data = await websocket.receive_text() join_msg = json.loads(data) participant_id = join_msg["participant_id"] role = ParticipantRole(join_msg["role"]) # Join room if not self.join_room(workspace_id, room_id, participant_id, role): await websocket.send_text( json.dumps({ "type": MessageType.ERROR.value, "message": "Cannot join room", }) ) await websocket.close() return self.connections[participant_id] = websocket # Store connection metadata self.connection_metadata[participant_id] = ConnectionMetadata( workspace_id=workspace_id, room_id=room_id, participant_id=participant_id, role=role, connected_at=datetime.now(tz=UTC), last_activity=datetime.now(tz=UTC), message_count=0, ) # Send current state to consumer if role == ParticipantRole.CONSUMER: room = self._get_room(workspace_id, room_id) if room: await websocket.send_text( json.dumps({ "type": MessageType.STATE_SYNC.value, "data": room.joints, "timestamp": datetime.now(tz=UTC).isoformat(), }) ) # Send join confirmation await websocket.send_text( json.dumps({ "type": MessageType.JOINED.value, "room_id": room_id, "workspace_id": workspace_id, "role": role.value, "timestamp": datetime.now(tz=UTC).isoformat(), }) ) # Handle messages async for message in websocket.iter_text(): try: msg = json.loads(message) await self._handle_message( workspace_id, room_id, participant_id, role, msg ) except json.JSONDecodeError: logger.exception(f"Invalid JSON from {participant_id}") except Exception: logger.exception("Message error") except WebSocketDisconnect: logger.info(f"WebSocket disconnected: {participant_id}") except Exception: logger.exception("WebSocket error") finally: # Cleanup if participant_id: metadata = self.connection_metadata.get(participant_id) if metadata: self.leave_room( metadata["workspace_id"], metadata["room_id"], participant_id ) if participant_id in self.connections: del self.connections[participant_id] if participant_id in self.connection_metadata: del self.connection_metadata[participant_id] async def _handle_message( self, workspace_id: str, room_id: str, participant_id: str, role: ParticipantRole, message: dict, ): """Handle incoming WebSocket message with structured handlers""" # Update activity tracking if participant_id in self.connection_metadata: self.connection_metadata[participant_id]["last_activity"] = datetime.now( tz=UTC ) self.connection_metadata[participant_id]["message_count"] += 1 try: msg_type = MessageType(message.get("type")) except ValueError: logger.warning( f"Unknown message type from {participant_id}: {message.get('type')}" ) await self._handle_error( participant_id, f"Unknown message type: {message.get('type')}" ) return # Dispatch to specific handlers if msg_type == MessageType.JOINT_UPDATE: await self._handle_joint_update( workspace_id, room_id, participant_id, role, message ) elif msg_type == MessageType.HEARTBEAT: await self._handle_heartbeat(participant_id) elif msg_type == MessageType.EMERGENCY_STOP: await self._handle_emergency_stop( workspace_id, room_id, participant_id, message ) else: logger.warning(f"Unhandled message type {msg_type} from {participant_id}") # ============= STRUCTURED MESSAGE HANDLERS ============= async def _handle_joint_update( self, workspace_id: str, room_id: str, participant_id: str, role: ParticipantRole, message: dict, ): """Handle joint update commands from producers""" if role != ParticipantRole.PRODUCER: logger.warning( f"Non-producer {participant_id} attempted to send joint update" ) return joints = message.get("data", []) if not joints: logger.warning(f"Empty joint data from producer {participant_id}") return try: changed_joints = self.update_joints(workspace_id, room_id, joints) if changed_joints: await self._broadcast_to_consumers( workspace_id, room_id, { "type": MessageType.JOINT_UPDATE.value, "data": changed_joints, "timestamp": datetime.now(tz=UTC).isoformat(), "source": participant_id, }, ) logger.debug( f"Producer {participant_id} sent {len(changed_joints)} joint updates" ) except Exception: logger.exception(f"Error processing joint update from {participant_id}") await self._handle_error(participant_id, "Failed to process joint update") async def _handle_heartbeat(self, participant_id: str): """Handle heartbeat messages""" try: await self._send_to_participant( participant_id, { "type": MessageType.HEARTBEAT_ACK.value, "timestamp": datetime.now(tz=UTC).isoformat(), }, ) logger.debug(f"Heartbeat acknowledged for {participant_id}") except Exception: logger.exception(f"Error handling heartbeat from {participant_id}") async def _handle_emergency_stop( self, workspace_id: str, room_id: str, participant_id: str, message: dict ): """Handle emergency stop messages""" try: reason = message.get("reason", f"Emergency stop from {participant_id}") emergency_message = { "type": MessageType.EMERGENCY_STOP.value, "timestamp": datetime.now(tz=UTC).isoformat(), "reason": reason, "source": participant_id, } # Broadcast to all participants in room await self._broadcast_to_all_participants( workspace_id, room_id, emergency_message ) logger.warning( f"🚨 Emergency stop triggered by {participant_id} in room {room_id} (workspace {workspace_id})" ) except Exception: logger.exception(f"Error handling emergency stop from {participant_id}") async def _handle_error(self, participant_id: str, error_message: str): """Send error message to participant""" try: await self._send_to_participant( participant_id, { "type": MessageType.ERROR.value, "message": error_message, "timestamp": datetime.now(tz=UTC).isoformat(), }, ) except Exception: logger.exception(f"Error sending error message to {participant_id}") async def _broadcast_to_consumers( self, workspace_id: str, room_id: str, message: dict ): """Send message to all consumers in room""" room = self._get_room(workspace_id, room_id) if not room: return message_text = json.dumps(message) failed = [] for consumer_id in room.consumers: if consumer_id in self.connections: try: await self.connections[consumer_id].send_text(message_text) except Exception: logger.exception(f"Error sending message to {consumer_id}") failed.append(consumer_id) # Cleanup failed connections for consumer_id in failed: room.consumers.remove(consumer_id) if consumer_id in self.connections: del self.connections[consumer_id] if consumer_id in self.connection_metadata: del self.connection_metadata[consumer_id] async def _broadcast_to_all_participants( self, workspace_id: str, room_id: str, message: dict ): """Send message to all participants (producer + consumers) in room""" room = self._get_room(workspace_id, room_id) if not room: return message_text = json.dumps(message) participants = [] # Add producer if exists if room.producer: participants.append(room.producer) # Add all consumers participants.extend(room.consumers) failed = [] sent_count = 0 for participant_id in participants: if participant_id in self.connections: try: await self.connections[participant_id].send_text(message_text) sent_count += 1 except Exception: logger.exception(f"Error sending message to {participant_id}") failed.append(participant_id) # Cleanup failed connections for participant_id in failed: metadata = self.connection_metadata.get(participant_id) if metadata: self.leave_room( metadata["workspace_id"], metadata["room_id"], participant_id ) if participant_id in self.connections: del self.connections[participant_id] if participant_id in self.connection_metadata: del self.connection_metadata[participant_id] logger.debug( f"Broadcast message to {sent_count}/{len(participants)} participants in room {room_id}" ) async def _send_to_participant(self, participant_id: str, message: dict): """Send message to specific participant""" if participant_id in self.connections: try: await self.connections[participant_id].send_text(json.dumps(message)) except Exception: logger.exception(f"Error sending message to {participant_id}") if participant_id in self.connections: del self.connections[participant_id] # ============= CONNECTION MONITORING ============= def get_connection_stats(self) -> dict: """Get connection statistics and metadata""" stats = { "total_connections": len(self.connections), "total_workspaces": len(self.workspaces), "total_rooms": sum(len(rooms) for rooms in self.workspaces.values()), "connections_by_role": {"producer": 0, "consumer": 0}, "connections_by_workspace": {}, "active_connections": [], } # Count by role and collect active connections for participant_id, metadata in self.connection_metadata.items(): role = metadata["role"] workspace_id = metadata["workspace_id"] room_id = metadata["room_id"] stats["connections_by_role"][role.value] += 1 if workspace_id not in stats["connections_by_workspace"]: stats["connections_by_workspace"][workspace_id] = { "producer": 0, "consumer": 0, "rooms": 0, } stats["connections_by_workspace"][workspace_id][role.value] += 1 # Count unique rooms per workspace if workspace_id in self.workspaces: stats["connections_by_workspace"][workspace_id]["rooms"] = len( self.workspaces[workspace_id] ) stats["active_connections"].append({ "participant_id": participant_id, "workspace_id": workspace_id, "room_id": room_id, "role": role.value, "connected_at": metadata["connected_at"].isoformat(), "last_activity": metadata["last_activity"].isoformat(), "message_count": metadata["message_count"], }) return stats # ============= EXTERNAL API METHODS ============= async def send_command_to_room( self, workspace_id: str, room_id: str, joints: list[dict] ): changed_joints = self.update_joints(workspace_id, room_id, joints) if changed_joints: await self._broadcast_to_consumers( workspace_id, room_id, { "type": MessageType.JOINT_UPDATE.value, "data": changed_joints, "timestamp": datetime.now(tz=UTC).isoformat(), "source": "api", }, ) return len(changed_joints)