Spaces:
Running
Running
import asyncio | |
import pytest | |
from lerobot_arena_client import RoboticsConsumer | |
class TestRoboticsConsumer: | |
"""Test RoboticsConsumer functionality.""" | |
async def test_consumer_connection(self, consumer, test_room): | |
"""Test basic consumer connection.""" | |
assert not consumer.is_connected() | |
success = await consumer.connect(test_room) | |
assert success is True | |
assert consumer.is_connected() | |
assert consumer.room_id == test_room | |
assert consumer.role == "consumer" | |
await consumer.disconnect() | |
assert not consumer.is_connected() | |
async def test_consumer_connection_info(self, connected_consumer): | |
"""Test getting connection information.""" | |
consumer, room_id = connected_consumer | |
info = consumer.get_connection_info() | |
assert info["connected"] is True | |
assert info["room_id"] == room_id | |
assert info["role"] == "consumer" | |
assert info["participant_id"] is not None | |
assert info["base_url"] == "http://localhost:8000" | |
async def test_get_state_sync(self, connected_consumer): | |
"""Test getting current state synchronously.""" | |
consumer, room_id = connected_consumer | |
state = await consumer.get_state_sync() | |
assert isinstance(state, dict) | |
# Initial state should be empty | |
assert len(state) == 0 | |
async def test_consumer_callbacks_setup(self, consumer, test_room): | |
"""Test setting up consumer callbacks.""" | |
state_sync_called = False | |
joint_update_called = False | |
error_called = False | |
connected_called = False | |
disconnected_called = False | |
def on_state_sync(state): | |
nonlocal state_sync_called | |
state_sync_called = True | |
def on_joint_update(joints): | |
nonlocal joint_update_called | |
joint_update_called = True | |
def on_error(error): | |
nonlocal error_called | |
error_called = True | |
def on_connected(): | |
nonlocal connected_called | |
connected_called = True | |
def on_disconnected(): | |
nonlocal disconnected_called | |
disconnected_called = True | |
# Set all callbacks | |
consumer.on_state_sync(on_state_sync) | |
consumer.on_joint_update(on_joint_update) | |
consumer.on_error(on_error) | |
consumer.on_connected(on_connected) | |
consumer.on_disconnected(on_disconnected) | |
# Connect and test connection callbacks | |
await consumer.connect(test_room) | |
await asyncio.sleep(0.1) | |
assert connected_called is True | |
await consumer.disconnect() | |
await asyncio.sleep(0.1) | |
assert disconnected_called is True | |
async def test_multiple_consumers(self, test_room): | |
"""Test multiple consumers connecting to same room.""" | |
consumer1 = RoboticsConsumer("http://localhost:8000") | |
consumer2 = RoboticsConsumer("http://localhost:8000") | |
try: | |
# Both consumers should be able to connect | |
success1 = await consumer1.connect(test_room) | |
success2 = await consumer2.connect(test_room) | |
assert success1 is True | |
assert success2 is True | |
assert consumer1.is_connected() | |
assert consumer2.is_connected() | |
finally: | |
if consumer1.is_connected(): | |
await consumer1.disconnect() | |
if consumer2.is_connected(): | |
await consumer2.disconnect() | |
async def test_consumer_receive_state_sync(self, producer_consumer_pair): | |
"""Test consumer receiving state sync from producer.""" | |
producer, consumer, room_id = producer_consumer_pair | |
received_states = [] | |
received_updates = [] | |
def on_state_sync(state): | |
received_states.append(state) | |
def on_joint_update(joints): | |
received_updates.append(joints) | |
consumer.on_state_sync(on_state_sync) | |
consumer.on_joint_update(on_joint_update) | |
# Give some time for connection to stabilize | |
await asyncio.sleep(0.1) | |
# Producer sends state sync (which gets converted to joint updates) | |
await producer.send_state_sync({"shoulder": 45.0, "elbow": -20.0}) | |
# Wait for message to be received | |
await asyncio.sleep(0.2) | |
# Consumer should have received the joint updates from the state sync | |
# The initial state sync during connection might be empty, so we check for joint updates | |
assert len(received_updates) >= 1 | |
async def test_consumer_receive_joint_updates(self, producer_consumer_pair): | |
"""Test consumer receiving joint updates from producer.""" | |
producer, consumer, room_id = producer_consumer_pair | |
received_updates = [] | |
def on_joint_update(joints): | |
received_updates.append(joints) | |
consumer.on_joint_update(on_joint_update) | |
# Give some time for connection to stabilize | |
await asyncio.sleep(0.1) | |
# Producer sends joint updates | |
test_joints = [ | |
{"name": "shoulder", "value": 45.0}, | |
{"name": "elbow", "value": -20.0}, | |
] | |
await producer.send_joint_update(test_joints) | |
# Wait for message to be received | |
await asyncio.sleep(0.2) | |
# Consumer should have received the joint update | |
assert len(received_updates) >= 1 | |
if received_updates: | |
received_joints = received_updates[-1] | |
assert isinstance(received_joints, list) | |
assert len(received_joints) == 2 | |
async def test_consumer_multiple_updates(self, producer_consumer_pair): | |
"""Test consumer receiving multiple updates.""" | |
producer, consumer, room_id = producer_consumer_pair | |
received_updates = [] | |
def on_joint_update(joints): | |
received_updates.append(joints) | |
consumer.on_joint_update(on_joint_update) | |
# Give some time for connection to stabilize | |
await asyncio.sleep(0.1) | |
# Send multiple updates | |
for i in range(5): | |
await producer.send_state_sync({ | |
"joint1": float(i * 10), | |
"joint2": float(i * -5), | |
}) | |
await asyncio.sleep(0.05) | |
# Wait for all messages to be received | |
await asyncio.sleep(0.3) | |
# Should have received multiple updates | |
assert len(received_updates) >= 3 | |
async def test_consumer_emergency_stop(self, producer_consumer_pair): | |
"""Test consumer receiving emergency stop.""" | |
producer, consumer, room_id = producer_consumer_pair | |
received_errors = [] | |
def on_error(error): | |
received_errors.append(error) | |
consumer.on_error(on_error) | |
# Give some time for connection to stabilize | |
await asyncio.sleep(0.1) | |
# Producer sends emergency stop | |
await producer.send_emergency_stop("Test emergency stop") | |
# Wait for message to be received | |
await asyncio.sleep(0.2) | |
# Consumer should have received emergency stop as error | |
assert len(received_errors) >= 1 | |
if received_errors: | |
assert "emergency stop" in received_errors[-1].lower() | |
async def test_custom_participant_id(self, consumer, test_room): | |
"""Test connecting with custom participant ID.""" | |
custom_id = "custom-consumer-456" | |
await consumer.connect(test_room, participant_id=custom_id) | |
info = consumer.get_connection_info() | |
assert info["participant_id"] == custom_id | |
async def test_context_manager(self, test_room): | |
"""Test using consumer as context manager.""" | |
async with RoboticsConsumer("http://localhost:8000") as consumer: | |
await consumer.connect(test_room) | |
assert consumer.is_connected() | |
state = await consumer.get_state_sync() | |
assert isinstance(state, dict) | |
# Should be disconnected after context exit | |
assert not consumer.is_connected() | |
async def test_get_state_without_connection(self, consumer): | |
"""Test getting state without being connected.""" | |
assert not consumer.is_connected() | |
with pytest.raises(ValueError, match="Must be connected to a room"): | |
await consumer.get_state_sync() | |
async def test_consumer_reconnection(self, consumer, test_room): | |
"""Test consumer reconnecting to same room.""" | |
# First connection | |
await consumer.connect(test_room) | |
assert consumer.is_connected() | |
await consumer.disconnect() | |
assert not consumer.is_connected() | |
# Reconnect to same room | |
await consumer.connect(test_room) | |
assert consumer.is_connected() | |
assert consumer.room_id == test_room | |
async def test_consumer_state_after_producer_updates(self, producer_consumer_pair): | |
"""Test that consumer can get updated state after producer sends updates.""" | |
producer, consumer, room_id = producer_consumer_pair | |
# Give some time for connection to stabilize | |
await asyncio.sleep(0.1) | |
# Producer sends some state updates | |
await producer.send_state_sync({ | |
"shoulder": 45.0, | |
"elbow": -20.0, | |
"wrist": 10.0, | |
}) | |
# Wait for state to propagate | |
await asyncio.sleep(0.2) | |
# Consumer should be able to get updated state | |
state = await consumer.get_state_sync() | |
assert isinstance(state, dict) | |
# State should contain the joints we sent | |
expected_joints = {"shoulder", "elbow", "wrist"} | |
if state: # Only check if state is not empty | |
assert set(state.keys()) == expected_joints | |