blanchon's picture
Initial commit
02eac4b
import asyncio
import json
import logging
from collections.abc import Callable
from urllib.parse import urlparse
import aiohttp
import websockets
logger = logging.getLogger(__name__)
class RoboticsClientCore:
"""Base client for LeRobot Arena robotics API"""
def __init__(self, base_url: str = "http://localhost:8000"):
self.base_url = base_url.rstrip("/")
self.api_base = f"{self.base_url}/robotics"
# WebSocket connection
self.websocket: websockets.WebSocketServerProtocol | None = None
self.workspace_id: str | None = None
self.room_id: str | None = None
self.role: str | None = None
self.participant_id: str | None = None
self.connected = False
# Background task for message handling
self._message_task: asyncio.Task | None = None
# ============= REST API METHODS =============
async def list_rooms(self, workspace_id: str) -> list[dict]:
"""List all available rooms in a workspace"""
async with aiohttp.ClientSession() as session:
async with session.get(
f"{self.api_base}/workspaces/{workspace_id}/rooms"
) as response:
response.raise_for_status()
result = await response.json()
# Extract the rooms list from the response
return result.get("rooms", [])
async def create_room(
self, workspace_id: str | None = None, room_id: str | None = None
) -> tuple[str, str]:
"""Create a new room and return (workspace_id, room_id)"""
# Generate workspace ID if not provided
final_workspace_id = workspace_id or self._generate_workspace_id()
payload = {}
if room_id:
payload["room_id"] = room_id
async with aiohttp.ClientSession() as session:
async with session.post(
f"{self.api_base}/workspaces/{final_workspace_id}/rooms", json=payload
) as response:
response.raise_for_status()
result = await response.json()
return result["workspace_id"], result["room_id"]
async def delete_room(self, workspace_id: str, room_id: str) -> bool:
"""Delete a room"""
async with aiohttp.ClientSession() as session:
async with session.delete(
f"{self.api_base}/workspaces/{workspace_id}/rooms/{room_id}"
) as response:
if response.status == 404:
return False
response.raise_for_status()
result = await response.json()
return result["success"]
async def get_room_state(self, workspace_id: str, room_id: str) -> dict:
"""Get current room state"""
async with aiohttp.ClientSession() as session:
async with session.get(
f"{self.api_base}/workspaces/{workspace_id}/rooms/{room_id}/state"
) as response:
response.raise_for_status()
result = await response.json()
# Extract the state from the response
return result.get("state", {})
async def get_room_info(self, workspace_id: str, room_id: str) -> dict:
"""Get basic room information"""
async with aiohttp.ClientSession() as session:
async with session.get(
f"{self.api_base}/workspaces/{workspace_id}/rooms/{room_id}"
) as response:
response.raise_for_status()
result = await response.json()
# Extract the room data from the response
return result.get("room", {})
# ============= WEBSOCKET CONNECTION =============
async def connect_to_room(
self,
workspace_id: str,
room_id: str,
role: str,
participant_id: str | None = None,
) -> bool:
"""Connect to a room as producer or consumer"""
if self.connected:
await self.disconnect()
self.workspace_id = workspace_id
self.room_id = room_id
self.role = role
self.participant_id = participant_id or f"{role}_{id(self)}"
# Convert HTTP URL to WebSocket URL
parsed = urlparse(self.base_url)
ws_scheme = "wss" if parsed.scheme == "https" else "ws"
ws_url = f"{ws_scheme}://{parsed.netloc}/robotics/workspaces/{workspace_id}/rooms/{room_id}/ws"
initial_state_sync = None
try:
self.websocket = await websockets.connect(ws_url)
# Send join message
join_message = {"participant_id": self.participant_id, "role": role}
await self.websocket.send(json.dumps(join_message))
# Wait for server response to join message
try:
response_text = await asyncio.wait_for(
self.websocket.recv(), timeout=5.0
)
response = json.loads(response_text)
if response.get("type") == "error":
logger.error(
f"Server rejected connection: {response.get('message')}"
)
await self.websocket.close()
return False
if response.get("type") == "state_sync":
# Consumer receives initial state sync, store it and wait for joined message
logger.debug("Received initial state sync")
initial_state_sync = response
# Wait for the joined message
response_text = await asyncio.wait_for(
self.websocket.recv(), timeout=5.0
)
response = json.loads(response_text)
if response.get("type") == "joined":
logger.info(f"Successfully joined room {room_id} as {role}")
elif response.get("type") == "error":
logger.error(
f"Server rejected connection: {response.get('message')}"
)
await self.websocket.close()
return False
else:
logger.warning(f"Unexpected response from server: {response}")
elif response.get("type") == "joined":
logger.info(f"Successfully joined room {room_id} as {role}")
# Connection successful, continue with setup
else:
logger.warning(f"Unexpected response from server: {response}")
except TimeoutError:
logger.error("Timeout waiting for server response")
await self.websocket.close()
return False
except json.JSONDecodeError:
logger.error("Invalid JSON response from server")
await self.websocket.close()
return False
# Start message handling task
self._message_task = asyncio.create_task(self._handle_messages())
self.connected = True
logger.info(f"Connected to room {room_id} as {role}")
await self._on_connected()
# Process initial state sync if we received one
if initial_state_sync:
await self._process_message(initial_state_sync)
return True
except Exception as e:
logger.error(f"Failed to connect to room {room_id}: {e}")
return False
async def disconnect(self):
"""Disconnect from current room"""
if self._message_task:
self._message_task.cancel()
try:
await self._message_task
except asyncio.CancelledError:
pass
self._message_task = None
if self.websocket:
await self.websocket.close()
self.websocket = None
self.connected = False
self.workspace_id = None
self.room_id = None
self.role = None
self.participant_id = None
await self._on_disconnected()
logger.info("Disconnected from room")
# ============= MESSAGE HANDLING =============
async def _handle_messages(self):
"""Handle incoming WebSocket messages"""
try:
async for message in self.websocket:
try:
data = json.loads(message)
await self._process_message(data)
except json.JSONDecodeError:
logger.error(f"Invalid JSON received: {message}")
except Exception as e:
logger.error(f"Error processing message: {e}")
except websockets.exceptions.ConnectionClosed:
logger.info("WebSocket connection closed")
except Exception as e:
logger.error(f"WebSocket error: {e}")
finally:
self.connected = False
await self._on_disconnected()
async def _process_message(self, data: dict):
"""Process incoming message based on type - to be overridden by subclasses"""
msg_type = data.get("type")
if msg_type == "joined":
logger.info(
f"Successfully joined room {data.get('room_id')} as {data.get('role')}"
)
elif msg_type == "heartbeat_ack":
logger.debug("Heartbeat acknowledged")
else:
# Let subclasses handle specific message types
await self._handle_role_specific_message(data)
async def _handle_role_specific_message(self, data: dict):
"""Handle role-specific messages - to be overridden by subclasses"""
# ============= UTILITY METHODS =============
async def send_heartbeat(self):
"""Send heartbeat to server"""
if not self.connected:
return
message = {"type": "heartbeat"}
await self.websocket.send(json.dumps(message))
def is_connected(self) -> bool:
"""Check if client is connected"""
return self.connected
def get_connection_info(self) -> dict:
"""Get current connection information"""
return {
"connected": self.connected,
"workspace_id": self.workspace_id,
"room_id": self.room_id,
"role": self.role,
"participant_id": self.participant_id,
"base_url": self.base_url,
}
# ============= HOOKS FOR SUBCLASSES =============
async def _on_connected(self):
"""Called when connection is established - to be overridden by subclasses"""
async def _on_disconnected(self):
"""Called when connection is lost - to be overridden by subclasses"""
# ============= CONTEXT MANAGER SUPPORT =============
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
await self.disconnect()
# ============= WORKSPACE HELPERS =============
def _generate_workspace_id(self) -> str:
"""Generate a UUID-like workspace ID"""
import uuid
return str(uuid.uuid4())
class RoboticsProducer(RoboticsClientCore):
"""Producer client for controlling robots"""
def __init__(self, base_url: str = "http://localhost:8000"):
super().__init__(base_url)
self._on_error_callback: Callable[[str], None] | None = None
self._on_connected_callback: Callable[[], None] | None = None
self._on_disconnected_callback: Callable[[], None] | None = None
async def connect(
self, workspace_id: str, room_id: str, participant_id: str | None = None
) -> bool:
"""Connect as producer to a room"""
return await self.connect_to_room(
workspace_id, room_id, "producer", participant_id
)
# ============= PRODUCER METHODS =============
async def send_joint_update(self, joints: list[dict]):
"""Send joint updates"""
if not self.connected:
raise ValueError("Must be connected to send joint updates")
message = {"type": "joint_update", "data": joints}
await self.websocket.send(json.dumps(message))
async def send_state_sync(self, state: dict):
"""Send state synchronization (convert dict to list format)"""
joints = [{"name": name, "value": value} for name, value in state.items()]
await self.send_joint_update(joints)
async def send_emergency_stop(self, reason: str = "Emergency stop"):
"""Send emergency stop signal"""
if not self.connected:
raise ValueError("Must be connected to send emergency stop")
message = {"type": "emergency_stop", "reason": reason}
await self.websocket.send(json.dumps(message))
# ============= EVENT CALLBACKS =============
def on_error(self, callback: Callable[[str], None]):
"""Set callback for error events"""
self._on_error_callback = callback
def on_connected(self, callback: Callable[[], None]):
"""Set callback for connection events"""
self._on_connected_callback = callback
def on_disconnected(self, callback: Callable[[], None]):
"""Set callback for disconnection events"""
self._on_disconnected_callback = callback
# ============= OVERRIDDEN HOOKS =============
async def _on_connected(self):
if self._on_connected_callback:
self._on_connected_callback()
async def _on_disconnected(self):
if self._on_disconnected_callback:
self._on_disconnected_callback()
async def _handle_role_specific_message(self, data: dict):
"""Handle producer-specific messages"""
msg_type = data.get("type")
if msg_type == "emergency_stop":
logger.warning(f"🚨 Emergency stop: {data.get('reason', 'Unknown reason')}")
if self._on_error_callback:
self._on_error_callback(
f"Emergency stop: {data.get('reason', 'Unknown reason')}"
)
elif msg_type == "error":
error_msg = data.get("message", "Unknown error")
logger.error(f"Server error: {error_msg}")
if self._on_error_callback:
self._on_error_callback(error_msg)
else:
logger.warning(f"Unknown message type for producer: {msg_type}")
class RoboticsConsumer(RoboticsClientCore):
"""Consumer client for receiving robot commands"""
def __init__(self, base_url: str = "http://localhost:8000"):
super().__init__(base_url)
self._on_state_sync_callback: Callable[[dict], None] | None = None
self._on_joint_update_callback: Callable[[list], None] | None = None
self._on_error_callback: Callable[[str], None] | None = None
self._on_connected_callback: Callable[[], None] | None = None
self._on_disconnected_callback: Callable[[], None] | None = None
async def connect(
self, workspace_id: str, room_id: str, participant_id: str | None = None
) -> bool:
"""Connect as consumer to a room"""
return await self.connect_to_room(
workspace_id, room_id, "consumer", participant_id
)
# ============= CONSUMER METHODS =============
async def get_state_sync(self) -> dict:
"""Get current state synchronously"""
if not self.workspace_id or not self.room_id:
raise ValueError("Must be connected to a room")
state = await self.get_room_state(self.workspace_id, self.room_id)
return state.get("joints", {})
# ============= EVENT CALLBACKS =============
def on_state_sync(self, callback: Callable[[dict], None]):
"""Set callback for state synchronization events"""
self._on_state_sync_callback = callback
def on_joint_update(self, callback: Callable[[list], None]):
"""Set callback for joint update events"""
self._on_joint_update_callback = callback
def on_error(self, callback: Callable[[str], None]):
"""Set callback for error events"""
self._on_error_callback = callback
def on_connected(self, callback: Callable[[], None]):
"""Set callback for connection events"""
self._on_connected_callback = callback
def on_disconnected(self, callback: Callable[[], None]):
"""Set callback for disconnection events"""
self._on_disconnected_callback = callback
# ============= OVERRIDDEN HOOKS =============
async def _on_connected(self):
if self._on_connected_callback:
self._on_connected_callback()
async def _on_disconnected(self):
if self._on_disconnected_callback:
self._on_disconnected_callback()
async def _handle_role_specific_message(self, data: dict):
"""Handle consumer-specific messages"""
msg_type = data.get("type")
if msg_type == "state_sync":
if self._on_state_sync_callback:
self._on_state_sync_callback(data.get("data", {}))
elif msg_type == "joint_update":
if self._on_joint_update_callback:
self._on_joint_update_callback(data.get("data", []))
elif msg_type == "emergency_stop":
logger.warning(f"🚨 Emergency stop: {data.get('reason', 'Unknown reason')}")
if self._on_error_callback:
self._on_error_callback(
f"Emergency stop: {data.get('reason', 'Unknown reason')}"
)
elif msg_type == "error":
error_msg = data.get("message", "Unknown error")
logger.error(f"Server error: {error_msg}")
if self._on_error_callback:
self._on_error_callback(error_msg)
else:
logger.warning(f"Unknown message type for consumer: {msg_type}")
# ============= FACTORY FUNCTIONS =============
def create_client(role: str, base_url: str = "http://localhost:8000"):
"""Factory function to create the appropriate client based on role"""
if role == "producer":
return RoboticsProducer(base_url)
if role == "consumer":
return RoboticsConsumer(base_url)
raise ValueError(f"Invalid role: {role}. Must be 'producer' or 'consumer'")
async def create_producer_client(
base_url: str = "http://localhost:8000",
workspace_id: str | None = None,
room_id: str | None = None,
) -> RoboticsProducer:
"""Create and connect a producer client"""
client = RoboticsProducer(base_url)
workspace_id, room_id = await client.create_room(workspace_id, room_id)
await client.connect(workspace_id, room_id)
return client
async def create_consumer_client(
workspace_id: str, room_id: str, base_url: str = "http://localhost:8000"
) -> RoboticsConsumer:
"""Create and connect a consumer client"""
client = RoboticsConsumer(base_url)
await client.connect(workspace_id, room_id)
return client