Spaces:
Running
Running
import asyncio | |
import json | |
import logging | |
from collections.abc import Callable | |
from urllib.parse import urlparse | |
import aiohttp | |
import websockets | |
logger = logging.getLogger(__name__) | |
class RoboticsClientCore: | |
"""Base client for LeRobot Arena robotics API""" | |
def __init__(self, base_url: str = "http://localhost:8000"): | |
self.base_url = base_url.rstrip("/") | |
self.api_base = f"{self.base_url}/robotics" | |
# WebSocket connection | |
self.websocket: websockets.WebSocketServerProtocol | None = None | |
self.workspace_id: str | None = None | |
self.room_id: str | None = None | |
self.role: str | None = None | |
self.participant_id: str | None = None | |
self.connected = False | |
# Background task for message handling | |
self._message_task: asyncio.Task | None = None | |
# ============= REST API METHODS ============= | |
async def list_rooms(self, workspace_id: str) -> list[dict]: | |
"""List all available rooms in a workspace""" | |
async with aiohttp.ClientSession() as session: | |
async with session.get( | |
f"{self.api_base}/workspaces/{workspace_id}/rooms" | |
) as response: | |
response.raise_for_status() | |
result = await response.json() | |
# Extract the rooms list from the response | |
return result.get("rooms", []) | |
async def create_room( | |
self, workspace_id: str | None = None, room_id: str | None = None | |
) -> tuple[str, str]: | |
"""Create a new room and return (workspace_id, room_id)""" | |
# Generate workspace ID if not provided | |
final_workspace_id = workspace_id or self._generate_workspace_id() | |
payload = {} | |
if room_id: | |
payload["room_id"] = room_id | |
async with aiohttp.ClientSession() as session: | |
async with session.post( | |
f"{self.api_base}/workspaces/{final_workspace_id}/rooms", json=payload | |
) as response: | |
response.raise_for_status() | |
result = await response.json() | |
return result["workspace_id"], result["room_id"] | |
async def delete_room(self, workspace_id: str, room_id: str) -> bool: | |
"""Delete a room""" | |
async with aiohttp.ClientSession() as session: | |
async with session.delete( | |
f"{self.api_base}/workspaces/{workspace_id}/rooms/{room_id}" | |
) as response: | |
if response.status == 404: | |
return False | |
response.raise_for_status() | |
result = await response.json() | |
return result["success"] | |
async def get_room_state(self, workspace_id: str, room_id: str) -> dict: | |
"""Get current room state""" | |
async with aiohttp.ClientSession() as session: | |
async with session.get( | |
f"{self.api_base}/workspaces/{workspace_id}/rooms/{room_id}/state" | |
) as response: | |
response.raise_for_status() | |
result = await response.json() | |
# Extract the state from the response | |
return result.get("state", {}) | |
async def get_room_info(self, workspace_id: str, room_id: str) -> dict: | |
"""Get basic room information""" | |
async with aiohttp.ClientSession() as session: | |
async with session.get( | |
f"{self.api_base}/workspaces/{workspace_id}/rooms/{room_id}" | |
) as response: | |
response.raise_for_status() | |
result = await response.json() | |
# Extract the room data from the response | |
return result.get("room", {}) | |
# ============= WEBSOCKET CONNECTION ============= | |
async def connect_to_room( | |
self, | |
workspace_id: str, | |
room_id: str, | |
role: str, | |
participant_id: str | None = None, | |
) -> bool: | |
"""Connect to a room as producer or consumer""" | |
if self.connected: | |
await self.disconnect() | |
self.workspace_id = workspace_id | |
self.room_id = room_id | |
self.role = role | |
self.participant_id = participant_id or f"{role}_{id(self)}" | |
# Convert HTTP URL to WebSocket URL | |
parsed = urlparse(self.base_url) | |
ws_scheme = "wss" if parsed.scheme == "https" else "ws" | |
ws_url = f"{ws_scheme}://{parsed.netloc}/robotics/workspaces/{workspace_id}/rooms/{room_id}/ws" | |
initial_state_sync = None | |
try: | |
self.websocket = await websockets.connect(ws_url) | |
# Send join message | |
join_message = {"participant_id": self.participant_id, "role": role} | |
await self.websocket.send(json.dumps(join_message)) | |
# Wait for server response to join message | |
try: | |
response_text = await asyncio.wait_for( | |
self.websocket.recv(), timeout=5.0 | |
) | |
response = json.loads(response_text) | |
if response.get("type") == "error": | |
logger.error( | |
f"Server rejected connection: {response.get('message')}" | |
) | |
await self.websocket.close() | |
return False | |
if response.get("type") == "state_sync": | |
# Consumer receives initial state sync, store it and wait for joined message | |
logger.debug("Received initial state sync") | |
initial_state_sync = response | |
# Wait for the joined message | |
response_text = await asyncio.wait_for( | |
self.websocket.recv(), timeout=5.0 | |
) | |
response = json.loads(response_text) | |
if response.get("type") == "joined": | |
logger.info(f"Successfully joined room {room_id} as {role}") | |
elif response.get("type") == "error": | |
logger.error( | |
f"Server rejected connection: {response.get('message')}" | |
) | |
await self.websocket.close() | |
return False | |
else: | |
logger.warning(f"Unexpected response from server: {response}") | |
elif response.get("type") == "joined": | |
logger.info(f"Successfully joined room {room_id} as {role}") | |
# Connection successful, continue with setup | |
else: | |
logger.warning(f"Unexpected response from server: {response}") | |
except TimeoutError: | |
logger.error("Timeout waiting for server response") | |
await self.websocket.close() | |
return False | |
except json.JSONDecodeError: | |
logger.error("Invalid JSON response from server") | |
await self.websocket.close() | |
return False | |
# Start message handling task | |
self._message_task = asyncio.create_task(self._handle_messages()) | |
self.connected = True | |
logger.info(f"Connected to room {room_id} as {role}") | |
await self._on_connected() | |
# Process initial state sync if we received one | |
if initial_state_sync: | |
await self._process_message(initial_state_sync) | |
return True | |
except Exception as e: | |
logger.error(f"Failed to connect to room {room_id}: {e}") | |
return False | |
async def disconnect(self): | |
"""Disconnect from current room""" | |
if self._message_task: | |
self._message_task.cancel() | |
try: | |
await self._message_task | |
except asyncio.CancelledError: | |
pass | |
self._message_task = None | |
if self.websocket: | |
await self.websocket.close() | |
self.websocket = None | |
self.connected = False | |
self.workspace_id = None | |
self.room_id = None | |
self.role = None | |
self.participant_id = None | |
await self._on_disconnected() | |
logger.info("Disconnected from room") | |
# ============= MESSAGE HANDLING ============= | |
async def _handle_messages(self): | |
"""Handle incoming WebSocket messages""" | |
try: | |
async for message in self.websocket: | |
try: | |
data = json.loads(message) | |
await self._process_message(data) | |
except json.JSONDecodeError: | |
logger.error(f"Invalid JSON received: {message}") | |
except Exception as e: | |
logger.error(f"Error processing message: {e}") | |
except websockets.exceptions.ConnectionClosed: | |
logger.info("WebSocket connection closed") | |
except Exception as e: | |
logger.error(f"WebSocket error: {e}") | |
finally: | |
self.connected = False | |
await self._on_disconnected() | |
async def _process_message(self, data: dict): | |
"""Process incoming message based on type - to be overridden by subclasses""" | |
msg_type = data.get("type") | |
if msg_type == "joined": | |
logger.info( | |
f"Successfully joined room {data.get('room_id')} as {data.get('role')}" | |
) | |
elif msg_type == "heartbeat_ack": | |
logger.debug("Heartbeat acknowledged") | |
else: | |
# Let subclasses handle specific message types | |
await self._handle_role_specific_message(data) | |
async def _handle_role_specific_message(self, data: dict): | |
"""Handle role-specific messages - to be overridden by subclasses""" | |
# ============= UTILITY METHODS ============= | |
async def send_heartbeat(self): | |
"""Send heartbeat to server""" | |
if not self.connected: | |
return | |
message = {"type": "heartbeat"} | |
await self.websocket.send(json.dumps(message)) | |
def is_connected(self) -> bool: | |
"""Check if client is connected""" | |
return self.connected | |
def get_connection_info(self) -> dict: | |
"""Get current connection information""" | |
return { | |
"connected": self.connected, | |
"workspace_id": self.workspace_id, | |
"room_id": self.room_id, | |
"role": self.role, | |
"participant_id": self.participant_id, | |
"base_url": self.base_url, | |
} | |
# ============= HOOKS FOR SUBCLASSES ============= | |
async def _on_connected(self): | |
"""Called when connection is established - to be overridden by subclasses""" | |
async def _on_disconnected(self): | |
"""Called when connection is lost - to be overridden by subclasses""" | |
# ============= CONTEXT MANAGER SUPPORT ============= | |
async def __aenter__(self): | |
return self | |
async def __aexit__(self, exc_type, exc_val, exc_tb): | |
await self.disconnect() | |
# ============= WORKSPACE HELPERS ============= | |
def _generate_workspace_id(self) -> str: | |
"""Generate a UUID-like workspace ID""" | |
import uuid | |
return str(uuid.uuid4()) | |
class RoboticsProducer(RoboticsClientCore): | |
"""Producer client for controlling robots""" | |
def __init__(self, base_url: str = "http://localhost:8000"): | |
super().__init__(base_url) | |
self._on_error_callback: Callable[[str], None] | None = None | |
self._on_connected_callback: Callable[[], None] | None = None | |
self._on_disconnected_callback: Callable[[], None] | None = None | |
async def connect( | |
self, workspace_id: str, room_id: str, participant_id: str | None = None | |
) -> bool: | |
"""Connect as producer to a room""" | |
return await self.connect_to_room( | |
workspace_id, room_id, "producer", participant_id | |
) | |
# ============= PRODUCER METHODS ============= | |
async def send_joint_update(self, joints: list[dict]): | |
"""Send joint updates""" | |
if not self.connected: | |
raise ValueError("Must be connected to send joint updates") | |
message = {"type": "joint_update", "data": joints} | |
await self.websocket.send(json.dumps(message)) | |
async def send_state_sync(self, state: dict): | |
"""Send state synchronization (convert dict to list format)""" | |
joints = [{"name": name, "value": value} for name, value in state.items()] | |
await self.send_joint_update(joints) | |
async def send_emergency_stop(self, reason: str = "Emergency stop"): | |
"""Send emergency stop signal""" | |
if not self.connected: | |
raise ValueError("Must be connected to send emergency stop") | |
message = {"type": "emergency_stop", "reason": reason} | |
await self.websocket.send(json.dumps(message)) | |
# ============= EVENT CALLBACKS ============= | |
def on_error(self, callback: Callable[[str], None]): | |
"""Set callback for error events""" | |
self._on_error_callback = callback | |
def on_connected(self, callback: Callable[[], None]): | |
"""Set callback for connection events""" | |
self._on_connected_callback = callback | |
def on_disconnected(self, callback: Callable[[], None]): | |
"""Set callback for disconnection events""" | |
self._on_disconnected_callback = callback | |
# ============= OVERRIDDEN HOOKS ============= | |
async def _on_connected(self): | |
if self._on_connected_callback: | |
self._on_connected_callback() | |
async def _on_disconnected(self): | |
if self._on_disconnected_callback: | |
self._on_disconnected_callback() | |
async def _handle_role_specific_message(self, data: dict): | |
"""Handle producer-specific messages""" | |
msg_type = data.get("type") | |
if msg_type == "emergency_stop": | |
logger.warning(f"🚨 Emergency stop: {data.get('reason', 'Unknown reason')}") | |
if self._on_error_callback: | |
self._on_error_callback( | |
f"Emergency stop: {data.get('reason', 'Unknown reason')}" | |
) | |
elif msg_type == "error": | |
error_msg = data.get("message", "Unknown error") | |
logger.error(f"Server error: {error_msg}") | |
if self._on_error_callback: | |
self._on_error_callback(error_msg) | |
else: | |
logger.warning(f"Unknown message type for producer: {msg_type}") | |
class RoboticsConsumer(RoboticsClientCore): | |
"""Consumer client for receiving robot commands""" | |
def __init__(self, base_url: str = "http://localhost:8000"): | |
super().__init__(base_url) | |
self._on_state_sync_callback: Callable[[dict], None] | None = None | |
self._on_joint_update_callback: Callable[[list], None] | None = None | |
self._on_error_callback: Callable[[str], None] | None = None | |
self._on_connected_callback: Callable[[], None] | None = None | |
self._on_disconnected_callback: Callable[[], None] | None = None | |
async def connect( | |
self, workspace_id: str, room_id: str, participant_id: str | None = None | |
) -> bool: | |
"""Connect as consumer to a room""" | |
return await self.connect_to_room( | |
workspace_id, room_id, "consumer", participant_id | |
) | |
# ============= CONSUMER METHODS ============= | |
async def get_state_sync(self) -> dict: | |
"""Get current state synchronously""" | |
if not self.workspace_id or not self.room_id: | |
raise ValueError("Must be connected to a room") | |
state = await self.get_room_state(self.workspace_id, self.room_id) | |
return state.get("joints", {}) | |
# ============= EVENT CALLBACKS ============= | |
def on_state_sync(self, callback: Callable[[dict], None]): | |
"""Set callback for state synchronization events""" | |
self._on_state_sync_callback = callback | |
def on_joint_update(self, callback: Callable[[list], None]): | |
"""Set callback for joint update events""" | |
self._on_joint_update_callback = callback | |
def on_error(self, callback: Callable[[str], None]): | |
"""Set callback for error events""" | |
self._on_error_callback = callback | |
def on_connected(self, callback: Callable[[], None]): | |
"""Set callback for connection events""" | |
self._on_connected_callback = callback | |
def on_disconnected(self, callback: Callable[[], None]): | |
"""Set callback for disconnection events""" | |
self._on_disconnected_callback = callback | |
# ============= OVERRIDDEN HOOKS ============= | |
async def _on_connected(self): | |
if self._on_connected_callback: | |
self._on_connected_callback() | |
async def _on_disconnected(self): | |
if self._on_disconnected_callback: | |
self._on_disconnected_callback() | |
async def _handle_role_specific_message(self, data: dict): | |
"""Handle consumer-specific messages""" | |
msg_type = data.get("type") | |
if msg_type == "state_sync": | |
if self._on_state_sync_callback: | |
self._on_state_sync_callback(data.get("data", {})) | |
elif msg_type == "joint_update": | |
if self._on_joint_update_callback: | |
self._on_joint_update_callback(data.get("data", [])) | |
elif msg_type == "emergency_stop": | |
logger.warning(f"🚨 Emergency stop: {data.get('reason', 'Unknown reason')}") | |
if self._on_error_callback: | |
self._on_error_callback( | |
f"Emergency stop: {data.get('reason', 'Unknown reason')}" | |
) | |
elif msg_type == "error": | |
error_msg = data.get("message", "Unknown error") | |
logger.error(f"Server error: {error_msg}") | |
if self._on_error_callback: | |
self._on_error_callback(error_msg) | |
else: | |
logger.warning(f"Unknown message type for consumer: {msg_type}") | |
# ============= FACTORY FUNCTIONS ============= | |
def create_client(role: str, base_url: str = "http://localhost:8000"): | |
"""Factory function to create the appropriate client based on role""" | |
if role == "producer": | |
return RoboticsProducer(base_url) | |
if role == "consumer": | |
return RoboticsConsumer(base_url) | |
raise ValueError(f"Invalid role: {role}. Must be 'producer' or 'consumer'") | |
async def create_producer_client( | |
base_url: str = "http://localhost:8000", | |
workspace_id: str | None = None, | |
room_id: str | None = None, | |
) -> RoboticsProducer: | |
"""Create and connect a producer client""" | |
client = RoboticsProducer(base_url) | |
workspace_id, room_id = await client.create_room(workspace_id, room_id) | |
await client.connect(workspace_id, room_id) | |
return client | |
async def create_consumer_client( | |
workspace_id: str, room_id: str, base_url: str = "http://localhost:8000" | |
) -> RoboticsConsumer: | |
"""Create and connect a consumer client""" | |
client = RoboticsConsumer(base_url) | |
await client.connect(workspace_id, room_id) | |
return client | |