blanchon's picture
Initial commit
02eac4b
import asyncio
import pytest
from lerobot_arena_client import RoboticsConsumer
class TestRoboticsConsumer:
"""Test RoboticsConsumer functionality."""
@pytest.mark.asyncio
async def test_consumer_connection(self, consumer, test_room):
"""Test basic consumer connection."""
assert not consumer.is_connected()
success = await consumer.connect(test_room)
assert success is True
assert consumer.is_connected()
assert consumer.room_id == test_room
assert consumer.role == "consumer"
await consumer.disconnect()
assert not consumer.is_connected()
@pytest.mark.asyncio
async def test_consumer_connection_info(self, connected_consumer):
"""Test getting connection information."""
consumer, room_id = connected_consumer
info = consumer.get_connection_info()
assert info["connected"] is True
assert info["room_id"] == room_id
assert info["role"] == "consumer"
assert info["participant_id"] is not None
assert info["base_url"] == "http://localhost:8000"
@pytest.mark.asyncio
async def test_get_state_sync(self, connected_consumer):
"""Test getting current state synchronously."""
consumer, room_id = connected_consumer
state = await consumer.get_state_sync()
assert isinstance(state, dict)
# Initial state should be empty
assert len(state) == 0
@pytest.mark.asyncio
async def test_consumer_callbacks_setup(self, consumer, test_room):
"""Test setting up consumer callbacks."""
state_sync_called = False
joint_update_called = False
error_called = False
connected_called = False
disconnected_called = False
def on_state_sync(state):
nonlocal state_sync_called
state_sync_called = True
def on_joint_update(joints):
nonlocal joint_update_called
joint_update_called = True
def on_error(error):
nonlocal error_called
error_called = True
def on_connected():
nonlocal connected_called
connected_called = True
def on_disconnected():
nonlocal disconnected_called
disconnected_called = True
# Set all callbacks
consumer.on_state_sync(on_state_sync)
consumer.on_joint_update(on_joint_update)
consumer.on_error(on_error)
consumer.on_connected(on_connected)
consumer.on_disconnected(on_disconnected)
# Connect and test connection callbacks
await consumer.connect(test_room)
await asyncio.sleep(0.1)
assert connected_called is True
await consumer.disconnect()
await asyncio.sleep(0.1)
assert disconnected_called is True
@pytest.mark.asyncio
async def test_multiple_consumers(self, test_room):
"""Test multiple consumers connecting to same room."""
consumer1 = RoboticsConsumer("http://localhost:8000")
consumer2 = RoboticsConsumer("http://localhost:8000")
try:
# Both consumers should be able to connect
success1 = await consumer1.connect(test_room)
success2 = await consumer2.connect(test_room)
assert success1 is True
assert success2 is True
assert consumer1.is_connected()
assert consumer2.is_connected()
finally:
if consumer1.is_connected():
await consumer1.disconnect()
if consumer2.is_connected():
await consumer2.disconnect()
@pytest.mark.asyncio
async def test_consumer_receive_state_sync(self, producer_consumer_pair):
"""Test consumer receiving state sync from producer."""
producer, consumer, room_id = producer_consumer_pair
received_states = []
received_updates = []
def on_state_sync(state):
received_states.append(state)
def on_joint_update(joints):
received_updates.append(joints)
consumer.on_state_sync(on_state_sync)
consumer.on_joint_update(on_joint_update)
# Give some time for connection to stabilize
await asyncio.sleep(0.1)
# Producer sends state sync (which gets converted to joint updates)
await producer.send_state_sync({"shoulder": 45.0, "elbow": -20.0})
# Wait for message to be received
await asyncio.sleep(0.2)
# Consumer should have received the joint updates from the state sync
# The initial state sync during connection might be empty, so we check for joint updates
assert len(received_updates) >= 1
@pytest.mark.asyncio
async def test_consumer_receive_joint_updates(self, producer_consumer_pair):
"""Test consumer receiving joint updates from producer."""
producer, consumer, room_id = producer_consumer_pair
received_updates = []
def on_joint_update(joints):
received_updates.append(joints)
consumer.on_joint_update(on_joint_update)
# Give some time for connection to stabilize
await asyncio.sleep(0.1)
# Producer sends joint updates
test_joints = [
{"name": "shoulder", "value": 45.0},
{"name": "elbow", "value": -20.0},
]
await producer.send_joint_update(test_joints)
# Wait for message to be received
await asyncio.sleep(0.2)
# Consumer should have received the joint update
assert len(received_updates) >= 1
if received_updates:
received_joints = received_updates[-1]
assert isinstance(received_joints, list)
assert len(received_joints) == 2
@pytest.mark.asyncio
async def test_consumer_multiple_updates(self, producer_consumer_pair):
"""Test consumer receiving multiple updates."""
producer, consumer, room_id = producer_consumer_pair
received_updates = []
def on_joint_update(joints):
received_updates.append(joints)
consumer.on_joint_update(on_joint_update)
# Give some time for connection to stabilize
await asyncio.sleep(0.1)
# Send multiple updates
for i in range(5):
await producer.send_state_sync({
"joint1": float(i * 10),
"joint2": float(i * -5),
})
await asyncio.sleep(0.05)
# Wait for all messages to be received
await asyncio.sleep(0.3)
# Should have received multiple updates
assert len(received_updates) >= 3
@pytest.mark.asyncio
async def test_consumer_emergency_stop(self, producer_consumer_pair):
"""Test consumer receiving emergency stop."""
producer, consumer, room_id = producer_consumer_pair
received_errors = []
def on_error(error):
received_errors.append(error)
consumer.on_error(on_error)
# Give some time for connection to stabilize
await asyncio.sleep(0.1)
# Producer sends emergency stop
await producer.send_emergency_stop("Test emergency stop")
# Wait for message to be received
await asyncio.sleep(0.2)
# Consumer should have received emergency stop as error
assert len(received_errors) >= 1
if received_errors:
assert "emergency stop" in received_errors[-1].lower()
@pytest.mark.asyncio
async def test_custom_participant_id(self, consumer, test_room):
"""Test connecting with custom participant ID."""
custom_id = "custom-consumer-456"
await consumer.connect(test_room, participant_id=custom_id)
info = consumer.get_connection_info()
assert info["participant_id"] == custom_id
@pytest.mark.asyncio
async def test_context_manager(self, test_room):
"""Test using consumer as context manager."""
async with RoboticsConsumer("http://localhost:8000") as consumer:
await consumer.connect(test_room)
assert consumer.is_connected()
state = await consumer.get_state_sync()
assert isinstance(state, dict)
# Should be disconnected after context exit
assert not consumer.is_connected()
@pytest.mark.asyncio
async def test_get_state_without_connection(self, consumer):
"""Test getting state without being connected."""
assert not consumer.is_connected()
with pytest.raises(ValueError, match="Must be connected to a room"):
await consumer.get_state_sync()
@pytest.mark.asyncio
async def test_consumer_reconnection(self, consumer, test_room):
"""Test consumer reconnecting to same room."""
# First connection
await consumer.connect(test_room)
assert consumer.is_connected()
await consumer.disconnect()
assert not consumer.is_connected()
# Reconnect to same room
await consumer.connect(test_room)
assert consumer.is_connected()
assert consumer.room_id == test_room
@pytest.mark.asyncio
async def test_consumer_state_after_producer_updates(self, producer_consumer_pair):
"""Test that consumer can get updated state after producer sends updates."""
producer, consumer, room_id = producer_consumer_pair
# Give some time for connection to stabilize
await asyncio.sleep(0.1)
# Producer sends some state updates
await producer.send_state_sync({
"shoulder": 45.0,
"elbow": -20.0,
"wrist": 10.0,
})
# Wait for state to propagate
await asyncio.sleep(0.2)
# Consumer should be able to get updated state
state = await consumer.get_state_sync()
assert isinstance(state, dict)
# State should contain the joints we sent
expected_joints = {"shoulder", "elbow", "wrist"}
if state: # Only check if state is not empty
assert set(state.keys()) == expected_joints